added SQLTable_Test.test_orderBy()
[pygr.git] / tests / sqltable_test.py
blob0f75af8d6697611e3058fc262b03335766b070d0
1 import os, unittest
2 from testlib import testutil, PygrTestProgram, SkipTest
3 from pygr.sqlgraph import SQLTable, SQLTableNoCache,\
4 MapView, GraphView, DBServerInfo, import_sqlite
5 from pygr import logger
8 def catch_iterator(self, *args, **kwargs):
9 try:
10 assert not self.catchIter, 'this should not iterate!'
11 except AttributeError:
12 pass
13 return SQLTable.generic_iterator(self, *args, **kwargs)
16 class SQLTableCatcher(SQLTable):
17 generic_iterator = catch_iterator
19 class SQLTableNoCacheCatcher(SQLTableNoCache):
20 generic_iterator = catch_iterator
22 class SQLTable_Setup(unittest.TestCase):
23 tableClass = SQLTableCatcher
24 def __init__(self, *args, **kwargs):
25 unittest.TestCase.__init__(self, *args, **kwargs)
26 self.serverInfo = DBServerInfo() # share conn for all tests
27 def setUp(self):
28 try:
29 self.load_data(writeable=self.writeable)
30 except ImportError:
31 raise SkipTest('missing MySQLdb module?')
32 def load_data(self, tableName='test.sqltable_test', writeable=False):
33 'create 3 tables and load 9 rows for our tests'
34 self.tableName = tableName
35 self.joinTable1 = joinTable1 = tableName + '1'
36 self.joinTable2 = joinTable2 = tableName + '2'
37 createTable = """\
38 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
39 """ % tableName
40 self.db = self.tableClass(tableName, dropIfExists=True,
41 serverInfo=self.serverInfo,
42 createTable=createTable,
43 writeable=writeable)
44 self.sourceDB = self.tableClass(joinTable1, serverInfo=self.serverInfo,
45 dropIfExists=True, createTable="""\
46 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
47 other_id VARCHAR(16))
48 """ % joinTable1)
49 self.targetDB = self.tableClass(joinTable2, serverInfo=self.serverInfo,
50 dropIfExists=True, createTable="""\
51 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
52 other_id VARCHAR(16))
53 """ % joinTable2)
54 sql = """
55 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
56 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
57 INSERT INTO %s VALUES (2,'seq2')
58 INSERT INTO %s VALUES (3,'seq3')
59 INSERT INTO %s VALUES (4,'seq4')
60 INSERT INTO %s VALUES (7, 'seq2')
61 INSERT INTO %s VALUES (99, 'seq3')
62 INSERT INTO %s VALUES (6, 'seq4')
63 INSERT INTO %s VALUES (8, 'seq4')
64 """ % tuple(([tableName]*2) + ([joinTable1]*3) + ([joinTable2]*4))
65 for line in sql.strip().splitlines(): # insert our test data
66 self.db.cursor.execute(line.strip())
67 def tearDown(self):
68 self.db.cursor.execute('drop table if exists %s' % self.tableName)
69 self.db.cursor.execute('drop table if exists %s' % self.joinTable1)
70 self.db.cursor.execute('drop table if exists %s' % self.joinTable2)
71 self.serverInfo.close()
73 class SQLTable_Test(SQLTable_Setup):
74 writeable = False # read-only database interface
75 def test_keys(self):
76 k = self.db.keys()
77 k.sort()
78 assert k == [1, 2]
79 def test_len(self):
80 self.db.catchIter = True
81 assert len(self.db) == len(self.db.keys())
82 def test_contains(self):
83 self.db.catchIter = True
84 assert 1 in self.db
85 assert 2 in self.db
86 assert 'foo' not in self.db
87 def test_has_key(self):
88 self.db.catchIter = True
89 assert self.db.has_key(1)
90 assert self.db.has_key(2)
91 assert not self.db.has_key('foo')
92 def test_get(self):
93 self.db.catchIter = True
94 assert self.db.get('foo') is None
95 assert self.db.get(1) == self.db[1]
96 assert self.db.get(2) == self.db[2]
97 def test_items(self):
98 i = [ k for (k,v) in self.db.items() ]
99 i.sort()
100 assert i == [1, 2]
101 def test_iterkeys(self):
102 kk = self.db.keys()
103 kk.sort()
104 ik = list(self.db.iterkeys())
105 ik.sort()
106 assert kk == ik
107 def test_itervalues(self):
108 kv = self.db.values()
109 kv.sort()
110 iv = list(self.db.itervalues())
111 iv.sort()
112 assert kv == iv
113 def test_itervalues_long(self):
114 """test iterator isolation from queries run inside iterator loop """
115 sql = 'insert into %s (start) values (1)' % self.tableName
116 for i in range(40000): # insert 40000 rows
117 self.db.cursor.execute(sql)
118 iv = []
119 print 'begin itervalues()'
120 for o in self.db.itervalues():
121 status = 99 in self.db # make it do a query inside iterator loop
122 iv.append(o.id)
123 print 'begin values()'
124 kv = [o.id for o in self.db.values()]
125 assert len(kv) == len(iv)
126 assert kv == iv
127 print 'done'
128 def test_iteritems(self):
129 ki = self.db.items()
130 ki.sort()
131 ii = list(self.db.iteritems())
132 ii.sort()
133 assert ki == ii
134 def test_readonly(self):
135 'test error handling of write attempts to read-only DB'
136 self.db.catchIter = True # no iter expected in this test!
137 try:
138 self.db.new(seq_id='freddy', start=3000, stop=4500)
139 raise AssertionError('failed to trap attempt to write to db')
140 except ValueError:
141 pass
142 o = self.db[1]
143 try:
144 self.db[33] = o
145 raise AssertionError('failed to trap attempt to write to db')
146 except ValueError:
147 pass
148 try:
149 del self.db[2]
150 raise AssertionError('failed to trap attempt to write to db')
151 except ValueError:
152 pass
153 def test_orderBy(self):
154 'test iterator with orderBy, iterSQL, iterColumns'
155 self.targetDB.catchIter = True # should not iterate
156 self.targetDB.arraysize = 2 # force it to use multiple queries to finish
157 result = self.targetDB.keys()
158 assert result == [6, 7, 8, 99]
159 self.targetDB.catchIter = False # next statement will iterate
160 assert result == list(iter(self.targetDB))
161 self.targetDB.catchIter = True # should not iterate
162 self.targetDB.orderBy = 'ORDER BY other_id'
163 result = self.targetDB.keys()
164 assert result == [7, 99, 6, 8]
165 self.targetDB.catchIter = False # next statement will iterate
166 if self.serverInfo._serverType == 'mysql': # only test this for mysql
167 try:
168 assert result == list(iter(self.targetDB))
169 raise AssertionError('failed to trap missing iterSQL attr')
170 except AttributeError:
171 pass
172 self.targetDB.iterSQL = 'WHERE other_id>%s' # tell it how to slice
173 self.targetDB.iterColumns = ['other_id']
174 assert result == list(iter(self.targetDB))
175 result = self.targetDB.values()
176 assert result == [self.targetDB[7], self.targetDB[99],
177 self.targetDB[6], self.targetDB[8]]
178 assert result == list(self.targetDB.itervalues())
179 result = self.targetDB.items()
180 assert result == [(7, self.targetDB[7]), (99, self.targetDB[99]),
181 (6, self.targetDB[6]), (8, self.targetDB[8])]
182 assert result == list(self.targetDB.iteritems())
184 ### @CTB need to test write access
185 def test_mapview(self):
186 'test MapView of SQL join'
187 self.sourceDB.catchIter = self.targetDB.catchIter = True
188 m = MapView(self.sourceDB, self.targetDB,"""\
189 SELECT t2.third_id FROM %s t1, %s t2
190 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
191 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
192 assert m[self.sourceDB[2]] == self.targetDB[7]
193 assert m[self.sourceDB[3]] == self.targetDB[99]
194 assert self.sourceDB[2] in m
195 try:
196 d = m[self.sourceDB[4]]
197 raise AssertionError('failed to trap non-unique mapping')
198 except KeyError:
199 pass
200 try:
201 r = ~m
202 raise AssertionError('failed to trap non-invertible mapping')
203 except ValueError:
204 pass
205 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
206 % self.sourceDB.name)
207 assert len(self.sourceDB) == 4
208 self.sourceDB.catchIter = False # next step will cause iteration
209 assert len(m) == 2
210 l = m.keys()
211 l.sort()
212 correct = [self.sourceDB[2],self.sourceDB[3]]
213 correct.sort()
214 assert l == correct
215 def test_mapview_inverse(self):
216 'test inverse MapView of SQL join'
217 self.sourceDB.catchIter = self.targetDB.catchIter = True
218 m = MapView(self.sourceDB, self.targetDB,"""\
219 SELECT t2.third_id FROM %s t1, %s t2
220 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
221 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
222 inverseSQL="""\
223 SELECT t1.my_id FROM %s t1, %s t2
224 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
225 """ % (self.joinTable1,self.joinTable2))
226 r = ~m # get the inverse
227 assert self.sourceDB[2] == r[self.targetDB[7]]
228 assert self.sourceDB[3] == r[self.targetDB[99]]
229 assert self.targetDB[7] in r
231 m = ~r # get the inverse of the inverse!
232 assert m[self.sourceDB[2]] == self.targetDB[7]
233 assert m[self.sourceDB[3]] == self.targetDB[99]
234 assert self.sourceDB[2] in m
235 try:
236 d = m[self.sourceDB[4]]
237 raise AssertionError('failed to trap non-unique mapping')
238 except KeyError:
239 pass
240 def test_graphview(self):
241 'test GraphView of SQL join'
242 self.sourceDB.catchIter = self.targetDB.catchIter = True
243 m = GraphView(self.sourceDB, self.targetDB,"""\
244 SELECT t2.third_id FROM %s t1, %s t2
245 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
246 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
247 d = m[self.sourceDB[4]]
248 assert len(d) == 2
249 assert self.targetDB[6] in d and self.targetDB[8] in d
250 assert self.sourceDB[2] in m
252 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
253 % self.sourceDB.name)
254 assert len(self.sourceDB) == 4
255 self.sourceDB.catchIter = False # next step will cause iteration
256 assert len(m) == 3
257 l = m.keys()
258 l.sort()
259 correct = [self.sourceDB[2],self.sourceDB[3],self.sourceDB[4]]
260 correct.sort()
261 assert l == correct
263 def test_graphview_inverse(self):
264 'test inverse GraphView of SQL join'
265 self.sourceDB.catchIter = self.targetDB.catchIter = True
266 m = GraphView(self.sourceDB, self.targetDB,"""\
267 SELECT t2.third_id FROM %s t1, %s t2
268 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
269 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
270 inverseSQL="""\
271 SELECT t1.my_id FROM %s t1, %s t2
272 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
273 """ % (self.joinTable1,self.joinTable2))
274 r = ~m # get the inverse
275 assert self.sourceDB[2] in r[self.targetDB[7]]
276 assert self.sourceDB[3] in r[self.targetDB[99]]
277 assert self.targetDB[7] in r
278 d = r[self.targetDB[6]]
279 assert len(d) == 1
280 assert self.sourceDB[4] in d
282 m = ~r # get inverse of the inverse!
283 d = m[self.sourceDB[4]]
284 assert len(d) == 2
285 assert self.targetDB[6] in d and self.targetDB[8] in d
286 assert self.sourceDB[2] in m
288 class SQLiteBase(testutil.SQLite_Mixin):
289 def sqlite_load(self):
290 self.load_data('sqltable_test', writeable=self.writeable)
292 class SQLiteTable_Test(SQLiteBase, SQLTable_Test):
293 pass
295 ## class SQLitePickle_Test(SQLiteTable_Test):
296 ## def setUp(self):
297 ## """Pickle / unpickle our serverInfo before trying to use it """
298 ## SQLiteTable_Test.setUp(self)
299 ## self.serverInfo.close()
300 ## import pickle
301 ## s = pickle.dumps(self.serverInfo)
302 ## del self.serverInfo
303 ## self.serverInfo = pickle.loads(s)
304 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
305 ## self.sourceDB = self.tableClass(self.joinTable1,
306 ## serverInfo=self.serverInfo)
307 ## self.targetDB = self.tableClass(self.joinTable2,
308 ## serverInfo=self.serverInfo)
310 class SQLTable_NoCache_Test(SQLTable_Test):
311 tableClass = SQLTableNoCache
313 class SQLiteTable_NoCache_Test(SQLiteTable_Test):
314 tableClass = SQLTableNoCache
316 class SQLTableRW_Test(SQLTable_Setup):
317 'test write operations'
318 writeable = True
319 def test_new(self):
320 'test row creation with auto inc ID'
321 self.db.catchIter = True # no iter expected in this test
322 n = len(self.db)
323 o = self.db.new(seq_id='freddy', start=3000, stop=4500)
324 assert len(self.db) == n + 1
325 t = self.tableClass(self.tableName,
326 serverInfo=self.serverInfo) # requery the db
327 t.catchIter = True # no iter expected in this test
328 result = t[o.id]
329 assert result.seq_id == 'freddy' and result.start==3000 \
330 and result.stop==4500
331 def test_new2(self):
332 'check row creation with specified ID'
333 self.db.catchIter = True # no iter expected in this test
334 n = len(self.db)
335 o = self.db.new(id=99, seq_id='jeff', start=3000, stop=4500)
336 assert len(self.db) == n + 1
337 assert o.id == 99
338 t = self.tableClass(self.tableName,
339 serverInfo=self.serverInfo) # requery the db
340 t.catchIter = True # no iter expected in this test
341 result = t[99]
342 assert result.seq_id == 'jeff' and result.start==3000 \
343 and result.stop==4500
344 def test_attr(self):
345 'test changing an attr value'
346 self.db.catchIter = True # no iter expected in this test
347 o = self.db[2]
348 assert o.seq_id == 'seq2'
349 o.seq_id = 'newval' # overwrite this attribute
350 assert o.seq_id == 'newval' # check cached value
351 t = self.tableClass(self.tableName,
352 serverInfo=self.serverInfo) # requery the db
353 t.catchIter = True # no iter expected in this test
354 result = t[2]
355 assert result.seq_id == 'newval'
356 def test_delitem(self):
357 'test deletion of a row'
358 self.db.catchIter = True # no iter expected in this test
359 n = len(self.db)
360 del self.db[1]
361 assert len(self.db) == n - 1
362 try:
363 result = self.db[1]
364 raise AssertionError('old ID still exists!')
365 except KeyError:
366 pass
367 def test_setitem(self):
368 'test assigning new ID to existing object'
369 self.db.catchIter = True # no iter expected in this test
370 o = self.db.new(id=17, seq_id='bob', start=2000, stop=2500)
371 self.db[13] = o
372 assert o.id == 13
373 try:
374 result = self.db[17]
375 raise AssertionError('old ID still exists!')
376 except KeyError:
377 pass
378 t = self.tableClass(self.tableName,
379 serverInfo=self.serverInfo) # requery the db
380 t.catchIter = True # no iter expected in this test
381 result = t[13]
382 assert result.seq_id == 'bob' and result.start==2000 \
383 and result.stop==2500
384 try:
385 result = t[17]
386 raise AssertionError('old ID still exists!')
387 except KeyError:
388 pass
391 class SQLiteTableRW_Test(SQLiteBase, SQLTableRW_Test):
392 pass
394 class SQLTableRW_NoCache_Test(SQLTableRW_Test):
395 tableClass = SQLTableNoCache
397 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test):
398 tableClass = SQLTableNoCache
400 class Ensembl_Test(unittest.TestCase):
402 def setUp(self):
403 # test will be skipped if mysql module or ensembldb server unavailable
405 logger.debug('accessing ensembldb.ensembl.org')
406 conn = DBServerInfo(host='ensembldb.ensembl.org', user='anonymous',
407 passwd='')
408 try:
409 translationDB = SQLTableCatcher('homo_sapiens_core_47_36i.translation',
410 serverInfo=conn)
411 translationDB.catchIter = True # should not iter in this test!
412 exonDB = SQLTable('homo_sapiens_core_47_36i.exon', serverInfo=conn)
413 except ImportError,e:
414 raise SkipTest(e)
416 sql_statement = '''SELECT t3.exon_id FROM
417 homo_sapiens_core_47_36i.translation AS tr,
418 homo_sapiens_core_47_36i.exon_transcript AS t1,
419 homo_sapiens_core_47_36i.exon_transcript AS t2,
420 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
421 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
422 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
423 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
424 t3.rank <= t2.rank ORDER BY t3.rank
426 print 'creating GraphView...'
427 self.translationExons = GraphView(translationDB, exonDB,
428 sql_statement, serverInfo=conn)
429 print 'getting translation...'
430 self.translation = translationDB[15121]
432 def test_orderBy(self):
433 "Ensemble access, test order by"
434 'test issue 53: ensure that the ORDER BY results are correct'
435 print 'do mapping'
436 exons = self.translationExons[self.translation] # do the query
437 print 'done'
438 result = [e.id for e in exons]
439 correct = [95160,95020,95035,95050,95059,95069,95081,95088,95101,
440 95110,95172]
441 self.assertEqual(result, correct) # make sure the exact order matches
444 if __name__ == '__main__':
445 PygrTestProgram(verbosity=2)