database.database_connection: Fix _copy_exceptions
[autotest-zwu.git] / database / database_connection.py
blob53903c9e3b16ee2bec0d44cf6c5a50b0bf674d4d
1 import re, time, traceback
2 import common
3 from autotest_lib.client.common_lib import global_config
5 RECONNECT_FOREVER = object()
7 _DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
8 _GLOBAL_CONFIG_NAMES = {
9 'username' : 'user',
10 'db_name' : 'database',
13 def _copy_exceptions(source, destination):
14 for exception_name in _DB_EXCEPTIONS:
15 try:
16 setattr(destination, exception_name,
17 getattr(source, exception_name))
18 except AttributeError:
19 # Under the django backend:
20 # Django 1.3 does not have OperationalError and ProgrammingError.
21 # Let's just mock these classes with the base DatabaseError.
22 setattr(destination, exception_name,
23 getattr(source, 'DatabaseError'))
26 class _GenericBackend(object):
27 def __init__(self, database_module):
28 self._database_module = database_module
29 self._connection = None
30 self._cursor = None
31 self.rowcount = None
32 _copy_exceptions(database_module, self)
35 def connect(self, host=None, username=None, password=None, db_name=None):
36 """
37 This is assumed to enable autocommit.
38 """
39 raise NotImplementedError
42 def disconnect(self):
43 if self._connection:
44 self._connection.close()
45 self._connection = None
46 self._cursor = None
49 def execute(self, query, parameters=None):
50 if parameters is None:
51 parameters = ()
52 self._cursor.execute(query, parameters)
53 self.rowcount = self._cursor.rowcount
54 return self._cursor.fetchall()
57 class _MySqlBackend(_GenericBackend):
58 def __init__(self):
59 import MySQLdb
60 super(_MySqlBackend, self).__init__(MySQLdb)
63 @staticmethod
64 def convert_boolean(boolean, conversion_dict):
65 'Convert booleans to integer strings'
66 return str(int(boolean))
69 def connect(self, host=None, username=None, password=None, db_name=None):
70 import MySQLdb.converters
71 convert_dict = MySQLdb.converters.conversions
72 convert_dict.setdefault(bool, self.convert_boolean)
74 self._connection = self._database_module.connect(
75 host=host, user=username, passwd=password, db=db_name,
76 conv=convert_dict)
77 self._connection.autocommit(True)
78 self._cursor = self._connection.cursor()
81 class _SqliteBackend(_GenericBackend):
82 def __init__(self):
83 from pysqlite2 import dbapi2
84 super(_SqliteBackend, self).__init__(dbapi2)
85 self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
86 re.IGNORECASE)
89 def connect(self, host=None, username=None, password=None, db_name=None):
90 self._connection = self._database_module.connect(db_name)
91 self._connection.isolation_level = None # enable autocommit
92 self._cursor = self._connection.cursor()
95 def execute(self, query, parameters=None):
96 # pysqlite2 uses paramstyle=qmark
97 # TODO: make this more sophisticated if necessary
98 query = query.replace('%s', '?')
99 # pysqlite2 can't handle parameters=None (it throws a nonsense
100 # exception)
101 if parameters is None:
102 parameters = ()
103 # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has
104 # something similar called LAST_INSERT_ROWID() that will do enough of
105 # what we want (for our non-concurrent unittest use case).
106 query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
107 return super(_SqliteBackend, self).execute(query, parameters)
110 class _DjangoBackend(_GenericBackend):
111 def __init__(self):
112 from django.db import backend, connection, transaction
113 import django.db as django_db
114 super(_DjangoBackend, self).__init__(django_db)
115 self._django_connection = connection
116 self._django_transaction = transaction
119 def connect(self, host=None, username=None, password=None, db_name=None):
120 self._connection = self._django_connection
121 self._cursor = self._connection.cursor()
124 def execute(self, query, parameters=None):
125 try:
126 return super(_DjangoBackend, self).execute(query,
127 parameters=parameters)
128 finally:
129 self._django_transaction.commit_unless_managed()
132 _BACKEND_MAP = {
133 'mysql': _MySqlBackend,
134 'sqlite': _SqliteBackend,
135 'django': _DjangoBackend,
139 class DatabaseConnection(object):
141 Generic wrapper for a database connection. Supports both mysql and sqlite
142 backends.
144 Public attributes:
145 * reconnect_enabled: if True, when an OperationalError occurs the class will
146 try to reconnect to the database automatically.
147 * reconnect_delay_sec: seconds to wait before reconnecting
148 * max_reconnect_attempts: maximum number of time to try reconnecting before
149 giving up. Setting to RECONNECT_FOREVER removes the limit.
150 * rowcount - will hold cursor.rowcount after each call to execute().
151 * global_config_section - the section in which to find DB information. this
152 should be passed to the constructor, not set later, and may be None, in
153 which case information must be passed to connect().
154 * debug - if set True, all queries will be printed before being executed
156 _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
157 'db_name')
159 def __init__(self, global_config_section=None, debug=False):
160 self.global_config_section = global_config_section
161 self._backend = None
162 self.rowcount = None
163 self.debug = debug
165 # reconnect defaults
166 self.reconnect_enabled = True
167 self.reconnect_delay_sec = 20
168 self.max_reconnect_attempts = 10
170 self._read_options()
173 def _get_option(self, name, provided_value):
174 if provided_value is not None:
175 return provided_value
176 if self.global_config_section:
177 global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
178 return global_config.global_config.get_config_value(
179 self.global_config_section, global_config_name)
180 return getattr(self, name, None)
183 def _read_options(self, db_type=None, host=None, username=None,
184 password=None, db_name=None):
185 self.db_type = self._get_option('db_type', db_type)
186 self.host = self._get_option('host', host)
187 self.username = self._get_option('username', username)
188 self.password = self._get_option('password', password)
189 self.db_name = self._get_option('db_name', db_name)
192 def _get_backend(self, db_type):
193 if db_type not in _BACKEND_MAP:
194 raise ValueError('Invalid database type: %s, should be one of %s' %
195 (db_type, ', '.join(_BACKEND_MAP.keys())))
196 backend_class = _BACKEND_MAP[db_type]
197 return backend_class()
200 def _reached_max_attempts(self, num_attempts):
201 return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
202 num_attempts > self.max_reconnect_attempts)
205 def _is_reconnect_enabled(self, supplied_param):
206 if supplied_param is not None:
207 return supplied_param
208 return self.reconnect_enabled
211 def _connect_backend(self, try_reconnecting=None):
212 num_attempts = 0
213 while True:
214 try:
215 self._backend.connect(host=self.host, username=self.username,
216 password=self.password,
217 db_name=self.db_name)
218 return
219 except self._backend.OperationalError:
220 num_attempts += 1
221 if not self._is_reconnect_enabled(try_reconnecting):
222 raise
223 if self._reached_max_attempts(num_attempts):
224 raise
225 traceback.print_exc()
226 print ("Can't connect to database; reconnecting in %s sec" %
227 self.reconnect_delay_sec)
228 time.sleep(self.reconnect_delay_sec)
229 self.disconnect()
232 def connect(self, db_type=None, host=None, username=None, password=None,
233 db_name=None, try_reconnecting=None):
235 Parameters passed to this function will override defaults from global
236 config. try_reconnecting, if passed, will override
237 self.reconnect_enabled.
239 self.disconnect()
240 self._read_options(db_type, host, username, password, db_name)
242 self._backend = self._get_backend(self.db_type)
243 _copy_exceptions(self._backend, self)
244 self._connect_backend(try_reconnecting)
247 def disconnect(self):
248 if self._backend:
249 self._backend.disconnect()
252 def execute(self, query, parameters=None, try_reconnecting=None):
254 Execute a query and return cursor.fetchall(). try_reconnecting, if
255 passed, will override self.reconnect_enabled.
257 if self.debug:
258 print 'Executing %s, %s' % (query, parameters)
259 # _connect_backend() contains a retry loop, so don't loop here
260 try:
261 results = self._backend.execute(query, parameters)
262 except self._backend.OperationalError:
263 if not self._is_reconnect_enabled(try_reconnecting):
264 raise
265 traceback.print_exc()
266 print ("MYSQL connection died; reconnecting")
267 self.disconnect()
268 self._connect_backend(try_reconnecting)
269 results = self._backend.execute(query, parameters)
271 self.rowcount = self._backend.rowcount
272 return results
275 def get_database_info(self):
276 return dict((attribute, getattr(self, attribute))
277 for attribute in self._DATABASE_ATTRIBUTES)
280 @classmethod
281 def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
283 Factory method returning a DatabaseConnection for a temporary in-memory
284 database.
286 database = cls(**constructor_kwargs)
287 database.reconnect_enabled = False
288 database.connect(db_type='sqlite', db_name=file_path)
289 return database
292 class TranslatingDatabase(DatabaseConnection):
294 Database wrapper than applies arbitrary substitution regexps to each query
295 string. Useful for SQLite testing.
297 def __init__(self, translators):
299 @param translation_regexps: list of callables to apply to each query
300 string (in order). Each accepts a query string and returns a
301 (possibly) modified query string.
303 super(TranslatingDatabase, self).__init__()
304 self._translators = translators
307 def execute(self, query, parameters=None, try_reconnecting=None):
308 for translator in self._translators:
309 query = translator(query)
310 return super(TranslatingDatabase, self).execute(
311 query, parameters=parameters, try_reconnecting=try_reconnecting)
314 @classmethod
315 def make_regexp_translator(cls, search_re, replace_str):
317 Returns a translator that calls re.sub() on the query with the given
318 search and replace arguments.
320 def translator(query):
321 return re.sub(search_re, replace_str, query)
322 return translator