common_lib.base_packages: Add parallel bzip2 support to package manager
[autotest-zwu.git] / tko / db.py
blob72c897fbc6c4e1445f0b2d3413d6588eaaf05de8
1 import re, os, sys, types, time, random
3 import common
4 from autotest_lib.client.common_lib import global_config
5 from autotest_lib.tko import utils
8 class MySQLTooManyRows(Exception):
9 pass
12 class db_sql(object):
13 def __init__(self, debug=False, autocommit=True, host=None,
14 database=None, user=None, password=None):
15 self.debug = debug
16 self.autocommit = autocommit
17 self._load_config(host, database, user, password)
19 self.con = None
20 self._init_db()
22 # if not present, insert statuses
23 self.status_idx = {}
24 self.status_word = {}
25 status_rows = self.select('status_idx, word', 'tko_status', None)
26 for s in status_rows:
27 self.status_idx[s[1]] = s[0]
28 self.status_word[s[0]] = s[1]
30 machine_map = os.path.join(os.path.dirname(__file__),
31 'machines')
32 if os.path.exists(machine_map):
33 self.machine_map = machine_map
34 else:
35 self.machine_map = None
36 self.machine_group = {}
39 def _load_config(self, host, database, user, password):
40 # grab the global config
41 get_value = global_config.global_config.get_config_value
43 # grab the host, database
44 if host:
45 self.host = host
46 else:
47 self.host = get_value("AUTOTEST_WEB", "host")
48 if database:
49 self.database = database
50 else:
51 self.database = get_value("AUTOTEST_WEB", "database")
53 # grab the user and password
54 if user:
55 self.user = user
56 else:
57 self.user = get_value("AUTOTEST_WEB", "user")
58 if password is not None:
59 self.password = password
60 else:
61 self.password = get_value("AUTOTEST_WEB", "password")
63 # grab the timeout configuration
64 self.query_timeout = get_value("AUTOTEST_WEB", "query_timeout",
65 type=int, default=3600)
66 self.min_delay = get_value("AUTOTEST_WEB", "min_retry_delay", type=int,
67 default=20)
68 self.max_delay = get_value("AUTOTEST_WEB", "max_retry_delay", type=int,
69 default=60)
72 def _init_db(self):
73 # make sure we clean up any existing connection
74 if self.con:
75 self.con.close()
76 self.con = None
78 # create the db connection and cursor
79 self.con = self.connect(self.host, self.database,
80 self.user, self.password)
81 self.cur = self.con.cursor()
84 def _random_delay(self):
85 delay = random.randint(self.min_delay, self.max_delay)
86 time.sleep(delay)
89 def run_with_retry(self, function, *args, **dargs):
90 """Call function(*args, **dargs) until either it passes
91 without an operational error, or a timeout is reached.
92 This will re-connect to the database, so it is NOT safe
93 to use this inside of a database transaction.
95 It can be safely used with transactions, but the
96 transaction start & end must be completely contained
97 within the call to 'function'."""
98 OperationalError = _get_error_class("OperationalError")
100 success = False
101 start_time = time.time()
102 while not success:
103 try:
104 result = function(*args, **dargs)
105 except OperationalError, e:
106 self._log_operational_error(e)
107 stop_time = time.time()
108 elapsed_time = stop_time - start_time
109 if elapsed_time > self.query_timeout:
110 raise
111 else:
112 try:
113 self._random_delay()
114 self._init_db()
115 except OperationalError, e:
116 self._log_operational_error(e)
117 else:
118 success = True
119 return result
122 def _log_operational_error(self, e):
123 msg = ("%s: An operational error occured during a database "
124 "operation: %s" % (time.strftime("%X %x"), str(e)))
125 print >> sys.stderr, msg
126 sys.stderr.flush() # we want these msgs to show up immediately
129 def dprint(self, value):
130 if self.debug:
131 sys.stdout.write('SQL: ' + str(value) + '\n')
134 def commit(self):
135 self.con.commit()
138 def get_last_autonumber_value(self):
139 self.cur.execute('SELECT LAST_INSERT_ID()', [])
140 return self.cur.fetchall()[0][0]
143 def _quote(self, field):
144 return '`%s`' % field
147 def _where_clause(self, where):
148 if not where:
149 return '', []
151 if isinstance(where, dict):
152 # key/value pairs (which should be equal, or None for null)
153 keys, values = [], []
154 for field, value in where.iteritems():
155 quoted_field = self._quote(field)
156 if value is None:
157 keys.append(quoted_field + ' is null')
158 else:
159 keys.append(quoted_field + '=%s')
160 values.append(value)
161 where_clause = ' and '.join(keys)
162 elif isinstance(where, basestring):
163 # the exact string
164 where_clause = where
165 values = []
166 elif isinstance(where, tuple):
167 # preformatted where clause + values
168 where_clause, values = where
169 assert where_clause
170 else:
171 raise ValueError('Invalid "where" value: %r' % where)
173 return ' WHERE ' + where_clause, values
177 def select(self, fields, table, where, distinct=False, group_by=None,
178 max_rows=None):
179 """\
180 This selects all the fields requested from a
181 specific table with a particular where clause.
182 The where clause can either be a dictionary of
183 field=value pairs, a string, or a tuple of (string,
184 a list of values). The last option is what you
185 should use when accepting user input as it'll
186 protect you against sql injection attacks (if
187 all user data is placed in the array rather than
188 the raw SQL).
190 For example:
191 where = ("a = %s AND b = %s", ['val', 'val'])
192 is better than
193 where = "a = 'val' AND b = 'val'"
195 cmd = ['select']
196 if distinct:
197 cmd.append('distinct')
198 cmd += [fields, 'from', table]
200 where_clause, values = self._where_clause(where)
201 cmd.append(where_clause)
203 if group_by:
204 cmd.append(' GROUP BY ' + group_by)
206 self.dprint('%s %s' % (' '.join(cmd), values))
208 # create a re-runable function for executing the query
209 def exec_sql():
210 sql = ' '.join(cmd)
211 numRec = self.cur.execute(sql, values)
212 if max_rows is not None and numRec > max_rows:
213 msg = 'Exceeded allowed number of records'
214 raise MySQLTooManyRows(msg)
215 return self.cur.fetchall()
217 # run the query, re-trying after operational errors
218 if self.autocommit:
219 return self.run_with_retry(exec_sql)
220 else:
221 return exec_sql()
224 def select_sql(self, fields, table, sql, values):
225 """\
226 select fields from table "sql"
228 cmd = 'select %s from %s %s' % (fields, table, sql)
229 self.dprint(cmd)
231 # create a -re-runable function for executing the query
232 def exec_sql():
233 self.cur.execute(cmd, values)
234 return self.cur.fetchall()
236 # run the query, re-trying after operational errors
237 if self.autocommit:
238 return self.run_with_retry(exec_sql)
239 else:
240 return exec_sql()
243 def _exec_sql_with_commit(self, sql, values, commit):
244 if self.autocommit:
245 # re-run the query until it succeeds
246 def exec_sql():
247 self.cur.execute(sql, values)
248 self.con.commit()
249 self.run_with_retry(exec_sql)
250 else:
251 # take one shot at running the query
252 self.cur.execute(sql, values)
253 if commit:
254 self.con.commit()
257 def insert(self, table, data, commit=None):
258 """\
259 'insert into table (keys) values (%s ... %s)', values
261 data:
262 dictionary of fields and data
264 fields = data.keys()
265 refs = ['%s' for field in fields]
266 values = [data[field] for field in fields]
267 cmd = ('insert into %s (%s) values (%s)' %
268 (table, ','.join(self._quote(field) for field in fields),
269 ','.join(refs)))
270 self.dprint('%s %s' % (cmd, values))
272 self._exec_sql_with_commit(cmd, values, commit)
275 def delete(self, table, where, commit = None):
276 cmd = ['delete from', table]
277 if commit is None:
278 commit = self.autocommit
279 where_clause, values = self._where_clause(where)
280 cmd.append(where_clause)
281 sql = ' '.join(cmd)
282 self.dprint('%s %s' % (sql, values))
284 self._exec_sql_with_commit(sql, values, commit)
287 def update(self, table, data, where, commit = None):
288 """\
289 'update table set data values (%s ... %s) where ...'
291 data:
292 dictionary of fields and data
294 if commit is None:
295 commit = self.autocommit
296 cmd = 'update %s ' % table
297 fields = data.keys()
298 data_refs = [self._quote(field) + '=%s' for field in fields]
299 data_values = [data[field] for field in fields]
300 cmd += ' set ' + ', '.join(data_refs)
302 where_clause, where_values = self._where_clause(where)
303 cmd += where_clause
305 values = data_values + where_values
306 self.dprint('%s %s' % (cmd, values))
308 self._exec_sql_with_commit(cmd, values, commit)
311 def delete_job(self, tag, commit = None):
312 job_idx = self.find_job(tag)
313 for test_idx in self.find_tests(job_idx):
314 where = {'test_idx' : test_idx}
315 self.delete('tko_iteration_result', where)
316 self.delete('tko_iteration_attributes', where)
317 self.delete('tko_test_attributes', where)
318 self.delete('tko_test_labels_tests', {'test_id': test_idx})
319 where = {'job_idx' : job_idx}
320 self.delete('tko_tests', where)
321 self.delete('tko_jobs', where)
324 def insert_job(self, tag, job, commit = None):
325 job.machine_idx = self.lookup_machine(job.machine)
326 if not job.machine_idx:
327 job.machine_idx = self.insert_machine(job, commit=commit)
328 else:
329 self.update_machine_information(job, commit=commit)
331 afe_job_id = utils.get_afe_job_id(tag)
333 data = {'tag':tag,
334 'label': job.label,
335 'username': job.user,
336 'machine_idx': job.machine_idx,
337 'queued_time': job.queued_time,
338 'started_time': job.started_time,
339 'finished_time': job.finished_time,
340 'afe_job_id': afe_job_id}
341 is_update = hasattr(job, 'index')
342 if is_update:
343 self.update('tko_jobs', data, {'job_idx': job.index}, commit=commit)
344 else:
345 self.insert('tko_jobs', data, commit=commit)
346 job.index = self.get_last_autonumber_value()
347 self.update_job_keyvals(job, commit=commit)
348 for test in job.tests:
349 self.insert_test(job, test, commit=commit)
352 def update_job_keyvals(self, job, commit=None):
353 for key, value in job.keyval_dict.iteritems():
354 where = {'job_id': job.index, 'key': key}
355 data = dict(where, value=value)
356 exists = self.select('id', 'tko_job_keyvals', where=where)
358 if exists:
359 self.update('tko_job_keyvals', data, where=where, commit=commit)
360 else:
361 self.insert('tko_job_keyvals', data, commit=commit)
364 def insert_test(self, job, test, commit = None):
365 kver = self.insert_kernel(test.kernel, commit=commit)
366 data = {'job_idx':job.index, 'test':test.testname,
367 'subdir':test.subdir, 'kernel_idx':kver,
368 'status':self.status_idx[test.status],
369 'reason':test.reason, 'machine_idx':job.machine_idx,
370 'started_time': test.started_time,
371 'finished_time':test.finished_time}
372 is_update = hasattr(test, "test_idx")
373 if is_update:
374 test_idx = test.test_idx
375 self.update('tko_tests', data,
376 {'test_idx': test_idx}, commit=commit)
377 where = {'test_idx': test_idx}
378 self.delete('tko_iteration_result', where)
379 self.delete('tko_iteration_attributes', where)
380 where['user_created'] = 0
381 self.delete('tko_test_attributes', where)
382 else:
383 self.insert('tko_tests', data, commit=commit)
384 test_idx = test.test_idx = self.get_last_autonumber_value()
385 data = {'test_idx': test_idx}
387 for i in test.iterations:
388 data['iteration'] = i.index
389 for key, value in i.attr_keyval.iteritems():
390 data['attribute'] = key
391 data['value'] = value
392 self.insert('tko_iteration_attributes', data,
393 commit=commit)
394 for key, value in i.perf_keyval.iteritems():
395 data['attribute'] = key
396 data['value'] = value
397 self.insert('tko_iteration_result', data,
398 commit=commit)
400 for key, value in test.attributes.iteritems():
401 data = {'test_idx': test_idx, 'attribute': key,
402 'value': value}
403 self.insert('tko_test_attributes', data, commit=commit)
405 if not is_update:
406 for label_index in test.labels:
407 data = {'test_id': test_idx, 'testlabel_id': label_index}
408 self.insert('tko_test_labels_tests', data, commit=commit)
411 def read_machine_map(self):
412 if self.machine_group or not self.machine_map:
413 return
414 for line in open(self.machine_map, 'r').readlines():
415 (machine, group) = line.split()
416 self.machine_group[machine] = group
419 def machine_info_dict(self, job):
420 hostname = job.machine
421 group = job.machine_group
422 owner = job.machine_owner
424 if not group:
425 self.read_machine_map()
426 group = self.machine_group.get(hostname, hostname)
427 if group == hostname and owner:
428 group = owner + '/' + hostname
430 return {'hostname': hostname, 'machine_group': group, 'owner': owner}
433 def insert_machine(self, job, commit = None):
434 machine_info = self.machine_info_dict(job)
435 self.insert('tko_machines', machine_info, commit=commit)
436 return self.get_last_autonumber_value()
439 def update_machine_information(self, job, commit = None):
440 machine_info = self.machine_info_dict(job)
441 self.update('tko_machines', machine_info,
442 where={'hostname': machine_info['hostname']},
443 commit=commit)
446 def lookup_machine(self, hostname):
447 where = { 'hostname' : hostname }
448 rows = self.select('machine_idx', 'tko_machines', where)
449 if rows:
450 return rows[0][0]
451 else:
452 return None
455 def lookup_kernel(self, kernel):
456 rows = self.select('kernel_idx', 'tko_kernels',
457 {'kernel_hash':kernel.kernel_hash})
458 if rows:
459 return rows[0][0]
460 else:
461 return None
464 def insert_kernel(self, kernel, commit = None):
465 kver = self.lookup_kernel(kernel)
466 if kver:
467 return kver
469 # If this kernel has any significant patches, append their hash
470 # as diferentiator.
471 printable = kernel.base
472 patch_count = 0
473 for patch in kernel.patches:
474 match = re.match(r'.*(-mm[0-9]+|-git[0-9]+)\.(bz2|gz)$',
475 patch.reference)
476 if not match:
477 patch_count += 1
479 self.insert('tko_kernels',
480 {'base':kernel.base,
481 'kernel_hash':kernel.kernel_hash,
482 'printable':printable},
483 commit=commit)
484 kver = self.get_last_autonumber_value()
486 if patch_count > 0:
487 printable += ' p%d' % (kver)
488 self.update('tko_kernels',
489 {'printable':printable},
490 {'kernel_idx':kver})
492 for patch in kernel.patches:
493 self.insert_patch(kver, patch, commit=commit)
494 return kver
497 def insert_patch(self, kver, patch, commit = None):
498 print patch.reference
499 name = os.path.basename(patch.reference)[:80]
500 self.insert('tko_patches',
501 {'kernel_idx': kver,
502 'name':name,
503 'url':patch.reference,
504 'hash':patch.hash},
505 commit=commit)
508 def find_test(self, job_idx, testname, subdir):
509 where = {'job_idx': job_idx , 'test': testname, 'subdir': subdir}
510 rows = self.select('test_idx', 'tko_tests', where)
511 if rows:
512 return rows[0][0]
513 else:
514 return None
517 def find_tests(self, job_idx):
518 where = { 'job_idx':job_idx }
519 rows = self.select('test_idx', 'tko_tests', where)
520 if rows:
521 return [row[0] for row in rows]
522 else:
523 return []
526 def find_job(self, tag):
527 rows = self.select('job_idx', 'tko_jobs', {'tag': tag})
528 if rows:
529 return rows[0][0]
530 else:
531 return None
534 def _get_db_type():
535 """Get the database type name to use from the global config."""
536 get_value = global_config.global_config.get_config_value
537 return "db_" + get_value("AUTOTEST_WEB", "db_type", default="mysql")
540 def _get_error_class(class_name):
541 """Retrieves the appropriate error class by name from the database
542 module."""
543 db_module = __import__("autotest_lib.tko." + _get_db_type(),
544 globals(), locals(), ["driver"])
545 return getattr(db_module.driver, class_name)
548 def db(*args, **dargs):
549 """Creates an instance of the database class with the arguments
550 provided in args and dargs, using the database type specified by
551 the global configuration (defaulting to mysql)."""
552 db_type = _get_db_type()
553 db_module = __import__("autotest_lib.tko." + db_type, globals(),
554 locals(), [db_type])
555 db = getattr(db_module, db_type)(*args, **dargs)
556 return db