KVM test: tests_base.cfg: Introduce parameter 'vm_type'
[autotest-zwu.git] / database / database_connection.py
blobca5f491f76bc639bc73c345ec447694cd4391e3f
1 import re, time, traceback
2 import common
3 from autotest_lib.client.common_lib import global_config
5 RECONNECT_FOREVER = object()
7 _DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
8 _GLOBAL_CONFIG_NAMES = {
9 'username' : 'user',
10 'db_name' : 'database',
13 def _copy_exceptions(source, destination):
14 for exception_name in _DB_EXCEPTIONS:
15 setattr(destination, exception_name, getattr(source, exception_name))
18 class _GenericBackend(object):
19 def __init__(self, database_module):
20 self._database_module = database_module
21 self._connection = None
22 self._cursor = None
23 self.rowcount = None
24 _copy_exceptions(database_module, self)
27 def connect(self, host=None, username=None, password=None, db_name=None):
28 """
29 This is assumed to enable autocommit.
30 """
31 raise NotImplementedError
34 def disconnect(self):
35 if self._connection:
36 self._connection.close()
37 self._connection = None
38 self._cursor = None
41 def execute(self, query, parameters=None):
42 if parameters is None:
43 parameters = ()
44 self._cursor.execute(query, parameters)
45 self.rowcount = self._cursor.rowcount
46 return self._cursor.fetchall()
49 class _MySqlBackend(_GenericBackend):
50 def __init__(self):
51 import MySQLdb
52 super(_MySqlBackend, self).__init__(MySQLdb)
55 @staticmethod
56 def convert_boolean(boolean, conversion_dict):
57 'Convert booleans to integer strings'
58 return str(int(boolean))
61 def connect(self, host=None, username=None, password=None, db_name=None):
62 import MySQLdb.converters
63 convert_dict = MySQLdb.converters.conversions
64 convert_dict.setdefault(bool, self.convert_boolean)
66 self._connection = self._database_module.connect(
67 host=host, user=username, passwd=password, db=db_name,
68 conv=convert_dict)
69 self._connection.autocommit(True)
70 self._cursor = self._connection.cursor()
73 class _SqliteBackend(_GenericBackend):
74 def __init__(self):
75 from pysqlite2 import dbapi2
76 super(_SqliteBackend, self).__init__(dbapi2)
77 self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
78 re.IGNORECASE)
81 def connect(self, host=None, username=None, password=None, db_name=None):
82 self._connection = self._database_module.connect(db_name)
83 self._connection.isolation_level = None # enable autocommit
84 self._cursor = self._connection.cursor()
87 def execute(self, query, parameters=None):
88 # pysqlite2 uses paramstyle=qmark
89 # TODO: make this more sophisticated if necessary
90 query = query.replace('%s', '?')
91 # pysqlite2 can't handle parameters=None (it throws a nonsense
92 # exception)
93 if parameters is None:
94 parameters = ()
95 # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has
96 # something similar called LAST_INSERT_ROWID() that will do enough of
97 # what we want (for our non-concurrent unittest use case).
98 query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
99 return super(_SqliteBackend, self).execute(query, parameters)
102 class _DjangoBackend(_GenericBackend):
103 def __init__(self):
104 from django.db import backend, connection, transaction
105 super(_DjangoBackend, self).__init__(backend.Database)
106 self._django_connection = connection
107 self._django_transaction = transaction
110 def connect(self, host=None, username=None, password=None, db_name=None):
111 self._connection = self._django_connection
112 self._cursor = self._connection.cursor()
115 def execute(self, query, parameters=None):
116 try:
117 return super(_DjangoBackend, self).execute(query,
118 parameters=parameters)
119 finally:
120 self._django_transaction.commit_unless_managed()
123 _BACKEND_MAP = {
124 'mysql': _MySqlBackend,
125 'sqlite': _SqliteBackend,
126 'django': _DjangoBackend,
130 class DatabaseConnection(object):
132 Generic wrapper for a database connection. Supports both mysql and sqlite
133 backends.
135 Public attributes:
136 * reconnect_enabled: if True, when an OperationalError occurs the class will
137 try to reconnect to the database automatically.
138 * reconnect_delay_sec: seconds to wait before reconnecting
139 * max_reconnect_attempts: maximum number of time to try reconnecting before
140 giving up. Setting to RECONNECT_FOREVER removes the limit.
141 * rowcount - will hold cursor.rowcount after each call to execute().
142 * global_config_section - the section in which to find DB information. this
143 should be passed to the constructor, not set later, and may be None, in
144 which case information must be passed to connect().
145 * debug - if set True, all queries will be printed before being executed
147 _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
148 'db_name')
150 def __init__(self, global_config_section=None, debug=False):
151 self.global_config_section = global_config_section
152 self._backend = None
153 self.rowcount = None
154 self.debug = debug
156 # reconnect defaults
157 self.reconnect_enabled = True
158 self.reconnect_delay_sec = 20
159 self.max_reconnect_attempts = 10
161 self._read_options()
164 def _get_option(self, name, provided_value):
165 if provided_value is not None:
166 return provided_value
167 if self.global_config_section:
168 global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
169 return global_config.global_config.get_config_value(
170 self.global_config_section, global_config_name)
171 return getattr(self, name, None)
174 def _read_options(self, db_type=None, host=None, username=None,
175 password=None, db_name=None):
176 self.db_type = self._get_option('db_type', db_type)
177 self.host = self._get_option('host', host)
178 self.username = self._get_option('username', username)
179 self.password = self._get_option('password', password)
180 self.db_name = self._get_option('db_name', db_name)
183 def _get_backend(self, db_type):
184 if db_type not in _BACKEND_MAP:
185 raise ValueError('Invalid database type: %s, should be one of %s' %
186 (db_type, ', '.join(_BACKEND_MAP.keys())))
187 backend_class = _BACKEND_MAP[db_type]
188 return backend_class()
191 def _reached_max_attempts(self, num_attempts):
192 return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
193 num_attempts > self.max_reconnect_attempts)
196 def _is_reconnect_enabled(self, supplied_param):
197 if supplied_param is not None:
198 return supplied_param
199 return self.reconnect_enabled
202 def _connect_backend(self, try_reconnecting=None):
203 num_attempts = 0
204 while True:
205 try:
206 self._backend.connect(host=self.host, username=self.username,
207 password=self.password,
208 db_name=self.db_name)
209 return
210 except self._backend.OperationalError:
211 num_attempts += 1
212 if not self._is_reconnect_enabled(try_reconnecting):
213 raise
214 if self._reached_max_attempts(num_attempts):
215 raise
216 traceback.print_exc()
217 print ("Can't connect to database; reconnecting in %s sec" %
218 self.reconnect_delay_sec)
219 time.sleep(self.reconnect_delay_sec)
220 self.disconnect()
223 def connect(self, db_type=None, host=None, username=None, password=None,
224 db_name=None, try_reconnecting=None):
226 Parameters passed to this function will override defaults from global
227 config. try_reconnecting, if passed, will override
228 self.reconnect_enabled.
230 self.disconnect()
231 self._read_options(db_type, host, username, password, db_name)
233 self._backend = self._get_backend(self.db_type)
234 _copy_exceptions(self._backend, self)
235 self._connect_backend(try_reconnecting)
238 def disconnect(self):
239 if self._backend:
240 self._backend.disconnect()
243 def execute(self, query, parameters=None, try_reconnecting=None):
245 Execute a query and return cursor.fetchall(). try_reconnecting, if
246 passed, will override self.reconnect_enabled.
248 if self.debug:
249 print 'Executing %s, %s' % (query, parameters)
250 # _connect_backend() contains a retry loop, so don't loop here
251 try:
252 results = self._backend.execute(query, parameters)
253 except self._backend.OperationalError:
254 if not self._is_reconnect_enabled(try_reconnecting):
255 raise
256 traceback.print_exc()
257 print ("MYSQL connection died; reconnecting")
258 self.disconnect()
259 self._connect_backend(try_reconnecting)
260 results = self._backend.execute(query, parameters)
262 self.rowcount = self._backend.rowcount
263 return results
266 def get_database_info(self):
267 return dict((attribute, getattr(self, attribute))
268 for attribute in self._DATABASE_ATTRIBUTES)
271 @classmethod
272 def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
274 Factory method returning a DatabaseConnection for a temporary in-memory
275 database.
277 database = cls(**constructor_kwargs)
278 database.reconnect_enabled = False
279 database.connect(db_type='sqlite', db_name=file_path)
280 return database
283 class TranslatingDatabase(DatabaseConnection):
285 Database wrapper than applies arbitrary substitution regexps to each query
286 string. Useful for SQLite testing.
288 def __init__(self, translators):
290 @param translation_regexps: list of callables to apply to each query
291 string (in order). Each accepts a query string and returns a
292 (possibly) modified query string.
294 super(TranslatingDatabase, self).__init__()
295 self._translators = translators
298 def execute(self, query, parameters=None, try_reconnecting=None):
299 for translator in self._translators:
300 query = translator(query)
301 return super(TranslatingDatabase, self).execute(
302 query, parameters=parameters, try_reconnecting=try_reconnecting)
305 @classmethod
306 def make_regexp_translator(cls, search_re, replace_str):
308 Returns a translator that calls re.sub() on the query with the given
309 search and replace arguments.
311 def translator(query):
312 return re.sub(search_re, replace_str, query)
313 return translator