improved iterator trapping tests
[pygr.git] / tests / sqltable_test.py
blob7766df38529c7a2c35dc8a2e23a3e5f60dbfcb4f
1 import os, random, string, unittest
2 from testlib import testutil, PygrTestProgram, SkipTest
3 from pygr.sqlgraph import SQLTable, SQLTableNoCache,SQLTableClustered,\
4 MapView, GraphView, DBServerInfo, import_sqlite
5 from pygr import logger
7 def entrap(klass):
8 def catch_iterator(self, *args, **kwargs):
9 try:
10 assert not self.catchIter, 'this should not iterate!'
11 except AttributeError:
12 pass
13 return klass.generic_iterator(self, *args, **kwargs)
14 return catch_iterator
17 class SQLTableCatcher(SQLTable):
18 generic_iterator = entrap(SQLTable)
20 class SQLTableNoCacheCatcher(SQLTableNoCache):
21 generic_iterator = entrap(SQLTableNoCache)
23 class SQLTableClusteredCatcher(SQLTableClustered):
24 generic_iterator = entrap(SQLTableClustered)
26 class SQLTable_Setup(unittest.TestCase):
27 tableClass = SQLTableCatcher
28 serverArgs = {}
29 loadArgs = {}
30 def __init__(self, *args, **kwargs):
31 unittest.TestCase.__init__(self, *args, **kwargs)
32 # share conn for all tests
33 self.serverInfo = DBServerInfo(** self.serverArgs)
34 def setUp(self):
35 try:
36 self.load_data(writeable=self.writeable, ** self.loadArgs)
37 except ImportError:
38 raise SkipTest('missing MySQLdb module?')
39 def load_data(self, tableName='test.sqltable_test', writeable=False,
40 dbargs={},sourceDBargs={},targetDBargs={}):
41 'create 3 tables and load 9 rows for our tests'
42 self.tableName = tableName
43 self.joinTable1 = joinTable1 = tableName + '1'
44 self.joinTable2 = joinTable2 = tableName + '2'
45 createTable = """\
46 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
47 """ % tableName
48 self.db = self.tableClass(tableName, dropIfExists=True,
49 serverInfo=self.serverInfo,
50 createTable=createTable,
51 writeable=writeable, **dbargs)
52 self.sourceDB = self.tableClass(joinTable1, serverInfo=self.serverInfo,
53 dropIfExists=True, createTable="""\
54 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
55 other_id VARCHAR(16))
56 """ % joinTable1, **sourceDBargs)
57 self.targetDB = self.tableClass(joinTable2, serverInfo=self.serverInfo,
58 dropIfExists=True, createTable="""\
59 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
60 other_id VARCHAR(16))
61 """ % joinTable2, **targetDBargs)
62 sql = """
63 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
64 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
65 INSERT INTO %s VALUES (2,'seq2')
66 INSERT INTO %s VALUES (3,'seq3')
67 INSERT INTO %s VALUES (4,'seq4')
68 INSERT INTO %s VALUES (7, 'seq2')
69 INSERT INTO %s VALUES (99, 'seq3')
70 INSERT INTO %s VALUES (6, 'seq4')
71 INSERT INTO %s VALUES (8, 'seq4')
72 """ % tuple(([tableName]*2) + ([joinTable1]*3) + ([joinTable2]*4))
73 for line in sql.strip().splitlines(): # insert our test data
74 self.db.cursor.execute(line.strip())
76 # Another table, for the "ORDER BY" test
77 self.orderTable = tableName + '_orderBy'
78 self.db.cursor.execute("""\
79 CREATE TABLE %s (id INTEGER PRIMARY KEY, number INTEGER, letter VARCHAR(1))
80 """ % self.orderTable)
81 for row in range(0, 10):
82 self.db.cursor.execute('INSERT INTO %s VALUES (%d, %d, \'%s\')' %
83 (self.orderTable, row, random.randint(0, 99),
84 string.lowercase[random.randint(0,
85 len(string.lowercase)
86 - 1)]))
87 def tearDown(self):
88 self.db.cursor.execute('drop table if exists %s' % self.tableName)
89 self.db.cursor.execute('drop table if exists %s' % self.joinTable1)
90 self.db.cursor.execute('drop table if exists %s' % self.joinTable2)
91 self.db.cursor.execute('drop table if exists %s' % self.orderTable)
92 self.serverInfo.close()
94 class SQLTable_Test(SQLTable_Setup):
95 writeable = False # read-only database interface
96 def test_keys(self):
97 k = self.db.keys()
98 k.sort()
99 assert k == [1, 2]
100 def test_len(self):
101 self.db.catchIter = True
102 assert len(self.db) == len(self.db.keys())
103 def test_contains(self):
104 self.db.catchIter = True
105 assert 1 in self.db
106 assert 2 in self.db
107 assert 'foo' not in self.db
108 def test_has_key(self):
109 self.db.catchIter = True
110 assert self.db.has_key(1)
111 assert self.db.has_key(2)
112 assert not self.db.has_key('foo')
113 def test_get(self):
114 self.db.catchIter = True
115 assert self.db.get('foo') is None
116 assert self.db.get(1) == self.db[1]
117 assert self.db.get(2) == self.db[2]
118 def test_items(self):
119 i = [ k for (k,v) in self.db.items() ]
120 i.sort()
121 assert i == [1, 2]
122 def test_iterkeys(self):
123 kk = self.db.keys()
124 ik = list(self.db.iterkeys())
125 assert kk == ik
126 def test_pickle(self):
127 kk = self.db.keys()
128 import pickle
129 s = pickle.dumps(self.db)
130 db = pickle.loads(s)
131 try:
132 ik = list(db.iterkeys())
133 assert kk == ik
134 finally:
135 db.serverInfo.close() # close extra DB connection
136 def test_itervalues(self):
137 kv = self.db.values()
138 iv = list(self.db.itervalues())
139 assert kv == iv
140 def test_itervalues_long(self):
141 """test iterator isolation from queries run inside iterator loop """
142 sql = 'insert into %s (start) values (1)' % self.tableName
143 for i in range(40000): # insert 40000 rows
144 self.db.cursor.execute(sql)
145 iv = []
146 for o in self.db.itervalues():
147 status = 99 in self.db # make it do a query inside iterator loop
148 iv.append(o.id)
149 kv = [o.id for o in self.db.values()]
150 assert len(kv) == len(iv)
151 assert kv == iv
152 def test_iteritems(self):
153 ki = self.db.items()
154 ii = list(self.db.iteritems())
155 assert ki == ii
156 def test_readonly(self):
157 'test error handling of write attempts to read-only DB'
158 self.db.catchIter = True # no iter expected in this test!
159 try:
160 self.db.new(seq_id='freddy', start=3000, stop=4500)
161 raise AssertionError('failed to trap attempt to write to db')
162 except ValueError:
163 pass
164 o = self.db[1]
165 try:
166 self.db[33] = o
167 raise AssertionError('failed to trap attempt to write to db')
168 except ValueError:
169 pass
170 try:
171 del self.db[2]
172 raise AssertionError('failed to trap attempt to write to db')
173 except ValueError:
174 pass
175 def test_orderBy(self):
176 'test iterator with orderBy, iterSQL, iterColumns'
177 self.targetDB.catchIter = True # should not iterate
178 self.targetDB.arraysize = 2 # force it to use multiple queries to finish
179 result = self.targetDB.keys()
180 assert result == [6, 7, 8, 99]
181 self.targetDB.catchIter = False # next statement will iterate
182 assert result == list(iter(self.targetDB))
183 self.targetDB.catchIter = True # should not iterate
184 self.targetDB.orderBy = 'ORDER BY other_id'
185 result = self.targetDB.keys()
186 assert result == [7, 99, 6, 8]
187 self.targetDB.catchIter = False # next statement will iterate
188 if self.serverInfo._serverType == 'mysql' \
189 and self.serverInfo.custom_iter_keys: # only test this for mysql
190 try:
191 assert result == list(iter(self.targetDB))
192 raise AssertionError('failed to trap missing iterSQL attr')
193 except AttributeError:
194 pass
195 self.targetDB.iterSQL = 'WHERE other_id>%s' # tell it how to slice
196 self.targetDB.iterColumns = ['other_id']
197 assert result == list(iter(self.targetDB))
198 result = self.targetDB.values()
199 assert result == [self.targetDB[7], self.targetDB[99],
200 self.targetDB[6], self.targetDB[8]]
201 assert result == list(self.targetDB.itervalues())
202 result = self.targetDB.items()
203 assert result == [(7, self.targetDB[7]), (99, self.targetDB[99]),
204 (6, self.targetDB[6]), (8, self.targetDB[8])]
205 assert result == list(self.targetDB.iteritems())
206 import pickle
207 s = pickle.dumps(self.targetDB) # test pickling & unpickling
208 db = pickle.loads(s)
209 try:
210 correct = self.targetDB.keys()
211 result = list(iter(db))
212 assert result == correct
213 finally:
214 db.serverInfo.close() # close extra DB connection
216 def test_orderby_random(self):
217 'test orderBy in SQLTable'
218 if self.serverInfo._serverType == 'mysql' \
219 and self.serverInfo.custom_iter_keys:
220 try:
221 byNumber = self.tableClass(self.orderTable, arraysize=2,
222 serverInfo=self.serverInfo,
223 orderBy='ORDER BY number')
224 raise AssertionError('failed to trap orderBy without iterSQL!')
225 except ValueError:
226 pass
227 byNumber = self.tableClass(self.orderTable, serverInfo=self.serverInfo,
228 arraysize=2, orderBy='ORDER BY number,id',
229 iterSQL='WHERE number>%s or (number=%s and id>%s)',
230 iterColumns=('number','number','id'))
231 bv = [val.number for val in byNumber.values()]
232 sortedBV = bv[:]
233 sortedBV.sort()
234 assert sortedBV == bv
235 bv = [val.number for val in byNumber.itervalues()]
236 assert sortedBV == bv
238 byLetter = self.tableClass(self.orderTable, serverInfo=self.serverInfo,
239 arraysize=2, orderBy='ORDER BY letter,id',
240 iterSQL='WHERE letter>%s or (letter=%s and id>%s)',
241 iterColumns=('letter','letter','id'))
242 bl = [val.letter for val in byLetter.values()]
243 sortedBL = bl[:]
244 assert sortedBL == bl
245 bl = [val.letter for val in byLetter.itervalues()]
246 assert sortedBL == bl
248 ### @CTB need to test write access
249 def test_mapview(self):
250 'test MapView of SQL join'
251 self.sourceDB.catchIter = self.targetDB.catchIter = True
252 m = MapView(self.sourceDB, self.targetDB,"""\
253 SELECT t2.third_id FROM %s t1, %s t2
254 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
255 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
256 assert m[self.sourceDB[2]] == self.targetDB[7]
257 assert m[self.sourceDB[3]] == self.targetDB[99]
258 assert self.sourceDB[2] in m
259 try:
260 d = m[self.sourceDB[4]]
261 raise AssertionError('failed to trap non-unique mapping')
262 except KeyError:
263 pass
264 try:
265 r = ~m
266 raise AssertionError('failed to trap non-invertible mapping')
267 except ValueError:
268 pass
269 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
270 % self.sourceDB.name)
271 assert len(self.sourceDB) == 4
272 self.sourceDB.catchIter = False # next step will cause iteration
273 assert len(m) == 2
274 l = m.keys()
275 l.sort()
276 correct = [self.sourceDB[2],self.sourceDB[3]]
277 correct.sort()
278 assert l == correct
279 def test_mapview_inverse(self):
280 'test inverse MapView of SQL join'
281 self.sourceDB.catchIter = self.targetDB.catchIter = True
282 m = MapView(self.sourceDB, self.targetDB,"""\
283 SELECT t2.third_id FROM %s t1, %s t2
284 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
285 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
286 inverseSQL="""\
287 SELECT t1.my_id FROM %s t1, %s t2
288 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
289 """ % (self.joinTable1,self.joinTable2))
290 r = ~m # get the inverse
291 assert self.sourceDB[2] == r[self.targetDB[7]]
292 assert self.sourceDB[3] == r[self.targetDB[99]]
293 assert self.targetDB[7] in r
295 m = ~r # get the inverse of the inverse!
296 assert m[self.sourceDB[2]] == self.targetDB[7]
297 assert m[self.sourceDB[3]] == self.targetDB[99]
298 assert self.sourceDB[2] in m
299 try:
300 d = m[self.sourceDB[4]]
301 raise AssertionError('failed to trap non-unique mapping')
302 except KeyError:
303 pass
304 def test_graphview(self):
305 'test GraphView of SQL join'
306 self.sourceDB.catchIter = self.targetDB.catchIter = True
307 m = GraphView(self.sourceDB, self.targetDB,"""\
308 SELECT t2.third_id FROM %s t1, %s t2
309 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
310 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
311 d = m[self.sourceDB[4]]
312 assert len(d) == 2
313 assert self.targetDB[6] in d and self.targetDB[8] in d
314 assert self.sourceDB[2] in m
316 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
317 % self.sourceDB.name)
318 assert len(self.sourceDB) == 4
319 self.sourceDB.catchIter = False # next step will cause iteration
320 assert len(m) == 3
321 l = m.keys()
322 l.sort()
323 correct = [self.sourceDB[2],self.sourceDB[3],self.sourceDB[4]]
324 correct.sort()
325 assert l == correct
327 def test_graphview_inverse(self):
328 'test inverse GraphView of SQL join'
329 self.sourceDB.catchIter = self.targetDB.catchIter = True
330 m = GraphView(self.sourceDB, self.targetDB,"""\
331 SELECT t2.third_id FROM %s t1, %s t2
332 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
333 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
334 inverseSQL="""\
335 SELECT t1.my_id FROM %s t1, %s t2
336 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
337 """ % (self.joinTable1,self.joinTable2))
338 r = ~m # get the inverse
339 assert self.sourceDB[2] in r[self.targetDB[7]]
340 assert self.sourceDB[3] in r[self.targetDB[99]]
341 assert self.targetDB[7] in r
342 d = r[self.targetDB[6]]
343 assert len(d) == 1
344 assert self.sourceDB[4] in d
346 m = ~r # get inverse of the inverse!
347 d = m[self.sourceDB[4]]
348 assert len(d) == 2
349 assert self.targetDB[6] in d and self.targetDB[8] in d
350 assert self.sourceDB[2] in m
352 class SQLTable_No_SSCursor_Test(SQLTable_Test):
353 serverArgs = dict(serverSideCursors=False)
355 class SQLTable_OldIter_Test(SQLTable_Test):
356 serverArgs = dict(serverSideCursors=False,
357 blockIterators=False)
359 class SQLiteBase(testutil.SQLite_Mixin):
360 def sqlite_load(self):
361 self.load_data('sqltable_test', writeable=self.writeable)
363 class SQLiteTable_Test(SQLiteBase, SQLTable_Test):
364 pass
366 ## class SQLitePickle_Test(SQLiteTable_Test):
367 ## def setUp(self):
368 ## """Pickle / unpickle our serverInfo before trying to use it """
369 ## SQLiteTable_Test.setUp(self)
370 ## self.serverInfo.close()
371 ## import pickle
372 ## s = pickle.dumps(self.serverInfo)
373 ## del self.serverInfo
374 ## self.serverInfo = pickle.loads(s)
375 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
376 ## self.sourceDB = self.tableClass(self.joinTable1,
377 ## serverInfo=self.serverInfo)
378 ## self.targetDB = self.tableClass(self.joinTable2,
379 ## serverInfo=self.serverInfo)
381 class SQLTable_NoCache_Test(SQLTable_Test):
382 tableClass = SQLTableNoCacheCatcher
384 class SQLTableClustered_Test(SQLTable_Test):
385 tableClass = SQLTableClusteredCatcher
386 loadArgs = dict(dbargs=dict(clusterKey='seq_id', arraysize=2),
387 sourceDBargs=dict(clusterKey='other_id', arraysize=2),
388 targetDBargs=dict(clusterKey='other_id', arraysize=2))
389 def test_orderBy(self): # neither of these tests useful in this context
390 pass
391 def test_orderby_random(self):
392 pass
394 class SQLiteTable_NoCache_Test(SQLiteTable_Test):
395 tableClass = SQLTableNoCache
397 class SQLTableRW_Test(SQLTable_Setup):
398 'test write operations'
399 writeable = True
400 def test_new(self):
401 'test row creation with auto inc ID'
402 self.db.catchIter = True # no iter expected in this test
403 n = len(self.db)
404 o = self.db.new(seq_id='freddy', start=3000, stop=4500)
405 assert len(self.db) == n + 1
406 t = self.tableClass(self.tableName,
407 serverInfo=self.serverInfo) # requery the db
408 t.catchIter = True # no iter expected in this test
409 result = t[o.id]
410 assert result.seq_id == 'freddy' and result.start==3000 \
411 and result.stop==4500
412 def test_new2(self):
413 'check row creation with specified ID'
414 self.db.catchIter = True # no iter expected in this test
415 n = len(self.db)
416 o = self.db.new(id=99, seq_id='jeff', start=3000, stop=4500)
417 assert len(self.db) == n + 1
418 assert o.id == 99
419 t = self.tableClass(self.tableName,
420 serverInfo=self.serverInfo) # requery the db
421 t.catchIter = True # no iter expected in this test
422 result = t[99]
423 assert result.seq_id == 'jeff' and result.start==3000 \
424 and result.stop==4500
425 def test_attr(self):
426 'test changing an attr value'
427 self.db.catchIter = True # no iter expected in this test
428 o = self.db[2]
429 assert o.seq_id == 'seq2'
430 o.seq_id = 'newval' # overwrite this attribute
431 assert o.seq_id == 'newval' # check cached value
432 t = self.tableClass(self.tableName,
433 serverInfo=self.serverInfo) # requery the db
434 t.catchIter = True # no iter expected in this test
435 result = t[2]
436 assert result.seq_id == 'newval'
437 def test_delitem(self):
438 'test deletion of a row'
439 self.db.catchIter = True # no iter expected in this test
440 n = len(self.db)
441 del self.db[1]
442 assert len(self.db) == n - 1
443 try:
444 result = self.db[1]
445 raise AssertionError('old ID still exists!')
446 except KeyError:
447 pass
448 def test_setitem(self):
449 'test assigning new ID to existing object'
450 self.db.catchIter = True # no iter expected in this test
451 o = self.db.new(id=17, seq_id='bob', start=2000, stop=2500)
452 self.db[13] = o
453 assert o.id == 13
454 try:
455 result = self.db[17]
456 raise AssertionError('old ID still exists!')
457 except KeyError:
458 pass
459 t = self.tableClass(self.tableName,
460 serverInfo=self.serverInfo) # requery the db
461 t.catchIter = True # no iter expected in this test
462 result = t[13]
463 assert result.seq_id == 'bob' and result.start==2000 \
464 and result.stop==2500
465 try:
466 result = t[17]
467 raise AssertionError('old ID still exists!')
468 except KeyError:
469 pass
472 class SQLiteTableRW_Test(SQLiteBase, SQLTableRW_Test):
473 pass
475 class SQLTableRW_NoCache_Test(SQLTableRW_Test):
476 tableClass = SQLTableNoCache
478 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test):
479 tableClass = SQLTableNoCache
481 class Ensembl_Test(unittest.TestCase):
483 def setUp(self):
484 # test will be skipped if mysql module or ensembldb server unavailable
486 logger.debug('accessing ensembldb.ensembl.org')
487 conn = DBServerInfo(host='ensembldb.ensembl.org', user='anonymous',
488 passwd='')
489 try:
490 translationDB = SQLTableCatcher('homo_sapiens_core_47_36i.translation',
491 serverInfo=conn)
492 translationDB.catchIter = True # should not iter in this test!
493 exonDB = SQLTable('homo_sapiens_core_47_36i.exon', serverInfo=conn)
494 except ImportError,e:
495 raise SkipTest(e)
497 sql_statement = '''SELECT t3.exon_id FROM
498 homo_sapiens_core_47_36i.translation AS tr,
499 homo_sapiens_core_47_36i.exon_transcript AS t1,
500 homo_sapiens_core_47_36i.exon_transcript AS t2,
501 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
502 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
503 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
504 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
505 t3.rank <= t2.rank ORDER BY t3.rank
507 self.translationExons = GraphView(translationDB, exonDB,
508 sql_statement, serverInfo=conn)
509 self.translation = translationDB[15121]
511 def test_orderBy(self):
512 "Ensemble access, test order by"
513 'test issue 53: ensure that the ORDER BY results are correct'
514 exons = self.translationExons[self.translation] # do the query
515 result = [e.id for e in exons]
516 correct = [95160,95020,95035,95050,95059,95069,95081,95088,95101,
517 95110,95172]
518 self.assertEqual(result, correct) # make sure the exact order matches
521 if __name__ == '__main__':
522 PygrTestProgram(verbosity=2)