Merge commit 'remotes/ctb/xmlrpc_patches' into tryme
[pygr.git] / tests / sqltable_test.py
blobd99a5db36903b9e8bb799997b0a10aaa4b814fa2
1 import os, random, string, unittest
2 from testlib import testutil, PygrTestProgram, SkipTest
3 from pygr.sqlgraph import SQLTable, SQLTableNoCache,\
4 MapView, GraphView, DBServerInfo, import_sqlite
5 from pygr import logger
8 def catch_iterator(self, *args, **kwargs):
9 try:
10 assert not self.catchIter, 'this should not iterate!'
11 except AttributeError:
12 pass
13 return SQLTable.generic_iterator(self, *args, **kwargs)
16 class SQLTableCatcher(SQLTable):
17 generic_iterator = catch_iterator
19 class SQLTableNoCacheCatcher(SQLTableNoCache):
20 generic_iterator = catch_iterator
22 class SQLTable_Setup(unittest.TestCase):
23 tableClass = SQLTableCatcher
24 serverArgs = {}
25 def __init__(self, *args, **kwargs):
26 unittest.TestCase.__init__(self, *args, **kwargs)
27 # share conn for all tests
28 self.serverInfo = DBServerInfo(** self.serverArgs)
29 def setUp(self):
30 try:
31 self.load_data(writeable=self.writeable)
32 except ImportError:
33 raise SkipTest('missing MySQLdb module?')
34 def load_data(self, tableName='test.sqltable_test', writeable=False):
35 'create 3 tables and load 9 rows for our tests'
36 self.tableName = tableName
37 self.joinTable1 = joinTable1 = tableName + '1'
38 self.joinTable2 = joinTable2 = tableName + '2'
39 createTable = """\
40 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
41 """ % tableName
42 self.db = self.tableClass(tableName, dropIfExists=True,
43 serverInfo=self.serverInfo,
44 createTable=createTable,
45 writeable=writeable)
46 self.sourceDB = self.tableClass(joinTable1, serverInfo=self.serverInfo,
47 dropIfExists=True, createTable="""\
48 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
49 other_id VARCHAR(16))
50 """ % joinTable1)
51 self.targetDB = self.tableClass(joinTable2, serverInfo=self.serverInfo,
52 dropIfExists=True, createTable="""\
53 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
54 other_id VARCHAR(16))
55 """ % joinTable2)
56 sql = """
57 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
58 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
59 INSERT INTO %s VALUES (2,'seq2')
60 INSERT INTO %s VALUES (3,'seq3')
61 INSERT INTO %s VALUES (4,'seq4')
62 INSERT INTO %s VALUES (7, 'seq2')
63 INSERT INTO %s VALUES (99, 'seq3')
64 INSERT INTO %s VALUES (6, 'seq4')
65 INSERT INTO %s VALUES (8, 'seq4')
66 """ % tuple(([tableName]*2) + ([joinTable1]*3) + ([joinTable2]*4))
67 for line in sql.strip().splitlines(): # insert our test data
68 self.db.cursor.execute(line.strip())
70 # Another table, for the "ORDER BY" test
71 self.orderTable = tableName + '_orderBy'
72 self.db.cursor.execute("""\
73 CREATE TABLE %s (id INTEGER PRIMARY KEY, number INTEGER, letter VARCHAR(1))
74 """ % self.orderTable)
75 for row in range(0, 10):
76 self.db.cursor.execute('INSERT INTO %s VALUES (%d, %d, \'%s\')' %
77 (self.orderTable, row, random.randint(0, 99),
78 string.lowercase[random.randint(0,
79 len(string.lowercase)
80 - 1)]))
81 def tearDown(self):
82 self.db.cursor.execute('drop table if exists %s' % self.tableName)
83 self.db.cursor.execute('drop table if exists %s' % self.joinTable1)
84 self.db.cursor.execute('drop table if exists %s' % self.joinTable2)
85 self.db.cursor.execute('drop table if exists %s' % self.orderTable)
86 self.serverInfo.close()
88 class SQLTable_Test(SQLTable_Setup):
89 writeable = False # read-only database interface
90 def test_keys(self):
91 k = self.db.keys()
92 k.sort()
93 assert k == [1, 2]
94 def test_len(self):
95 self.db.catchIter = True
96 assert len(self.db) == len(self.db.keys())
97 def test_contains(self):
98 self.db.catchIter = True
99 assert 1 in self.db
100 assert 2 in self.db
101 assert 'foo' not in self.db
102 def test_has_key(self):
103 self.db.catchIter = True
104 assert self.db.has_key(1)
105 assert self.db.has_key(2)
106 assert not self.db.has_key('foo')
107 def test_get(self):
108 self.db.catchIter = True
109 assert self.db.get('foo') is None
110 assert self.db.get(1) == self.db[1]
111 assert self.db.get(2) == self.db[2]
112 def test_items(self):
113 i = [ k for (k,v) in self.db.items() ]
114 i.sort()
115 assert i == [1, 2]
116 def test_iterkeys(self):
117 kk = self.db.keys()
118 ik = list(self.db.iterkeys())
119 assert kk == ik
120 def test_pickle(self):
121 kk = self.db.keys()
122 import pickle
123 s = pickle.dumps(self.db)
124 db = pickle.loads(s)
125 try:
126 ik = list(db.iterkeys())
127 assert kk == ik
128 finally:
129 db.serverInfo.close() # close extra DB connection
130 def test_itervalues(self):
131 kv = self.db.values()
132 iv = list(self.db.itervalues())
133 assert kv == iv
134 def test_itervalues_long(self):
135 """test iterator isolation from queries run inside iterator loop """
136 sql = 'insert into %s (start) values (1)' % self.tableName
137 for i in range(40000): # insert 40000 rows
138 self.db.cursor.execute(sql)
139 iv = []
140 for o in self.db.itervalues():
141 status = 99 in self.db # make it do a query inside iterator loop
142 iv.append(o.id)
143 kv = [o.id for o in self.db.values()]
144 assert len(kv) == len(iv)
145 assert kv == iv
146 def test_iteritems(self):
147 ki = self.db.items()
148 ii = list(self.db.iteritems())
149 assert ki == ii
150 def test_readonly(self):
151 'test error handling of write attempts to read-only DB'
152 self.db.catchIter = True # no iter expected in this test!
153 try:
154 self.db.new(seq_id='freddy', start=3000, stop=4500)
155 raise AssertionError('failed to trap attempt to write to db')
156 except ValueError:
157 pass
158 o = self.db[1]
159 try:
160 self.db[33] = o
161 raise AssertionError('failed to trap attempt to write to db')
162 except ValueError:
163 pass
164 try:
165 del self.db[2]
166 raise AssertionError('failed to trap attempt to write to db')
167 except ValueError:
168 pass
169 def test_orderBy(self):
170 'test iterator with orderBy, iterSQL, iterColumns'
171 self.targetDB.catchIter = True # should not iterate
172 self.targetDB.arraysize = 2 # force it to use multiple queries to finish
173 result = self.targetDB.keys()
174 assert result == [6, 7, 8, 99]
175 self.targetDB.catchIter = False # next statement will iterate
176 assert result == list(iter(self.targetDB))
177 self.targetDB.catchIter = True # should not iterate
178 self.targetDB.orderBy = 'ORDER BY other_id'
179 result = self.targetDB.keys()
180 assert result == [7, 99, 6, 8]
181 self.targetDB.catchIter = False # next statement will iterate
182 if self.serverInfo._serverType == 'mysql' \
183 and self.serverInfo.custom_iter_keys: # only test this for mysql
184 try:
185 assert result == list(iter(self.targetDB))
186 raise AssertionError('failed to trap missing iterSQL attr')
187 except AttributeError:
188 pass
189 self.targetDB.iterSQL = 'WHERE other_id>%s' # tell it how to slice
190 self.targetDB.iterColumns = ['other_id']
191 assert result == list(iter(self.targetDB))
192 result = self.targetDB.values()
193 assert result == [self.targetDB[7], self.targetDB[99],
194 self.targetDB[6], self.targetDB[8]]
195 assert result == list(self.targetDB.itervalues())
196 result = self.targetDB.items()
197 assert result == [(7, self.targetDB[7]), (99, self.targetDB[99]),
198 (6, self.targetDB[6]), (8, self.targetDB[8])]
199 assert result == list(self.targetDB.iteritems())
200 import pickle
201 s = pickle.dumps(self.targetDB) # test pickling & unpickling
202 db = pickle.loads(s)
203 try:
204 correct = self.targetDB.keys()
205 result = list(iter(db))
206 assert result == correct
207 finally:
208 db.serverInfo.close() # close extra DB connection
210 def test_orderby_random(self):
211 'test orderBy in SQLTable'
212 if self.serverInfo._serverType == 'mysql' \
213 and self.serverInfo.custom_iter_keys:
214 try:
215 byNumber = self.tableClass(self.orderTable, arraysize=2,
216 serverInfo=self.serverInfo,
217 orderBy='ORDER BY number')
218 raise AssertionError('failed to trap orderBy without iterSQL!')
219 except ValueError:
220 pass
221 byNumber = self.tableClass(self.orderTable, serverInfo=self.serverInfo,
222 arraysize=2, orderBy='ORDER BY number,id',
223 iterSQL='WHERE number>%s or (number=%s and id>%s)',
224 iterColumns=('number','number','id'))
225 bv = [val.number for val in byNumber.values()]
226 sortedBV = bv[:]
227 sortedBV.sort()
228 assert sortedBV == bv
229 bv = [val.number for val in byNumber.itervalues()]
230 assert sortedBV == bv
232 byLetter = self.tableClass(self.orderTable, serverInfo=self.serverInfo,
233 arraysize=2, orderBy='ORDER BY letter,id',
234 iterSQL='WHERE letter>%s or (letter=%s and id>%s)',
235 iterColumns=('letter','letter','id'))
236 bl = [val.letter for val in byLetter.values()]
237 sortedBL = bl[:]
238 assert sortedBL == bl
239 bl = [val.letter for val in byLetter.itervalues()]
240 assert sortedBL == bl
242 ### @CTB need to test write access
243 def test_mapview(self):
244 'test MapView of SQL join'
245 self.sourceDB.catchIter = self.targetDB.catchIter = True
246 m = MapView(self.sourceDB, self.targetDB,"""\
247 SELECT t2.third_id FROM %s t1, %s t2
248 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
249 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
250 assert m[self.sourceDB[2]] == self.targetDB[7]
251 assert m[self.sourceDB[3]] == self.targetDB[99]
252 assert self.sourceDB[2] in m
253 try:
254 d = m[self.sourceDB[4]]
255 raise AssertionError('failed to trap non-unique mapping')
256 except KeyError:
257 pass
258 try:
259 r = ~m
260 raise AssertionError('failed to trap non-invertible mapping')
261 except ValueError:
262 pass
263 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
264 % self.sourceDB.name)
265 assert len(self.sourceDB) == 4
266 self.sourceDB.catchIter = False # next step will cause iteration
267 assert len(m) == 2
268 l = m.keys()
269 l.sort()
270 correct = [self.sourceDB[2],self.sourceDB[3]]
271 correct.sort()
272 assert l == correct
273 def test_mapview_inverse(self):
274 'test inverse MapView of SQL join'
275 self.sourceDB.catchIter = self.targetDB.catchIter = True
276 m = MapView(self.sourceDB, self.targetDB,"""\
277 SELECT t2.third_id FROM %s t1, %s t2
278 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
279 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
280 inverseSQL="""\
281 SELECT t1.my_id FROM %s t1, %s t2
282 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
283 """ % (self.joinTable1,self.joinTable2))
284 r = ~m # get the inverse
285 assert self.sourceDB[2] == r[self.targetDB[7]]
286 assert self.sourceDB[3] == r[self.targetDB[99]]
287 assert self.targetDB[7] in r
289 m = ~r # get the inverse of the inverse!
290 assert m[self.sourceDB[2]] == self.targetDB[7]
291 assert m[self.sourceDB[3]] == self.targetDB[99]
292 assert self.sourceDB[2] in m
293 try:
294 d = m[self.sourceDB[4]]
295 raise AssertionError('failed to trap non-unique mapping')
296 except KeyError:
297 pass
298 def test_graphview(self):
299 'test GraphView of SQL join'
300 self.sourceDB.catchIter = self.targetDB.catchIter = True
301 m = GraphView(self.sourceDB, self.targetDB,"""\
302 SELECT t2.third_id FROM %s t1, %s t2
303 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
304 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
305 d = m[self.sourceDB[4]]
306 assert len(d) == 2
307 assert self.targetDB[6] in d and self.targetDB[8] in d
308 assert self.sourceDB[2] in m
310 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
311 % self.sourceDB.name)
312 assert len(self.sourceDB) == 4
313 self.sourceDB.catchIter = False # next step will cause iteration
314 assert len(m) == 3
315 l = m.keys()
316 l.sort()
317 correct = [self.sourceDB[2],self.sourceDB[3],self.sourceDB[4]]
318 correct.sort()
319 assert l == correct
321 def test_graphview_inverse(self):
322 'test inverse GraphView of SQL join'
323 self.sourceDB.catchIter = self.targetDB.catchIter = True
324 m = GraphView(self.sourceDB, self.targetDB,"""\
325 SELECT t2.third_id FROM %s t1, %s t2
326 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
327 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
328 inverseSQL="""\
329 SELECT t1.my_id FROM %s t1, %s t2
330 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
331 """ % (self.joinTable1,self.joinTable2))
332 r = ~m # get the inverse
333 assert self.sourceDB[2] in r[self.targetDB[7]]
334 assert self.sourceDB[3] in r[self.targetDB[99]]
335 assert self.targetDB[7] in r
336 d = r[self.targetDB[6]]
337 assert len(d) == 1
338 assert self.sourceDB[4] in d
340 m = ~r # get inverse of the inverse!
341 d = m[self.sourceDB[4]]
342 assert len(d) == 2
343 assert self.targetDB[6] in d and self.targetDB[8] in d
344 assert self.sourceDB[2] in m
346 class SQLTable_No_SSCursor_Test(SQLTable_Test):
347 serverArgs = dict(serverSideCursors=False)
349 class SQLTable_OldIter_Test(SQLTable_Test):
350 serverArgs = dict(serverSideCursors=False,
351 blockIterators=False)
353 class SQLiteBase(testutil.SQLite_Mixin):
354 def sqlite_load(self):
355 self.load_data('sqltable_test', writeable=self.writeable)
357 class SQLiteTable_Test(SQLiteBase, SQLTable_Test):
358 pass
360 ## class SQLitePickle_Test(SQLiteTable_Test):
361 ## def setUp(self):
362 ## """Pickle / unpickle our serverInfo before trying to use it """
363 ## SQLiteTable_Test.setUp(self)
364 ## self.serverInfo.close()
365 ## import pickle
366 ## s = pickle.dumps(self.serverInfo)
367 ## del self.serverInfo
368 ## self.serverInfo = pickle.loads(s)
369 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
370 ## self.sourceDB = self.tableClass(self.joinTable1,
371 ## serverInfo=self.serverInfo)
372 ## self.targetDB = self.tableClass(self.joinTable2,
373 ## serverInfo=self.serverInfo)
375 class SQLTable_NoCache_Test(SQLTable_Test):
376 tableClass = SQLTableNoCache
378 class SQLiteTable_NoCache_Test(SQLiteTable_Test):
379 tableClass = SQLTableNoCache
381 class SQLTableRW_Test(SQLTable_Setup):
382 'test write operations'
383 writeable = True
384 def test_new(self):
385 'test row creation with auto inc ID'
386 self.db.catchIter = True # no iter expected in this test
387 n = len(self.db)
388 o = self.db.new(seq_id='freddy', start=3000, stop=4500)
389 assert len(self.db) == n + 1
390 t = self.tableClass(self.tableName,
391 serverInfo=self.serverInfo) # requery the db
392 t.catchIter = True # no iter expected in this test
393 result = t[o.id]
394 assert result.seq_id == 'freddy' and result.start==3000 \
395 and result.stop==4500
396 def test_new2(self):
397 'check row creation with specified ID'
398 self.db.catchIter = True # no iter expected in this test
399 n = len(self.db)
400 o = self.db.new(id=99, seq_id='jeff', start=3000, stop=4500)
401 assert len(self.db) == n + 1
402 assert o.id == 99
403 t = self.tableClass(self.tableName,
404 serverInfo=self.serverInfo) # requery the db
405 t.catchIter = True # no iter expected in this test
406 result = t[99]
407 assert result.seq_id == 'jeff' and result.start==3000 \
408 and result.stop==4500
409 def test_attr(self):
410 'test changing an attr value'
411 self.db.catchIter = True # no iter expected in this test
412 o = self.db[2]
413 assert o.seq_id == 'seq2'
414 o.seq_id = 'newval' # overwrite this attribute
415 assert o.seq_id == 'newval' # check cached value
416 t = self.tableClass(self.tableName,
417 serverInfo=self.serverInfo) # requery the db
418 t.catchIter = True # no iter expected in this test
419 result = t[2]
420 assert result.seq_id == 'newval'
421 def test_delitem(self):
422 'test deletion of a row'
423 self.db.catchIter = True # no iter expected in this test
424 n = len(self.db)
425 del self.db[1]
426 assert len(self.db) == n - 1
427 try:
428 result = self.db[1]
429 raise AssertionError('old ID still exists!')
430 except KeyError:
431 pass
432 def test_setitem(self):
433 'test assigning new ID to existing object'
434 self.db.catchIter = True # no iter expected in this test
435 o = self.db.new(id=17, seq_id='bob', start=2000, stop=2500)
436 self.db[13] = o
437 assert o.id == 13
438 try:
439 result = self.db[17]
440 raise AssertionError('old ID still exists!')
441 except KeyError:
442 pass
443 t = self.tableClass(self.tableName,
444 serverInfo=self.serverInfo) # requery the db
445 t.catchIter = True # no iter expected in this test
446 result = t[13]
447 assert result.seq_id == 'bob' and result.start==2000 \
448 and result.stop==2500
449 try:
450 result = t[17]
451 raise AssertionError('old ID still exists!')
452 except KeyError:
453 pass
456 class SQLiteTableRW_Test(SQLiteBase, SQLTableRW_Test):
457 pass
459 class SQLTableRW_NoCache_Test(SQLTableRW_Test):
460 tableClass = SQLTableNoCache
462 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test):
463 tableClass = SQLTableNoCache
465 class Ensembl_Test(unittest.TestCase):
467 def setUp(self):
468 # test will be skipped if mysql module or ensembldb server unavailable
470 logger.debug('accessing ensembldb.ensembl.org')
471 conn = DBServerInfo(host='ensembldb.ensembl.org', user='anonymous',
472 passwd='')
473 try:
474 translationDB = SQLTableCatcher('homo_sapiens_core_47_36i.translation',
475 serverInfo=conn)
476 translationDB.catchIter = True # should not iter in this test!
477 exonDB = SQLTable('homo_sapiens_core_47_36i.exon', serverInfo=conn)
478 except ImportError,e:
479 raise SkipTest(e)
481 sql_statement = '''SELECT t3.exon_id FROM
482 homo_sapiens_core_47_36i.translation AS tr,
483 homo_sapiens_core_47_36i.exon_transcript AS t1,
484 homo_sapiens_core_47_36i.exon_transcript AS t2,
485 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
486 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
487 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
488 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
489 t3.rank <= t2.rank ORDER BY t3.rank
491 self.translationExons = GraphView(translationDB, exonDB,
492 sql_statement, serverInfo=conn)
493 self.translation = translationDB[15121]
495 def test_orderBy(self):
496 "Ensemble access, test order by"
497 'test issue 53: ensure that the ORDER BY results are correct'
498 exons = self.translationExons[self.translation] # do the query
499 result = [e.id for e in exons]
500 correct = [95160,95020,95035,95050,95059,95069,95081,95088,95101,
501 95110,95172]
502 self.assertEqual(result, correct) # make sure the exact order matches
505 if __name__ == '__main__':
506 PygrTestProgram(verbosity=2)