Merge commit 'remotes/mkszuba/sqltable_orderby' into sscursor
[pygr.git] / tests / sqltable_test.py
blob3d25dc04102a862f77b5a1642cf5ee5a89451630
1 import os, random, string, unittest
2 from testlib import testutil, PygrTestProgram, SkipTest
3 from pygr.sqlgraph import SQLTable, SQLTableNoCache,\
4 MapView, GraphView, DBServerInfo, import_sqlite
5 from pygr import logger
8 def catch_iterator(self, *args, **kwargs):
9 try:
10 assert not self.catchIter, 'this should not iterate!'
11 except AttributeError:
12 pass
13 return SQLTable.generic_iterator(self, *args, **kwargs)
16 class SQLTableCatcher(SQLTable):
17 generic_iterator = catch_iterator
19 class SQLTableNoCacheCatcher(SQLTableNoCache):
20 generic_iterator = catch_iterator
22 class SQLTable_Setup(unittest.TestCase):
23 tableClass = SQLTableCatcher
24 def __init__(self, *args, **kwargs):
25 unittest.TestCase.__init__(self, *args, **kwargs)
26 self.serverInfo = DBServerInfo() # share conn for all tests
27 def setUp(self):
28 try:
29 self.load_data(writeable=self.writeable)
30 except ImportError:
31 raise SkipTest('missing MySQLdb module?')
32 def load_data(self, tableName='test.sqltable_test', writeable=False):
33 'create 3 tables and load 9 rows for our tests'
34 self.tableName = tableName
35 self.joinTable1 = joinTable1 = tableName + '1'
36 self.joinTable2 = joinTable2 = tableName + '2'
37 createTable = """\
38 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
39 """ % tableName
40 self.db = self.tableClass(tableName, dropIfExists=True,
41 serverInfo=self.serverInfo,
42 createTable=createTable,
43 writeable=writeable)
44 self.sourceDB = self.tableClass(joinTable1, serverInfo=self.serverInfo,
45 dropIfExists=True, createTable="""\
46 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
47 other_id VARCHAR(16))
48 """ % joinTable1)
49 self.targetDB = self.tableClass(joinTable2, serverInfo=self.serverInfo,
50 dropIfExists=True, createTable="""\
51 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
52 other_id VARCHAR(16))
53 """ % joinTable2)
54 sql = """
55 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
56 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
57 INSERT INTO %s VALUES (2,'seq2')
58 INSERT INTO %s VALUES (3,'seq3')
59 INSERT INTO %s VALUES (4,'seq4')
60 INSERT INTO %s VALUES (7, 'seq2')
61 INSERT INTO %s VALUES (99, 'seq3')
62 INSERT INTO %s VALUES (6, 'seq4')
63 INSERT INTO %s VALUES (8, 'seq4')
64 """ % tuple(([tableName]*2) + ([joinTable1]*3) + ([joinTable2]*4))
65 for line in sql.strip().splitlines(): # insert our test data
66 self.db.cursor.execute(line.strip())
68 # Another table, for the "ORDER BY" test
69 self.orderTable = tableName + '_orderBy'
70 self.db.cursor.execute("""\
71 CREATE TABLE %s (id INTEGER PRIMARY KEY, number INTEGER, letter VARCHAR(1))
72 """ % self.orderTable)
73 for row in range(0, 10):
74 self.db.cursor.execute('INSERT INTO %s VALUES (%d, %d, \'%s\')' %
75 (self.orderTable, row, random.randint(0, 99),
76 string.lowercase[random.randint(0,
77 len(string.lowercase)
78 - 1)]))
79 def tearDown(self):
80 self.db.cursor.execute('drop table if exists %s' % self.tableName)
81 self.db.cursor.execute('drop table if exists %s' % self.joinTable1)
82 self.db.cursor.execute('drop table if exists %s' % self.joinTable2)
83 self.db.cursor.execute('drop table if exists %s' % self.orderTable)
84 self.serverInfo.close()
86 class SQLTable_Test(SQLTable_Setup):
87 writeable = False # read-only database interface
88 def test_keys(self):
89 k = self.db.keys()
90 k.sort()
91 assert k == [1, 2]
92 def test_len(self):
93 self.db.catchIter = True
94 assert len(self.db) == len(self.db.keys())
95 def test_contains(self):
96 self.db.catchIter = True
97 assert 1 in self.db
98 assert 2 in self.db
99 assert 'foo' not in self.db
100 def test_has_key(self):
101 self.db.catchIter = True
102 assert self.db.has_key(1)
103 assert self.db.has_key(2)
104 assert not self.db.has_key('foo')
105 def test_get(self):
106 self.db.catchIter = True
107 assert self.db.get('foo') is None
108 assert self.db.get(1) == self.db[1]
109 assert self.db.get(2) == self.db[2]
110 def test_items(self):
111 i = [ k for (k,v) in self.db.items() ]
112 i.sort()
113 assert i == [1, 2]
114 def test_iterkeys(self):
115 kk = self.db.keys()
116 kk.sort()
117 ik = list(self.db.iterkeys())
118 ik.sort()
119 assert kk == ik
120 def test_itervalues(self):
121 kv = self.db.values()
122 kv.sort()
123 iv = list(self.db.itervalues())
124 iv.sort()
125 assert kv == iv
126 def test_itervalues_long(self):
127 """test iterator isolation from queries run inside iterator loop """
128 sql = 'insert into %s (start) values (1)' % self.tableName
129 for i in range(40000): # insert 40000 rows
130 self.db.cursor.execute(sql)
131 iv = []
132 print 'begin itervalues()'
133 for o in self.db.itervalues():
134 status = 99 in self.db # make it do a query inside iterator loop
135 iv.append(o.id)
136 print 'begin values()'
137 kv = [o.id for o in self.db.values()]
138 assert len(kv) == len(iv)
139 assert kv == iv
140 print 'done'
141 def test_iteritems(self):
142 ki = self.db.items()
143 ki.sort()
144 ii = list(self.db.iteritems())
145 ii.sort()
146 assert ki == ii
147 def test_readonly(self):
148 'test error handling of write attempts to read-only DB'
149 self.db.catchIter = True # no iter expected in this test!
150 try:
151 self.db.new(seq_id='freddy', start=3000, stop=4500)
152 raise AssertionError('failed to trap attempt to write to db')
153 except ValueError:
154 pass
155 o = self.db[1]
156 try:
157 self.db[33] = o
158 raise AssertionError('failed to trap attempt to write to db')
159 except ValueError:
160 pass
161 try:
162 del self.db[2]
163 raise AssertionError('failed to trap attempt to write to db')
164 except ValueError:
165 pass
166 def test_orderBy(self):
167 'test iterator with orderBy, iterSQL, iterColumns'
168 self.targetDB.catchIter = True # should not iterate
169 self.targetDB.arraysize = 2 # force it to use multiple queries to finish
170 result = self.targetDB.keys()
171 assert result == [6, 7, 8, 99]
172 self.targetDB.catchIter = False # next statement will iterate
173 assert result == list(iter(self.targetDB))
174 self.targetDB.catchIter = True # should not iterate
175 self.targetDB.orderBy = 'ORDER BY other_id'
176 result = self.targetDB.keys()
177 assert result == [7, 99, 6, 8]
178 self.targetDB.catchIter = False # next statement will iterate
179 if self.serverInfo._serverType == 'mysql': # only test this for mysql
180 try:
181 assert result == list(iter(self.targetDB))
182 raise AssertionError('failed to trap missing iterSQL attr')
183 except AttributeError:
184 pass
185 self.targetDB.iterSQL = 'WHERE other_id>%s' # tell it how to slice
186 self.targetDB.iterColumns = ['other_id']
187 assert result == list(iter(self.targetDB))
188 result = self.targetDB.values()
189 assert result == [self.targetDB[7], self.targetDB[99],
190 self.targetDB[6], self.targetDB[8]]
191 assert result == list(self.targetDB.itervalues())
192 result = self.targetDB.items()
193 assert result == [(7, self.targetDB[7]), (99, self.targetDB[99]),
194 (6, self.targetDB[6]), (8, self.targetDB[8])]
195 assert result == list(self.targetDB.iteritems())
197 ### @CTB need to test write access
198 def test_mapview(self):
199 'test MapView of SQL join'
200 self.sourceDB.catchIter = self.targetDB.catchIter = True
201 m = MapView(self.sourceDB, self.targetDB,"""\
202 SELECT t2.third_id FROM %s t1, %s t2
203 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
204 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
205 assert m[self.sourceDB[2]] == self.targetDB[7]
206 assert m[self.sourceDB[3]] == self.targetDB[99]
207 assert self.sourceDB[2] in m
208 try:
209 d = m[self.sourceDB[4]]
210 raise AssertionError('failed to trap non-unique mapping')
211 except KeyError:
212 pass
213 try:
214 r = ~m
215 raise AssertionError('failed to trap non-invertible mapping')
216 except ValueError:
217 pass
218 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
219 % self.sourceDB.name)
220 assert len(self.sourceDB) == 4
221 self.sourceDB.catchIter = False # next step will cause iteration
222 assert len(m) == 2
223 l = m.keys()
224 l.sort()
225 correct = [self.sourceDB[2],self.sourceDB[3]]
226 correct.sort()
227 assert l == correct
228 def test_mapview_inverse(self):
229 'test inverse MapView of SQL join'
230 self.sourceDB.catchIter = self.targetDB.catchIter = True
231 m = MapView(self.sourceDB, self.targetDB,"""\
232 SELECT t2.third_id FROM %s t1, %s t2
233 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
234 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
235 inverseSQL="""\
236 SELECT t1.my_id FROM %s t1, %s t2
237 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
238 """ % (self.joinTable1,self.joinTable2))
239 r = ~m # get the inverse
240 assert self.sourceDB[2] == r[self.targetDB[7]]
241 assert self.sourceDB[3] == r[self.targetDB[99]]
242 assert self.targetDB[7] in r
244 m = ~r # get the inverse of the inverse!
245 assert m[self.sourceDB[2]] == self.targetDB[7]
246 assert m[self.sourceDB[3]] == self.targetDB[99]
247 assert self.sourceDB[2] in m
248 try:
249 d = m[self.sourceDB[4]]
250 raise AssertionError('failed to trap non-unique mapping')
251 except KeyError:
252 pass
253 def test_graphview(self):
254 'test GraphView of SQL join'
255 self.sourceDB.catchIter = self.targetDB.catchIter = True
256 m = GraphView(self.sourceDB, self.targetDB,"""\
257 SELECT t2.third_id FROM %s t1, %s t2
258 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
259 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
260 d = m[self.sourceDB[4]]
261 assert len(d) == 2
262 assert self.targetDB[6] in d and self.targetDB[8] in d
263 assert self.sourceDB[2] in m
265 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
266 % self.sourceDB.name)
267 assert len(self.sourceDB) == 4
268 self.sourceDB.catchIter = False # next step will cause iteration
269 assert len(m) == 3
270 l = m.keys()
271 l.sort()
272 correct = [self.sourceDB[2],self.sourceDB[3],self.sourceDB[4]]
273 correct.sort()
274 assert l == correct
276 def test_graphview_inverse(self):
277 'test inverse GraphView of SQL join'
278 self.sourceDB.catchIter = self.targetDB.catchIter = True
279 m = GraphView(self.sourceDB, self.targetDB,"""\
280 SELECT t2.third_id FROM %s t1, %s t2
281 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
282 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
283 inverseSQL="""\
284 SELECT t1.my_id FROM %s t1, %s t2
285 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
286 """ % (self.joinTable1,self.joinTable2))
287 r = ~m # get the inverse
288 assert self.sourceDB[2] in r[self.targetDB[7]]
289 assert self.sourceDB[3] in r[self.targetDB[99]]
290 assert self.targetDB[7] in r
291 d = r[self.targetDB[6]]
292 assert len(d) == 1
293 assert self.sourceDB[4] in d
295 m = ~r # get inverse of the inverse!
296 d = m[self.sourceDB[4]]
297 assert len(d) == 2
298 assert self.targetDB[6] in d and self.targetDB[8] in d
299 assert self.sourceDB[2] in m
301 class SQLiteBase(testutil.SQLite_Mixin):
302 def sqlite_load(self):
303 self.load_data('sqltable_test', writeable=self.writeable)
305 class SQLiteTable_Test(SQLiteBase, SQLTable_Test):
306 def test_orderby(self):
307 'test orderBy in SQLTable'
309 byNumber = self.tableClass(self.orderTable, serverInfo=self.serverInfo,
310 orderBy='ORDER BY number')
311 bv = [ ]
312 for val in byNumber.values():
313 bv.append(val.number)
314 sortedBV = bv[:]
315 sortedBV.sort()
316 assert sortedBV == bv
318 byLetter = self.tableClass(self.orderTable, serverInfo=self.serverInfo,
319 orderBy='ORDER BY letter')
320 bl = [ ]
321 for val in byLetter.values():
322 bl.append(val.letter)
323 sortedBL = bl[:]
324 assert sortedBL == bl
326 ## class SQLitePickle_Test(SQLiteTable_Test):
327 ## def setUp(self):
328 ## """Pickle / unpickle our serverInfo before trying to use it """
329 ## SQLiteTable_Test.setUp(self)
330 ## self.serverInfo.close()
331 ## import pickle
332 ## s = pickle.dumps(self.serverInfo)
333 ## del self.serverInfo
334 ## self.serverInfo = pickle.loads(s)
335 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
336 ## self.sourceDB = self.tableClass(self.joinTable1,
337 ## serverInfo=self.serverInfo)
338 ## self.targetDB = self.tableClass(self.joinTable2,
339 ## serverInfo=self.serverInfo)
341 class SQLTable_NoCache_Test(SQLTable_Test):
342 tableClass = SQLTableNoCache
344 class SQLiteTable_NoCache_Test(SQLiteTable_Test):
345 tableClass = SQLTableNoCache
347 class SQLTableRW_Test(SQLTable_Setup):
348 'test write operations'
349 writeable = True
350 def test_new(self):
351 'test row creation with auto inc ID'
352 self.db.catchIter = True # no iter expected in this test
353 n = len(self.db)
354 o = self.db.new(seq_id='freddy', start=3000, stop=4500)
355 assert len(self.db) == n + 1
356 t = self.tableClass(self.tableName,
357 serverInfo=self.serverInfo) # requery the db
358 t.catchIter = True # no iter expected in this test
359 result = t[o.id]
360 assert result.seq_id == 'freddy' and result.start==3000 \
361 and result.stop==4500
362 def test_new2(self):
363 'check row creation with specified ID'
364 self.db.catchIter = True # no iter expected in this test
365 n = len(self.db)
366 o = self.db.new(id=99, seq_id='jeff', start=3000, stop=4500)
367 assert len(self.db) == n + 1
368 assert o.id == 99
369 t = self.tableClass(self.tableName,
370 serverInfo=self.serverInfo) # requery the db
371 t.catchIter = True # no iter expected in this test
372 result = t[99]
373 assert result.seq_id == 'jeff' and result.start==3000 \
374 and result.stop==4500
375 def test_attr(self):
376 'test changing an attr value'
377 self.db.catchIter = True # no iter expected in this test
378 o = self.db[2]
379 assert o.seq_id == 'seq2'
380 o.seq_id = 'newval' # overwrite this attribute
381 assert o.seq_id == 'newval' # check cached value
382 t = self.tableClass(self.tableName,
383 serverInfo=self.serverInfo) # requery the db
384 t.catchIter = True # no iter expected in this test
385 result = t[2]
386 assert result.seq_id == 'newval'
387 def test_delitem(self):
388 'test deletion of a row'
389 self.db.catchIter = True # no iter expected in this test
390 n = len(self.db)
391 del self.db[1]
392 assert len(self.db) == n - 1
393 try:
394 result = self.db[1]
395 raise AssertionError('old ID still exists!')
396 except KeyError:
397 pass
398 def test_setitem(self):
399 'test assigning new ID to existing object'
400 self.db.catchIter = True # no iter expected in this test
401 o = self.db.new(id=17, seq_id='bob', start=2000, stop=2500)
402 self.db[13] = o
403 assert o.id == 13
404 try:
405 result = self.db[17]
406 raise AssertionError('old ID still exists!')
407 except KeyError:
408 pass
409 t = self.tableClass(self.tableName,
410 serverInfo=self.serverInfo) # requery the db
411 t.catchIter = True # no iter expected in this test
412 result = t[13]
413 assert result.seq_id == 'bob' and result.start==2000 \
414 and result.stop==2500
415 try:
416 result = t[17]
417 raise AssertionError('old ID still exists!')
418 except KeyError:
419 pass
422 class SQLiteTableRW_Test(SQLiteBase, SQLTableRW_Test):
423 pass
425 class SQLTableRW_NoCache_Test(SQLTableRW_Test):
426 tableClass = SQLTableNoCache
428 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test):
429 tableClass = SQLTableNoCache
431 class Ensembl_Test(unittest.TestCase):
433 def setUp(self):
434 # test will be skipped if mysql module or ensembldb server unavailable
436 logger.debug('accessing ensembldb.ensembl.org')
437 conn = DBServerInfo(host='ensembldb.ensembl.org', user='anonymous',
438 passwd='')
439 try:
440 translationDB = SQLTableCatcher('homo_sapiens_core_47_36i.translation',
441 serverInfo=conn)
442 translationDB.catchIter = True # should not iter in this test!
443 exonDB = SQLTable('homo_sapiens_core_47_36i.exon', serverInfo=conn)
444 except ImportError,e:
445 raise SkipTest(e)
447 sql_statement = '''SELECT t3.exon_id FROM
448 homo_sapiens_core_47_36i.translation AS tr,
449 homo_sapiens_core_47_36i.exon_transcript AS t1,
450 homo_sapiens_core_47_36i.exon_transcript AS t2,
451 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
452 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
453 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
454 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
455 t3.rank <= t2.rank ORDER BY t3.rank
457 print 'creating GraphView...'
458 self.translationExons = GraphView(translationDB, exonDB,
459 sql_statement, serverInfo=conn)
460 print 'getting translation...'
461 self.translation = translationDB[15121]
463 def test_orderBy(self):
464 "Ensemble access, test order by"
465 'test issue 53: ensure that the ORDER BY results are correct'
466 print 'do mapping'
467 exons = self.translationExons[self.translation] # do the query
468 print 'done'
469 result = [e.id for e in exons]
470 correct = [95160,95020,95035,95050,95059,95069,95081,95088,95101,
471 95110,95172]
472 self.assertEqual(result, correct) # make sure the exact order matches
475 if __name__ == '__main__':
476 PygrTestProgram(verbosity=2)