1 import re
, time
, traceback
3 from autotest_lib
.client
.common_lib
import global_config
5 RECONNECT_FOREVER
= object()
7 _DB_EXCEPTIONS
= ('DatabaseError', 'OperationalError', 'ProgrammingError')
8 _GLOBAL_CONFIG_NAMES
= {
10 'db_name' : 'database',
13 def _copy_exceptions(source
, destination
):
14 for exception_name
in _DB_EXCEPTIONS
:
16 setattr(destination
, exception_name
,
17 getattr(source
, exception_name
))
18 except AttributeError:
19 # Under the django backend:
20 # Django 1.3 does not have OperationalError and ProgrammingError.
21 # Let's just mock these classes with the base DatabaseError.
22 setattr(destination
, exception_name
,
23 getattr(source
, 'DatabaseError'))
26 class _GenericBackend(object):
27 def __init__(self
, database_module
):
28 self
._database
_module
= database_module
29 self
._connection
= None
32 _copy_exceptions(database_module
, self
)
35 def connect(self
, host
=None, username
=None, password
=None, db_name
=None):
37 This is assumed to enable autocommit.
39 raise NotImplementedError
44 self
._connection
.close()
45 self
._connection
= None
49 def execute(self
, query
, parameters
=None):
50 if parameters
is None:
52 self
._cursor
.execute(query
, parameters
)
53 self
.rowcount
= self
._cursor
.rowcount
54 return self
._cursor
.fetchall()
57 class _MySqlBackend(_GenericBackend
):
60 super(_MySqlBackend
, self
).__init
__(MySQLdb
)
64 def convert_boolean(boolean
, conversion_dict
):
65 'Convert booleans to integer strings'
66 return str(int(boolean
))
69 def connect(self
, host
=None, username
=None, password
=None, db_name
=None):
70 import MySQLdb
.converters
71 convert_dict
= MySQLdb
.converters
.conversions
72 convert_dict
.setdefault(bool, self
.convert_boolean
)
74 self
._connection
= self
._database
_module
.connect(
75 host
=host
, user
=username
, passwd
=password
, db
=db_name
,
77 self
._connection
.autocommit(True)
78 self
._cursor
= self
._connection
.cursor()
81 class _SqliteBackend(_GenericBackend
):
83 from pysqlite2
import dbapi2
84 super(_SqliteBackend
, self
).__init
__(dbapi2
)
85 self
._last
_insert
_id
_re
= re
.compile(r
'\sLAST_INSERT_ID\(\)',
89 def connect(self
, host
=None, username
=None, password
=None, db_name
=None):
90 self
._connection
= self
._database
_module
.connect(db_name
)
91 self
._connection
.isolation_level
= None # enable autocommit
92 self
._cursor
= self
._connection
.cursor()
95 def execute(self
, query
, parameters
=None):
96 # pysqlite2 uses paramstyle=qmark
97 # TODO: make this more sophisticated if necessary
98 query
= query
.replace('%s', '?')
99 # pysqlite2 can't handle parameters=None (it throws a nonsense
101 if parameters
is None:
103 # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has
104 # something similar called LAST_INSERT_ROWID() that will do enough of
105 # what we want (for our non-concurrent unittest use case).
106 query
= self
._last
_insert
_id
_re
.sub(' LAST_INSERT_ROWID()', query
)
107 return super(_SqliteBackend
, self
).execute(query
, parameters
)
110 class _DjangoBackend(_GenericBackend
):
112 from django
.db
import backend
, connection
, transaction
113 import django
.db
as django_db
114 super(_DjangoBackend
, self
).__init
__(django_db
)
115 self
._django
_connection
= connection
116 self
._django
_transaction
= transaction
119 def connect(self
, host
=None, username
=None, password
=None, db_name
=None):
120 self
._connection
= self
._django
_connection
121 self
._cursor
= self
._connection
.cursor()
124 def execute(self
, query
, parameters
=None):
126 return super(_DjangoBackend
, self
).execute(query
,
127 parameters
=parameters
)
129 self
._django
_transaction
.commit_unless_managed()
133 'mysql': _MySqlBackend
,
134 'sqlite': _SqliteBackend
,
135 'django': _DjangoBackend
,
139 class DatabaseConnection(object):
141 Generic wrapper for a database connection. Supports both mysql and sqlite
145 * reconnect_enabled: if True, when an OperationalError occurs the class will
146 try to reconnect to the database automatically.
147 * reconnect_delay_sec: seconds to wait before reconnecting
148 * max_reconnect_attempts: maximum number of time to try reconnecting before
149 giving up. Setting to RECONNECT_FOREVER removes the limit.
150 * rowcount - will hold cursor.rowcount after each call to execute().
151 * global_config_section - the section in which to find DB information. this
152 should be passed to the constructor, not set later, and may be None, in
153 which case information must be passed to connect().
154 * debug - if set True, all queries will be printed before being executed
156 _DATABASE_ATTRIBUTES
= ('db_type', 'host', 'username', 'password',
159 def __init__(self
, global_config_section
=None, debug
=False):
160 self
.global_config_section
= global_config_section
166 self
.reconnect_enabled
= True
167 self
.reconnect_delay_sec
= 20
168 self
.max_reconnect_attempts
= 10
173 def _get_option(self
, name
, provided_value
):
174 if provided_value
is not None:
175 return provided_value
176 if self
.global_config_section
:
177 global_config_name
= _GLOBAL_CONFIG_NAMES
.get(name
, name
)
178 return global_config
.global_config
.get_config_value(
179 self
.global_config_section
, global_config_name
)
180 return getattr(self
, name
, None)
183 def _read_options(self
, db_type
=None, host
=None, username
=None,
184 password
=None, db_name
=None):
185 self
.db_type
= self
._get
_option
('db_type', db_type
)
186 self
.host
= self
._get
_option
('host', host
)
187 self
.username
= self
._get
_option
('username', username
)
188 self
.password
= self
._get
_option
('password', password
)
189 self
.db_name
= self
._get
_option
('db_name', db_name
)
192 def _get_backend(self
, db_type
):
193 if db_type
not in _BACKEND_MAP
:
194 raise ValueError('Invalid database type: %s, should be one of %s' %
195 (db_type
, ', '.join(_BACKEND_MAP
.keys())))
196 backend_class
= _BACKEND_MAP
[db_type
]
197 return backend_class()
200 def _reached_max_attempts(self
, num_attempts
):
201 return (self
.max_reconnect_attempts
is not RECONNECT_FOREVER
and
202 num_attempts
> self
.max_reconnect_attempts
)
205 def _is_reconnect_enabled(self
, supplied_param
):
206 if supplied_param
is not None:
207 return supplied_param
208 return self
.reconnect_enabled
211 def _connect_backend(self
, try_reconnecting
=None):
215 self
._backend
.connect(host
=self
.host
, username
=self
.username
,
216 password
=self
.password
,
217 db_name
=self
.db_name
)
219 except self
._backend
.OperationalError
:
221 if not self
._is
_reconnect
_enabled
(try_reconnecting
):
223 if self
._reached
_max
_attempts
(num_attempts
):
225 traceback
.print_exc()
226 print ("Can't connect to database; reconnecting in %s sec" %
227 self
.reconnect_delay_sec
)
228 time
.sleep(self
.reconnect_delay_sec
)
232 def connect(self
, db_type
=None, host
=None, username
=None, password
=None,
233 db_name
=None, try_reconnecting
=None):
235 Parameters passed to this function will override defaults from global
236 config. try_reconnecting, if passed, will override
237 self.reconnect_enabled.
240 self
._read
_options
(db_type
, host
, username
, password
, db_name
)
242 self
._backend
= self
._get
_backend
(self
.db_type
)
243 _copy_exceptions(self
._backend
, self
)
244 self
._connect
_backend
(try_reconnecting
)
247 def disconnect(self
):
249 self
._backend
.disconnect()
252 def execute(self
, query
, parameters
=None, try_reconnecting
=None):
254 Execute a query and return cursor.fetchall(). try_reconnecting, if
255 passed, will override self.reconnect_enabled.
258 print 'Executing %s, %s' % (query
, parameters
)
259 # _connect_backend() contains a retry loop, so don't loop here
261 results
= self
._backend
.execute(query
, parameters
)
262 except self
._backend
.OperationalError
:
263 if not self
._is
_reconnect
_enabled
(try_reconnecting
):
265 traceback
.print_exc()
266 print ("MYSQL connection died; reconnecting")
268 self
._connect
_backend
(try_reconnecting
)
269 results
= self
._backend
.execute(query
, parameters
)
271 self
.rowcount
= self
._backend
.rowcount
275 def get_database_info(self
):
276 return dict((attribute
, getattr(self
, attribute
))
277 for attribute
in self
._DATABASE
_ATTRIBUTES
)
281 def get_test_database(cls
, file_path
=':memory:', **constructor_kwargs
):
283 Factory method returning a DatabaseConnection for a temporary in-memory
286 database
= cls(**constructor_kwargs
)
287 database
.reconnect_enabled
= False
288 database
.connect(db_type
='sqlite', db_name
=file_path
)
292 class TranslatingDatabase(DatabaseConnection
):
294 Database wrapper than applies arbitrary substitution regexps to each query
295 string. Useful for SQLite testing.
297 def __init__(self
, translators
):
299 @param translation_regexps: list of callables to apply to each query
300 string (in order). Each accepts a query string and returns a
301 (possibly) modified query string.
303 super(TranslatingDatabase
, self
).__init
__()
304 self
._translators
= translators
307 def execute(self
, query
, parameters
=None, try_reconnecting
=None):
308 for translator
in self
._translators
:
309 query
= translator(query
)
310 return super(TranslatingDatabase
, self
).execute(
311 query
, parameters
=parameters
, try_reconnecting
=try_reconnecting
)
315 def make_regexp_translator(cls
, search_re
, replace_str
):
317 Returns a translator that calls re.sub() on the query with the given
318 search and replace arguments.
320 def translator(query
):
321 return re
.sub(search_re
, replace_str
, query
)