1 import os
, random
, string
, unittest
2 from testlib
import testutil
, PygrTestProgram
, SkipTest
3 from pygr
.sqlgraph
import SQLTable
, SQLTableNoCache
,\
4 MapView
, GraphView
, DBServerInfo
, import_sqlite
5 from pygr
import logger
8 def catch_iterator(self
, *args
, **kwargs
):
10 assert not self
.catchIter
, 'this should not iterate!'
11 except AttributeError:
13 return SQLTable
.generic_iterator(self
, *args
, **kwargs
)
16 class SQLTableCatcher(SQLTable
):
17 generic_iterator
= catch_iterator
19 class SQLTableNoCacheCatcher(SQLTableNoCache
):
20 generic_iterator
= catch_iterator
22 class SQLTable_Setup(unittest
.TestCase
):
23 tableClass
= SQLTableCatcher
25 def __init__(self
, *args
, **kwargs
):
26 unittest
.TestCase
.__init
__(self
, *args
, **kwargs
)
27 # share conn for all tests
28 self
.serverInfo
= DBServerInfo(** self
.serverArgs
)
31 self
.load_data(writeable
=self
.writeable
)
33 raise SkipTest('missing MySQLdb module?')
34 def load_data(self
, tableName
='test.sqltable_test', writeable
=False):
35 'create 3 tables and load 9 rows for our tests'
36 self
.tableName
= tableName
37 self
.joinTable1
= joinTable1
= tableName
+ '1'
38 self
.joinTable2
= joinTable2
= tableName
+ '2'
40 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
42 self
.db
= self
.tableClass(tableName
, dropIfExists
=True,
43 serverInfo
=self
.serverInfo
,
44 createTable
=createTable
,
46 self
.sourceDB
= self
.tableClass(joinTable1
, serverInfo
=self
.serverInfo
,
47 dropIfExists
=True, createTable
="""\
48 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
51 self
.targetDB
= self
.tableClass(joinTable2
, serverInfo
=self
.serverInfo
,
52 dropIfExists
=True, createTable
="""\
53 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
57 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
58 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
59 INSERT INTO %s VALUES (2,'seq2')
60 INSERT INTO %s VALUES (3,'seq3')
61 INSERT INTO %s VALUES (4,'seq4')
62 INSERT INTO %s VALUES (7, 'seq2')
63 INSERT INTO %s VALUES (99, 'seq3')
64 INSERT INTO %s VALUES (6, 'seq4')
65 INSERT INTO %s VALUES (8, 'seq4')
66 """ % tuple(([tableName
]*2) + ([joinTable1
]*3) + ([joinTable2
]*4))
67 for line
in sql
.strip().splitlines(): # insert our test data
68 self
.db
.cursor
.execute(line
.strip())
70 # Another table, for the "ORDER BY" test
71 self
.orderTable
= tableName
+ '_orderBy'
72 self
.db
.cursor
.execute("""\
73 CREATE TABLE %s (id INTEGER PRIMARY KEY, number INTEGER, letter VARCHAR(1))
74 """ % self
.orderTable
)
75 for row
in range(0, 10):
76 self
.db
.cursor
.execute('INSERT INTO %s VALUES (%d, %d, \'%s\')' %
77 (self
.orderTable
, row
, random
.randint(0, 99),
78 string
.lowercase
[random
.randint(0,
82 self
.db
.cursor
.execute('drop table if exists %s' % self
.tableName
)
83 self
.db
.cursor
.execute('drop table if exists %s' % self
.joinTable1
)
84 self
.db
.cursor
.execute('drop table if exists %s' % self
.joinTable2
)
85 self
.db
.cursor
.execute('drop table if exists %s' % self
.orderTable
)
86 self
.serverInfo
.close()
88 class SQLTable_Test(SQLTable_Setup
):
89 writeable
= False # read-only database interface
95 self
.db
.catchIter
= True
96 assert len(self
.db
) == len(self
.db
.keys())
97 def test_contains(self
):
98 self
.db
.catchIter
= True
101 assert 'foo' not in self
.db
102 def test_has_key(self
):
103 self
.db
.catchIter
= True
104 assert self
.db
.has_key(1)
105 assert self
.db
.has_key(2)
106 assert not self
.db
.has_key('foo')
108 self
.db
.catchIter
= True
109 assert self
.db
.get('foo') is None
110 assert self
.db
.get(1) == self
.db
[1]
111 assert self
.db
.get(2) == self
.db
[2]
112 def test_items(self
):
113 i
= [ k
for (k
,v
) in self
.db
.items() ]
116 def test_iterkeys(self
):
118 ik
= list(self
.db
.iterkeys())
120 def test_pickle(self
):
123 s
= pickle
.dumps(self
.db
)
126 ik
= list(db
.iterkeys())
129 db
.serverInfo
.close() # close extra DB connection
130 def test_itervalues(self
):
131 kv
= self
.db
.values()
132 iv
= list(self
.db
.itervalues())
134 def test_itervalues_long(self
):
135 """test iterator isolation from queries run inside iterator loop """
136 sql
= 'insert into %s (start) values (1)' % self
.tableName
137 for i
in range(40000): # insert 40000 rows
138 self
.db
.cursor
.execute(sql
)
140 for o
in self
.db
.itervalues():
141 status
= 99 in self
.db
# make it do a query inside iterator loop
143 kv
= [o
.id for o
in self
.db
.values()]
144 assert len(kv
) == len(iv
)
146 def test_iteritems(self
):
148 ii
= list(self
.db
.iteritems())
150 def test_readonly(self
):
151 'test error handling of write attempts to read-only DB'
152 self
.db
.catchIter
= True # no iter expected in this test!
154 self
.db
.new(seq_id
='freddy', start
=3000, stop
=4500)
155 raise AssertionError('failed to trap attempt to write to db')
161 raise AssertionError('failed to trap attempt to write to db')
166 raise AssertionError('failed to trap attempt to write to db')
169 def test_orderBy(self
):
170 'test iterator with orderBy, iterSQL, iterColumns'
171 self
.targetDB
.catchIter
= True # should not iterate
172 self
.targetDB
.arraysize
= 2 # force it to use multiple queries to finish
173 result
= self
.targetDB
.keys()
174 assert result
== [6, 7, 8, 99]
175 self
.targetDB
.catchIter
= False # next statement will iterate
176 assert result
== list(iter(self
.targetDB
))
177 self
.targetDB
.catchIter
= True # should not iterate
178 self
.targetDB
.orderBy
= 'ORDER BY other_id'
179 result
= self
.targetDB
.keys()
180 assert result
== [7, 99, 6, 8]
181 self
.targetDB
.catchIter
= False # next statement will iterate
182 if self
.serverInfo
._serverType
== 'mysql' \
183 and self
.serverInfo
.custom_iter_keys
: # only test this for mysql
185 assert result
== list(iter(self
.targetDB
))
186 raise AssertionError('failed to trap missing iterSQL attr')
187 except AttributeError:
189 self
.targetDB
.iterSQL
= 'WHERE other_id>%s' # tell it how to slice
190 self
.targetDB
.iterColumns
= ['other_id']
191 assert result
== list(iter(self
.targetDB
))
192 result
= self
.targetDB
.values()
193 assert result
== [self
.targetDB
[7], self
.targetDB
[99],
194 self
.targetDB
[6], self
.targetDB
[8]]
195 assert result
== list(self
.targetDB
.itervalues())
196 result
= self
.targetDB
.items()
197 assert result
== [(7, self
.targetDB
[7]), (99, self
.targetDB
[99]),
198 (6, self
.targetDB
[6]), (8, self
.targetDB
[8])]
199 assert result
== list(self
.targetDB
.iteritems())
201 s
= pickle
.dumps(self
.targetDB
) # test pickling & unpickling
204 correct
= self
.targetDB
.keys()
205 result
= list(iter(db
))
206 assert result
== correct
208 db
.serverInfo
.close() # close extra DB connection
210 def test_orderby_random(self
):
211 'test orderBy in SQLTable'
212 if self
.serverInfo
._serverType
== 'mysql' \
213 and self
.serverInfo
.custom_iter_keys
:
215 byNumber
= self
.tableClass(self
.orderTable
, arraysize
=2,
216 serverInfo
=self
.serverInfo
,
217 orderBy
='ORDER BY number')
218 raise AssertionError('failed to trap orderBy without iterSQL!')
221 byNumber
= self
.tableClass(self
.orderTable
, serverInfo
=self
.serverInfo
,
222 arraysize
=2, orderBy
='ORDER BY number,id',
223 iterSQL
='WHERE number>%s or (number=%s and id>%s)',
224 iterColumns
=('number','number','id'))
225 bv
= [val
.number
for val
in byNumber
.values()]
228 assert sortedBV
== bv
229 bv
= [val
.number
for val
in byNumber
.itervalues()]
230 assert sortedBV
== bv
232 byLetter
= self
.tableClass(self
.orderTable
, serverInfo
=self
.serverInfo
,
233 arraysize
=2, orderBy
='ORDER BY letter,id',
234 iterSQL
='WHERE letter>%s or (letter=%s and id>%s)',
235 iterColumns
=('letter','letter','id'))
236 bl
= [val
.letter
for val
in byLetter
.values()]
238 assert sortedBL
== bl
239 bl
= [val
.letter
for val
in byLetter
.itervalues()]
240 assert sortedBL
== bl
242 ### @CTB need to test write access
243 def test_mapview(self
):
244 'test MapView of SQL join'
245 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
246 m
= MapView(self
.sourceDB
, self
.targetDB
,"""\
247 SELECT t2.third_id FROM %s t1, %s t2
248 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
249 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
)
250 assert m
[self
.sourceDB
[2]] == self
.targetDB
[7]
251 assert m
[self
.sourceDB
[3]] == self
.targetDB
[99]
252 assert self
.sourceDB
[2] in m
254 d
= m
[self
.sourceDB
[4]]
255 raise AssertionError('failed to trap non-unique mapping')
260 raise AssertionError('failed to trap non-invertible mapping')
263 self
.sourceDB
.cursor
.execute("INSERT INTO %s VALUES (5,'seq78')"
264 % self
.sourceDB
.name
)
265 assert len(self
.sourceDB
) == 4
266 self
.sourceDB
.catchIter
= False # next step will cause iteration
270 correct
= [self
.sourceDB
[2],self
.sourceDB
[3]]
273 def test_mapview_inverse(self
):
274 'test inverse MapView of SQL join'
275 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
276 m
= MapView(self
.sourceDB
, self
.targetDB
,"""\
277 SELECT t2.third_id FROM %s t1, %s t2
278 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
279 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
,
281 SELECT t1.my_id FROM %s t1, %s t2
282 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
283 """ % (self
.joinTable1
,self
.joinTable2
))
284 r
= ~m
# get the inverse
285 assert self
.sourceDB
[2] == r
[self
.targetDB
[7]]
286 assert self
.sourceDB
[3] == r
[self
.targetDB
[99]]
287 assert self
.targetDB
[7] in r
289 m
= ~r
# get the inverse of the inverse!
290 assert m
[self
.sourceDB
[2]] == self
.targetDB
[7]
291 assert m
[self
.sourceDB
[3]] == self
.targetDB
[99]
292 assert self
.sourceDB
[2] in m
294 d
= m
[self
.sourceDB
[4]]
295 raise AssertionError('failed to trap non-unique mapping')
298 def test_graphview(self
):
299 'test GraphView of SQL join'
300 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
301 m
= GraphView(self
.sourceDB
, self
.targetDB
,"""\
302 SELECT t2.third_id FROM %s t1, %s t2
303 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
304 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
)
305 d
= m
[self
.sourceDB
[4]]
307 assert self
.targetDB
[6] in d
and self
.targetDB
[8] in d
308 assert self
.sourceDB
[2] in m
310 self
.sourceDB
.cursor
.execute("INSERT INTO %s VALUES (5,'seq78')"
311 % self
.sourceDB
.name
)
312 assert len(self
.sourceDB
) == 4
313 self
.sourceDB
.catchIter
= False # next step will cause iteration
317 correct
= [self
.sourceDB
[2],self
.sourceDB
[3],self
.sourceDB
[4]]
321 def test_graphview_inverse(self
):
322 'test inverse GraphView of SQL join'
323 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
324 m
= GraphView(self
.sourceDB
, self
.targetDB
,"""\
325 SELECT t2.third_id FROM %s t1, %s t2
326 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
327 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
,
329 SELECT t1.my_id FROM %s t1, %s t2
330 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
331 """ % (self
.joinTable1
,self
.joinTable2
))
332 r
= ~m
# get the inverse
333 assert self
.sourceDB
[2] in r
[self
.targetDB
[7]]
334 assert self
.sourceDB
[3] in r
[self
.targetDB
[99]]
335 assert self
.targetDB
[7] in r
336 d
= r
[self
.targetDB
[6]]
338 assert self
.sourceDB
[4] in d
340 m
= ~r
# get inverse of the inverse!
341 d
= m
[self
.sourceDB
[4]]
343 assert self
.targetDB
[6] in d
and self
.targetDB
[8] in d
344 assert self
.sourceDB
[2] in m
346 class SQLTable_No_SSCursor_Test(SQLTable_Test
):
347 serverArgs
= dict(serverSideCursors
=False)
349 class SQLTable_OldIter_Test(SQLTable_Test
):
350 serverArgs
= dict(serverSideCursors
=False,
351 blockIterators
=False)
353 class SQLiteBase(testutil
.SQLite_Mixin
):
354 def sqlite_load(self
):
355 self
.load_data('sqltable_test', writeable
=self
.writeable
)
357 class SQLiteTable_Test(SQLiteBase
, SQLTable_Test
):
360 ## class SQLitePickle_Test(SQLiteTable_Test):
362 ## """Pickle / unpickle our serverInfo before trying to use it """
363 ## SQLiteTable_Test.setUp(self)
364 ## self.serverInfo.close()
366 ## s = pickle.dumps(self.serverInfo)
367 ## del self.serverInfo
368 ## self.serverInfo = pickle.loads(s)
369 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
370 ## self.sourceDB = self.tableClass(self.joinTable1,
371 ## serverInfo=self.serverInfo)
372 ## self.targetDB = self.tableClass(self.joinTable2,
373 ## serverInfo=self.serverInfo)
375 class SQLTable_NoCache_Test(SQLTable_Test
):
376 tableClass
= SQLTableNoCache
378 class SQLiteTable_NoCache_Test(SQLiteTable_Test
):
379 tableClass
= SQLTableNoCache
381 class SQLTableRW_Test(SQLTable_Setup
):
382 'test write operations'
385 'test row creation with auto inc ID'
386 self
.db
.catchIter
= True # no iter expected in this test
388 o
= self
.db
.new(seq_id
='freddy', start
=3000, stop
=4500)
389 assert len(self
.db
) == n
+ 1
390 t
= self
.tableClass(self
.tableName
,
391 serverInfo
=self
.serverInfo
) # requery the db
392 t
.catchIter
= True # no iter expected in this test
394 assert result
.seq_id
== 'freddy' and result
.start
==3000 \
395 and result
.stop
==4500
397 'check row creation with specified ID'
398 self
.db
.catchIter
= True # no iter expected in this test
400 o
= self
.db
.new(id=99, seq_id
='jeff', start
=3000, stop
=4500)
401 assert len(self
.db
) == n
+ 1
403 t
= self
.tableClass(self
.tableName
,
404 serverInfo
=self
.serverInfo
) # requery the db
405 t
.catchIter
= True # no iter expected in this test
407 assert result
.seq_id
== 'jeff' and result
.start
==3000 \
408 and result
.stop
==4500
410 'test changing an attr value'
411 self
.db
.catchIter
= True # no iter expected in this test
413 assert o
.seq_id
== 'seq2'
414 o
.seq_id
= 'newval' # overwrite this attribute
415 assert o
.seq_id
== 'newval' # check cached value
416 t
= self
.tableClass(self
.tableName
,
417 serverInfo
=self
.serverInfo
) # requery the db
418 t
.catchIter
= True # no iter expected in this test
420 assert result
.seq_id
== 'newval'
421 def test_delitem(self
):
422 'test deletion of a row'
423 self
.db
.catchIter
= True # no iter expected in this test
426 assert len(self
.db
) == n
- 1
429 raise AssertionError('old ID still exists!')
432 def test_setitem(self
):
433 'test assigning new ID to existing object'
434 self
.db
.catchIter
= True # no iter expected in this test
435 o
= self
.db
.new(id=17, seq_id
='bob', start
=2000, stop
=2500)
440 raise AssertionError('old ID still exists!')
443 t
= self
.tableClass(self
.tableName
,
444 serverInfo
=self
.serverInfo
) # requery the db
445 t
.catchIter
= True # no iter expected in this test
447 assert result
.seq_id
== 'bob' and result
.start
==2000 \
448 and result
.stop
==2500
451 raise AssertionError('old ID still exists!')
456 class SQLiteTableRW_Test(SQLiteBase
, SQLTableRW_Test
):
459 class SQLTableRW_NoCache_Test(SQLTableRW_Test
):
460 tableClass
= SQLTableNoCache
462 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test
):
463 tableClass
= SQLTableNoCache
465 class Ensembl_Test(unittest
.TestCase
):
468 # test will be skipped if mysql module or ensembldb server unavailable
470 logger
.debug('accessing ensembldb.ensembl.org')
471 conn
= DBServerInfo(host
='ensembldb.ensembl.org', user
='anonymous',
474 translationDB
= SQLTableCatcher('homo_sapiens_core_47_36i.translation',
476 translationDB
.catchIter
= True # should not iter in this test!
477 exonDB
= SQLTable('homo_sapiens_core_47_36i.exon', serverInfo
=conn
)
478 except ImportError,e
:
481 sql_statement
= '''SELECT t3.exon_id FROM
482 homo_sapiens_core_47_36i.translation AS tr,
483 homo_sapiens_core_47_36i.exon_transcript AS t1,
484 homo_sapiens_core_47_36i.exon_transcript AS t2,
485 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
486 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
487 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
488 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
489 t3.rank <= t2.rank ORDER BY t3.rank
491 self
.translationExons
= GraphView(translationDB
, exonDB
,
492 sql_statement
, serverInfo
=conn
)
493 self
.translation
= translationDB
[15121]
495 def test_orderBy(self
):
496 "Ensemble access, test order by"
497 'test issue 53: ensure that the ORDER BY results are correct'
498 exons
= self
.translationExons
[self
.translation
] # do the query
499 result
= [e
.id for e
in exons
]
500 correct
= [95160,95020,95035,95050,95059,95069,95081,95088,95101,
502 self
.assertEqual(result
, correct
) # make sure the exact order matches
505 if __name__
== '__main__':
506 PygrTestProgram(verbosity
=2)