added test for unwanted iteration
[pygr.git] / tests / sqltable_test.py
blob0228cda59c44078538383398056d0bdc62c98270
1 import os, unittest
2 from testlib import testutil, PygrTestProgram, SkipTest
3 from pygr.sqlgraph import SQLTable, SQLTableNoCache,\
4 MapView, GraphView, DBServerInfo, import_sqlite
5 from pygr import logger
7 class SQLTableCatcher(SQLTable):
8 def generic_iterator(self, *args, **kwargs):
9 try:
10 assert not self.catchIter, 'this should not iterate!'
11 except AttributeError:
12 pass
13 return SQLTable.generic_iterator(self, *args, **kwargs)
15 class SQLTable_Setup(unittest.TestCase):
16 tableClass = SQLTableCatcher
17 def __init__(self, *args, **kwargs):
18 unittest.TestCase.__init__(self, *args, **kwargs)
19 self.serverInfo = DBServerInfo() # share conn for all tests
20 def setUp(self):
21 try:
22 self.load_data(writeable=self.writeable)
23 except ImportError:
24 raise SkipTest('missing MySQLdb module?')
25 def load_data(self, tableName='test.sqltable_test', writeable=False):
26 'create 3 tables and load 9 rows for our tests'
27 self.tableName = tableName
28 self.joinTable1 = joinTable1 = tableName + '1'
29 self.joinTable2 = joinTable2 = tableName + '2'
30 createTable = """\
31 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
32 """ % tableName
33 self.db = self.tableClass(tableName, dropIfExists=True,
34 serverInfo=self.serverInfo,
35 createTable=createTable,
36 writeable=writeable)
37 self.sourceDB = self.tableClass(joinTable1, serverInfo=self.serverInfo,
38 dropIfExists=True, createTable="""\
39 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
40 other_id VARCHAR(16))
41 """ % joinTable1)
42 self.targetDB = self.tableClass(joinTable2, serverInfo=self.serverInfo,
43 dropIfExists=True, createTable="""\
44 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
45 other_id VARCHAR(16))
46 """ % joinTable2)
47 sql = """
48 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
49 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
50 INSERT INTO %s VALUES (2,'seq2')
51 INSERT INTO %s VALUES (3,'seq3')
52 INSERT INTO %s VALUES (4,'seq4')
53 INSERT INTO %s VALUES (7, 'seq2')
54 INSERT INTO %s VALUES (99, 'seq3')
55 INSERT INTO %s VALUES (6, 'seq4')
56 INSERT INTO %s VALUES (8, 'seq4')
57 """ % tuple(([tableName]*2) + ([joinTable1]*3) + ([joinTable2]*4))
58 for line in sql.strip().splitlines(): # insert our test data
59 self.db.cursor.execute(line.strip())
60 def tearDown(self):
61 self.db.cursor.execute('drop table if exists %s' % self.tableName)
62 self.db.cursor.execute('drop table if exists %s' % self.joinTable1)
63 self.db.cursor.execute('drop table if exists %s' % self.joinTable2)
64 self.serverInfo.close()
66 class SQLTable_Test(SQLTable_Setup):
67 writeable = False # read-only database interface
68 def test_keys(self):
69 k = self.db.keys()
70 k.sort()
71 assert k == [1, 2]
72 def test_len(self):
73 assert len(self.db) == len(self.db.keys())
74 def test_contains(self):
75 assert 1 in self.db
76 assert 2 in self.db
77 assert 'foo' not in self.db
78 def test_has_key(self):
79 assert self.db.has_key(1)
80 assert self.db.has_key(2)
81 assert not self.db.has_key('foo')
82 def test_get(self):
83 assert self.db.get('foo') is None
84 assert self.db.get(1) == self.db[1]
85 assert self.db.get(2) == self.db[2]
86 def test_items(self):
87 i = [ k for (k,v) in self.db.items() ]
88 i.sort()
89 assert i == [1, 2]
90 def test_iterkeys(self):
91 kk = self.db.keys()
92 kk.sort()
93 ik = list(self.db.iterkeys())
94 ik.sort()
95 assert kk == ik
96 def test_itervalues(self):
97 kv = self.db.values()
98 kv.sort()
99 iv = list(self.db.itervalues())
100 iv.sort()
101 assert kv == iv
102 def test_itervalues_long(self):
103 """test iterator isolation from queries run inside iterator loop """
104 sql = 'insert into %s (start) values (1)' % self.tableName
105 for i in range(40000): # insert 40000 rows
106 self.db.cursor.execute(sql)
107 iv = []
108 print 'begin itervalues()'
109 for o in self.db.itervalues():
110 status = 99 in self.db # make it do a query inside iterator loop
111 iv.append(o.id)
112 print 'begin values()'
113 kv = [o.id for o in self.db.values()]
114 assert len(kv) == len(iv)
115 assert kv == iv
116 print 'done'
117 def test_iteritems(self):
118 ki = self.db.items()
119 ki.sort()
120 ii = list(self.db.iteritems())
121 ii.sort()
122 assert ki == ii
123 def test_readonly(self):
124 'test error handling of write attempts to read-only DB'
125 try:
126 self.db.new(seq_id='freddy', start=3000, stop=4500)
127 raise AssertionError('failed to trap attempt to write to db')
128 except ValueError:
129 pass
130 o = self.db[1]
131 try:
132 self.db[33] = o
133 raise AssertionError('failed to trap attempt to write to db')
134 except ValueError:
135 pass
136 try:
137 del self.db[2]
138 raise AssertionError('failed to trap attempt to write to db')
139 except ValueError:
140 pass
142 ### @CTB need to test write access
143 def test_mapview(self):
144 'test MapView of SQL join'
145 m = MapView(self.sourceDB, self.targetDB,"""\
146 SELECT t2.third_id FROM %s t1, %s t2
147 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
148 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
149 assert m[self.sourceDB[2]] == self.targetDB[7]
150 assert m[self.sourceDB[3]] == self.targetDB[99]
151 assert self.sourceDB[2] in m
152 try:
153 d = m[self.sourceDB[4]]
154 raise AssertionError('failed to trap non-unique mapping')
155 except KeyError:
156 pass
157 try:
158 r = ~m
159 raise AssertionError('failed to trap non-invertible mapping')
160 except ValueError:
161 pass
162 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
163 % self.sourceDB.name)
164 assert len(self.sourceDB) == 4
165 assert len(m) == 2
166 l = m.keys()
167 l.sort()
168 correct = [self.sourceDB[2],self.sourceDB[3]]
169 correct.sort()
170 assert l == correct
171 def test_mapview_inverse(self):
172 'test inverse MapView of SQL join'
173 self.sourceDB.catchIter = self.targetDB.catchIter = True
174 m = MapView(self.sourceDB, self.targetDB,"""\
175 SELECT t2.third_id FROM %s t1, %s t2
176 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
177 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
178 inverseSQL="""\
179 SELECT t1.my_id FROM %s t1, %s t2
180 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
181 """ % (self.joinTable1,self.joinTable2))
182 r = ~m # get the inverse
183 assert self.sourceDB[2] == r[self.targetDB[7]]
184 assert self.sourceDB[3] == r[self.targetDB[99]]
185 assert self.targetDB[7] in r
187 m = ~r # get the inverse of the inverse!
188 assert m[self.sourceDB[2]] == self.targetDB[7]
189 assert m[self.sourceDB[3]] == self.targetDB[99]
190 assert self.sourceDB[2] in m
191 try:
192 d = m[self.sourceDB[4]]
193 raise AssertionError('failed to trap non-unique mapping')
194 except KeyError:
195 pass
196 def test_graphview(self):
197 'test GraphView of SQL join'
198 m = GraphView(self.sourceDB, self.targetDB,"""\
199 SELECT t2.third_id FROM %s t1, %s t2
200 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
201 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
202 d = m[self.sourceDB[4]]
203 assert len(d) == 2
204 assert self.targetDB[6] in d and self.targetDB[8] in d
205 assert self.sourceDB[2] in m
207 self.sourceDB.cursor.execute("INSERT INTO %s VALUES (5,'seq78')"
208 % self.sourceDB.name)
209 assert len(self.sourceDB) == 4
210 assert len(m) == 3
211 l = m.keys()
212 l.sort()
213 correct = [self.sourceDB[2],self.sourceDB[3],self.sourceDB[4]]
214 correct.sort()
215 assert l == correct
217 def test_graphview_inverse(self):
218 'test inverse GraphView of SQL join'
219 m = GraphView(self.sourceDB, self.targetDB,"""\
220 SELECT t2.third_id FROM %s t1, %s t2
221 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
222 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
223 inverseSQL="""\
224 SELECT t1.my_id FROM %s t1, %s t2
225 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
226 """ % (self.joinTable1,self.joinTable2))
227 r = ~m # get the inverse
228 assert self.sourceDB[2] in r[self.targetDB[7]]
229 assert self.sourceDB[3] in r[self.targetDB[99]]
230 assert self.targetDB[7] in r
231 d = r[self.targetDB[6]]
232 assert len(d) == 1
233 assert self.sourceDB[4] in d
235 m = ~r # get inverse of the inverse!
236 d = m[self.sourceDB[4]]
237 assert len(d) == 2
238 assert self.targetDB[6] in d and self.targetDB[8] in d
239 assert self.sourceDB[2] in m
241 class SQLiteBase(testutil.SQLite_Mixin):
242 def sqlite_load(self):
243 self.load_data('sqltable_test', writeable=self.writeable)
245 class SQLiteTable_Test(SQLiteBase, SQLTable_Test):
246 pass
248 ## class SQLitePickle_Test(SQLiteTable_Test):
249 ## def setUp(self):
250 ## """Pickle / unpickle our serverInfo before trying to use it """
251 ## SQLiteTable_Test.setUp(self)
252 ## self.serverInfo.close()
253 ## import pickle
254 ## s = pickle.dumps(self.serverInfo)
255 ## del self.serverInfo
256 ## self.serverInfo = pickle.loads(s)
257 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
258 ## self.sourceDB = self.tableClass(self.joinTable1,
259 ## serverInfo=self.serverInfo)
260 ## self.targetDB = self.tableClass(self.joinTable2,
261 ## serverInfo=self.serverInfo)
263 class SQLTable_NoCache_Test(SQLTable_Test):
264 tableClass = SQLTableNoCache
266 class SQLiteTable_NoCache_Test(SQLiteTable_Test):
267 tableClass = SQLTableNoCache
269 class SQLTableRW_Test(SQLTable_Setup):
270 'test write operations'
271 writeable = True
272 def test_new(self):
273 'test row creation with auto inc ID'
274 n = len(self.db)
275 o = self.db.new(seq_id='freddy', start=3000, stop=4500)
276 assert len(self.db) == n + 1
277 t = self.tableClass(self.tableName,
278 serverInfo=self.serverInfo) # requery the db
279 result = t[o.id]
280 assert result.seq_id == 'freddy' and result.start==3000 \
281 and result.stop==4500
282 def test_new2(self):
283 'check row creation with specified ID'
284 n = len(self.db)
285 o = self.db.new(id=99, seq_id='jeff', start=3000, stop=4500)
286 assert len(self.db) == n + 1
287 assert o.id == 99
288 t = self.tableClass(self.tableName,
289 serverInfo=self.serverInfo) # requery the db
290 result = t[99]
291 assert result.seq_id == 'jeff' and result.start==3000 \
292 and result.stop==4500
293 def test_attr(self):
294 'test changing an attr value'
295 o = self.db[2]
296 assert o.seq_id == 'seq2'
297 o.seq_id = 'newval' # overwrite this attribute
298 assert o.seq_id == 'newval' # check cached value
299 t = self.tableClass(self.tableName,
300 serverInfo=self.serverInfo) # requery the db
301 result = t[2]
302 assert result.seq_id == 'newval'
303 def test_delitem(self):
304 'test deletion of a row'
305 n = len(self.db)
306 del self.db[1]
307 assert len(self.db) == n - 1
308 try:
309 result = self.db[1]
310 raise AssertionError('old ID still exists!')
311 except KeyError:
312 pass
313 def test_setitem(self):
314 'test assigning new ID to existing object'
315 o = self.db.new(id=17, seq_id='bob', start=2000, stop=2500)
316 self.db[13] = o
317 assert o.id == 13
318 try:
319 result = self.db[17]
320 raise AssertionError('old ID still exists!')
321 except KeyError:
322 pass
323 t = self.tableClass(self.tableName,
324 serverInfo=self.serverInfo) # requery the db
325 result = t[13]
326 assert result.seq_id == 'bob' and result.start==2000 \
327 and result.stop==2500
328 try:
329 result = t[17]
330 raise AssertionError('old ID still exists!')
331 except KeyError:
332 pass
335 class SQLiteTableRW_Test(SQLiteBase, SQLTableRW_Test):
336 pass
338 class SQLTableRW_NoCache_Test(SQLTableRW_Test):
339 tableClass = SQLTableNoCache
341 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test):
342 tableClass = SQLTableNoCache
344 class Ensembl_Test(unittest.TestCase):
346 def setUp(self):
347 # test will be skipped if mysql module or ensembldb server unavailable
349 logger.debug('accessing ensembldb.ensembl.org')
350 conn = DBServerInfo(host='ensembldb.ensembl.org', user='anonymous',
351 passwd='')
352 try:
353 translationDB = SQLTableCatcher('homo_sapiens_core_47_36i.translation',
354 serverInfo=conn)
355 translationDB.catchIter = True # should not iter in this test!
356 exonDB = SQLTable('homo_sapiens_core_47_36i.exon', serverInfo=conn)
357 except ImportError,e:
358 raise SkipTest(e)
360 sql_statement = '''SELECT t3.exon_id FROM
361 homo_sapiens_core_47_36i.translation AS tr,
362 homo_sapiens_core_47_36i.exon_transcript AS t1,
363 homo_sapiens_core_47_36i.exon_transcript AS t2,
364 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
365 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
366 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
367 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
368 t3.rank <= t2.rank ORDER BY t3.rank
370 print 'creating GraphView...'
371 self.translationExons = GraphView(translationDB, exonDB,
372 sql_statement, serverInfo=conn)
373 print 'getting translation...'
374 self.translation = translationDB[15121]
376 def test_orderBy(self):
377 "Ensemble access, test order by"
378 'test issue 53: ensure that the ORDER BY results are correct'
379 print 'do mapping'
380 exons = self.translationExons[self.translation] # do the query
381 print 'done'
382 result = [e.id for e in exons]
383 correct = [95160,95020,95035,95050,95059,95069,95081,95088,95101,
384 95110,95172]
385 self.assertEqual(result, correct) # make sure the exact order matches
388 if __name__ == '__main__':
389 PygrTestProgram(verbosity=2)