1 import re
, time
, traceback
3 from autotest_lib
.client
.common_lib
import global_config
5 RECONNECT_FOREVER
= object()
7 _DB_EXCEPTIONS
= ('DatabaseError', 'OperationalError', 'ProgrammingError')
8 _GLOBAL_CONFIG_NAMES
= {
10 'db_name' : 'database',
13 def _copy_exceptions(source
, destination
):
14 for exception_name
in _DB_EXCEPTIONS
:
15 setattr(destination
, exception_name
, getattr(source
, exception_name
))
18 class _GenericBackend(object):
19 def __init__(self
, database_module
):
20 self
._database
_module
= database_module
21 self
._connection
= None
24 _copy_exceptions(database_module
, self
)
27 def connect(self
, host
=None, username
=None, password
=None, db_name
=None):
29 This is assumed to enable autocommit.
31 raise NotImplementedError
36 self
._connection
.close()
37 self
._connection
= None
41 def execute(self
, query
, parameters
=None):
42 if parameters
is None:
44 self
._cursor
.execute(query
, parameters
)
45 self
.rowcount
= self
._cursor
.rowcount
46 return self
._cursor
.fetchall()
49 class _MySqlBackend(_GenericBackend
):
52 super(_MySqlBackend
, self
).__init
__(MySQLdb
)
56 def convert_boolean(boolean
, conversion_dict
):
57 'Convert booleans to integer strings'
58 return str(int(boolean
))
61 def connect(self
, host
=None, username
=None, password
=None, db_name
=None):
62 import MySQLdb
.converters
63 convert_dict
= MySQLdb
.converters
.conversions
64 convert_dict
.setdefault(bool, self
.convert_boolean
)
66 self
._connection
= self
._database
_module
.connect(
67 host
=host
, user
=username
, passwd
=password
, db
=db_name
,
69 self
._connection
.autocommit(True)
70 self
._cursor
= self
._connection
.cursor()
73 class _SqliteBackend(_GenericBackend
):
75 from pysqlite2
import dbapi2
76 super(_SqliteBackend
, self
).__init
__(dbapi2
)
77 self
._last
_insert
_id
_re
= re
.compile(r
'\sLAST_INSERT_ID\(\)',
81 def connect(self
, host
=None, username
=None, password
=None, db_name
=None):
82 self
._connection
= self
._database
_module
.connect(db_name
)
83 self
._connection
.isolation_level
= None # enable autocommit
84 self
._cursor
= self
._connection
.cursor()
87 def execute(self
, query
, parameters
=None):
88 # pysqlite2 uses paramstyle=qmark
89 # TODO: make this more sophisticated if necessary
90 query
= query
.replace('%s', '?')
91 # pysqlite2 can't handle parameters=None (it throws a nonsense
93 if parameters
is None:
95 # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has
96 # something similar called LAST_INSERT_ROWID() that will do enough of
97 # what we want (for our non-concurrent unittest use case).
98 query
= self
._last
_insert
_id
_re
.sub(' LAST_INSERT_ROWID()', query
)
99 return super(_SqliteBackend
, self
).execute(query
, parameters
)
102 class _DjangoBackend(_GenericBackend
):
104 from django
.db
import backend
, connection
, transaction
105 super(_DjangoBackend
, self
).__init
__(backend
.Database
)
106 self
._django
_connection
= connection
107 self
._django
_transaction
= transaction
110 def connect(self
, host
=None, username
=None, password
=None, db_name
=None):
111 self
._connection
= self
._django
_connection
112 self
._cursor
= self
._connection
.cursor()
115 def execute(self
, query
, parameters
=None):
117 return super(_DjangoBackend
, self
).execute(query
,
118 parameters
=parameters
)
120 self
._django
_transaction
.commit_unless_managed()
124 'mysql': _MySqlBackend
,
125 'sqlite': _SqliteBackend
,
126 'django': _DjangoBackend
,
130 class DatabaseConnection(object):
132 Generic wrapper for a database connection. Supports both mysql and sqlite
136 * reconnect_enabled: if True, when an OperationalError occurs the class will
137 try to reconnect to the database automatically.
138 * reconnect_delay_sec: seconds to wait before reconnecting
139 * max_reconnect_attempts: maximum number of time to try reconnecting before
140 giving up. Setting to RECONNECT_FOREVER removes the limit.
141 * rowcount - will hold cursor.rowcount after each call to execute().
142 * global_config_section - the section in which to find DB information. this
143 should be passed to the constructor, not set later, and may be None, in
144 which case information must be passed to connect().
145 * debug - if set True, all queries will be printed before being executed
147 _DATABASE_ATTRIBUTES
= ('db_type', 'host', 'username', 'password',
150 def __init__(self
, global_config_section
=None, debug
=False):
151 self
.global_config_section
= global_config_section
157 self
.reconnect_enabled
= True
158 self
.reconnect_delay_sec
= 20
159 self
.max_reconnect_attempts
= 10
164 def _get_option(self
, name
, provided_value
):
165 if provided_value
is not None:
166 return provided_value
167 if self
.global_config_section
:
168 global_config_name
= _GLOBAL_CONFIG_NAMES
.get(name
, name
)
169 return global_config
.global_config
.get_config_value(
170 self
.global_config_section
, global_config_name
)
171 return getattr(self
, name
, None)
174 def _read_options(self
, db_type
=None, host
=None, username
=None,
175 password
=None, db_name
=None):
176 self
.db_type
= self
._get
_option
('db_type', db_type
)
177 self
.host
= self
._get
_option
('host', host
)
178 self
.username
= self
._get
_option
('username', username
)
179 self
.password
= self
._get
_option
('password', password
)
180 self
.db_name
= self
._get
_option
('db_name', db_name
)
183 def _get_backend(self
, db_type
):
184 if db_type
not in _BACKEND_MAP
:
185 raise ValueError('Invalid database type: %s, should be one of %s' %
186 (db_type
, ', '.join(_BACKEND_MAP
.keys())))
187 backend_class
= _BACKEND_MAP
[db_type
]
188 return backend_class()
191 def _reached_max_attempts(self
, num_attempts
):
192 return (self
.max_reconnect_attempts
is not RECONNECT_FOREVER
and
193 num_attempts
> self
.max_reconnect_attempts
)
196 def _is_reconnect_enabled(self
, supplied_param
):
197 if supplied_param
is not None:
198 return supplied_param
199 return self
.reconnect_enabled
202 def _connect_backend(self
, try_reconnecting
=None):
206 self
._backend
.connect(host
=self
.host
, username
=self
.username
,
207 password
=self
.password
,
208 db_name
=self
.db_name
)
210 except self
._backend
.OperationalError
:
212 if not self
._is
_reconnect
_enabled
(try_reconnecting
):
214 if self
._reached
_max
_attempts
(num_attempts
):
216 traceback
.print_exc()
217 print ("Can't connect to database; reconnecting in %s sec" %
218 self
.reconnect_delay_sec
)
219 time
.sleep(self
.reconnect_delay_sec
)
223 def connect(self
, db_type
=None, host
=None, username
=None, password
=None,
224 db_name
=None, try_reconnecting
=None):
226 Parameters passed to this function will override defaults from global
227 config. try_reconnecting, if passed, will override
228 self.reconnect_enabled.
231 self
._read
_options
(db_type
, host
, username
, password
, db_name
)
233 self
._backend
= self
._get
_backend
(self
.db_type
)
234 _copy_exceptions(self
._backend
, self
)
235 self
._connect
_backend
(try_reconnecting
)
238 def disconnect(self
):
240 self
._backend
.disconnect()
243 def execute(self
, query
, parameters
=None, try_reconnecting
=None):
245 Execute a query and return cursor.fetchall(). try_reconnecting, if
246 passed, will override self.reconnect_enabled.
249 print 'Executing %s, %s' % (query
, parameters
)
250 # _connect_backend() contains a retry loop, so don't loop here
252 results
= self
._backend
.execute(query
, parameters
)
253 except self
._backend
.OperationalError
:
254 if not self
._is
_reconnect
_enabled
(try_reconnecting
):
256 traceback
.print_exc()
257 print ("MYSQL connection died; reconnecting")
259 self
._connect
_backend
(try_reconnecting
)
260 results
= self
._backend
.execute(query
, parameters
)
262 self
.rowcount
= self
._backend
.rowcount
266 def get_database_info(self
):
267 return dict((attribute
, getattr(self
, attribute
))
268 for attribute
in self
._DATABASE
_ATTRIBUTES
)
272 def get_test_database(cls
, file_path
=':memory:', **constructor_kwargs
):
274 Factory method returning a DatabaseConnection for a temporary in-memory
277 database
= cls(**constructor_kwargs
)
278 database
.reconnect_enabled
= False
279 database
.connect(db_type
='sqlite', db_name
=file_path
)
283 class TranslatingDatabase(DatabaseConnection
):
285 Database wrapper than applies arbitrary substitution regexps to each query
286 string. Useful for SQLite testing.
288 def __init__(self
, translators
):
290 @param translation_regexps: list of callables to apply to each query
291 string (in order). Each accepts a query string and returns a
292 (possibly) modified query string.
294 super(TranslatingDatabase
, self
).__init
__()
295 self
._translators
= translators
298 def execute(self
, query
, parameters
=None, try_reconnecting
=None):
299 for translator
in self
._translators
:
300 query
= translator(query
)
301 return super(TranslatingDatabase
, self
).execute(
302 query
, parameters
=parameters
, try_reconnecting
=try_reconnecting
)
306 def make_regexp_translator(cls
, search_re
, replace_str
):
308 Returns a translator that calls re.sub() on the query with the given
309 search and replace arguments.
311 def translator(query
):
312 return re
.sub(search_re
, replace_str
, query
)