1 import os
, random
, string
, unittest
2 from testlib
import testutil
, PygrTestProgram
, SkipTest
3 from pygr
.sqlgraph
import SQLTable
, SQLTableNoCache
,SQLTableClustered
,\
4 MapView
, GraphView
, DBServerInfo
, import_sqlite
5 from pygr
import logger
8 def catch_iterator(self
, *args
, **kwargs
):
10 assert not self
.catchIter
, 'this should not iterate!'
11 except AttributeError:
13 return klass
.generic_iterator(self
, *args
, **kwargs
)
17 class SQLTableCatcher(SQLTable
):
18 generic_iterator
= entrap(SQLTable
)
20 class SQLTableNoCacheCatcher(SQLTableNoCache
):
21 generic_iterator
= entrap(SQLTableNoCache
)
23 class SQLTableClusteredCatcher(SQLTableClustered
):
24 generic_iterator
= entrap(SQLTableClustered
)
26 class SQLTable_Setup(unittest
.TestCase
):
27 tableClass
= SQLTableCatcher
30 def __init__(self
, *args
, **kwargs
):
31 unittest
.TestCase
.__init
__(self
, *args
, **kwargs
)
32 # share conn for all tests
33 self
.serverInfo
= DBServerInfo(** self
.serverArgs
)
36 self
.load_data(writeable
=self
.writeable
, ** self
.loadArgs
)
38 raise SkipTest('missing MySQLdb module?')
39 def load_data(self
, tableName
='test.sqltable_test', writeable
=False,
40 dbargs
={},sourceDBargs
={},targetDBargs
={}):
41 'create 3 tables and load 9 rows for our tests'
42 self
.tableName
= tableName
43 self
.joinTable1
= joinTable1
= tableName
+ '1'
44 self
.joinTable2
= joinTable2
= tableName
+ '2'
46 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
48 self
.db
= self
.tableClass(tableName
, dropIfExists
=True,
49 serverInfo
=self
.serverInfo
,
50 createTable
=createTable
,
51 writeable
=writeable
, **dbargs
)
52 self
.sourceDB
= self
.tableClass(joinTable1
, serverInfo
=self
.serverInfo
,
53 dropIfExists
=True, createTable
="""\
54 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
56 """ % joinTable1
, **sourceDBargs
)
57 self
.targetDB
= self
.tableClass(joinTable2
, serverInfo
=self
.serverInfo
,
58 dropIfExists
=True, createTable
="""\
59 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
61 """ % joinTable2
, **targetDBargs
)
63 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
64 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
65 INSERT INTO %s VALUES (2,'seq2')
66 INSERT INTO %s VALUES (3,'seq3')
67 INSERT INTO %s VALUES (4,'seq4')
68 INSERT INTO %s VALUES (7, 'seq2')
69 INSERT INTO %s VALUES (99, 'seq3')
70 INSERT INTO %s VALUES (6, 'seq4')
71 INSERT INTO %s VALUES (8, 'seq4')
72 """ % tuple(([tableName
]*2) + ([joinTable1
]*3) + ([joinTable2
]*4))
73 for line
in sql
.strip().splitlines(): # insert our test data
74 self
.db
.cursor
.execute(line
.strip())
76 # Another table, for the "ORDER BY" test
77 self
.orderTable
= tableName
+ '_orderBy'
78 self
.db
.cursor
.execute("""\
79 CREATE TABLE %s (id INTEGER PRIMARY KEY, number INTEGER, letter VARCHAR(1))
80 """ % self
.orderTable
)
81 for row
in range(0, 10):
82 self
.db
.cursor
.execute('INSERT INTO %s VALUES (%d, %d, \'%s\')' %
83 (self
.orderTable
, row
, random
.randint(0, 99),
84 string
.lowercase
[random
.randint(0,
88 self
.db
.cursor
.execute('drop table if exists %s' % self
.tableName
)
89 self
.db
.cursor
.execute('drop table if exists %s' % self
.joinTable1
)
90 self
.db
.cursor
.execute('drop table if exists %s' % self
.joinTable2
)
91 self
.db
.cursor
.execute('drop table if exists %s' % self
.orderTable
)
92 self
.serverInfo
.close()
94 class SQLTable_Test(SQLTable_Setup
):
95 writeable
= False # read-only database interface
101 self
.db
.catchIter
= True
102 assert len(self
.db
) == len(self
.db
.keys())
103 def test_contains(self
):
104 self
.db
.catchIter
= True
107 assert 'foo' not in self
.db
108 def test_has_key(self
):
109 self
.db
.catchIter
= True
110 assert self
.db
.has_key(1)
111 assert self
.db
.has_key(2)
112 assert not self
.db
.has_key('foo')
114 self
.db
.catchIter
= True
115 assert self
.db
.get('foo') is None
116 assert self
.db
.get(1) == self
.db
[1]
117 assert self
.db
.get(2) == self
.db
[2]
118 def test_items(self
):
119 i
= [ k
for (k
,v
) in self
.db
.items() ]
122 def test_iterkeys(self
):
124 ik
= list(self
.db
.iterkeys())
126 def test_pickle(self
):
129 s
= pickle
.dumps(self
.db
)
132 ik
= list(db
.iterkeys())
135 db
.serverInfo
.close() # close extra DB connection
136 def test_itervalues(self
):
137 kv
= self
.db
.values()
138 iv
= list(self
.db
.itervalues())
140 def test_itervalues_long(self
):
141 """test iterator isolation from queries run inside iterator loop """
142 sql
= 'insert into %s (start) values (1)' % self
.tableName
143 for i
in range(40000): # insert 40000 rows
144 self
.db
.cursor
.execute(sql
)
146 for o
in self
.db
.itervalues():
147 status
= 99 in self
.db
# make it do a query inside iterator loop
149 kv
= [o
.id for o
in self
.db
.values()]
150 assert len(kv
) == len(iv
)
152 def test_iteritems(self
):
154 ii
= list(self
.db
.iteritems())
156 def test_readonly(self
):
157 'test error handling of write attempts to read-only DB'
158 self
.db
.catchIter
= True # no iter expected in this test!
160 self
.db
.new(seq_id
='freddy', start
=3000, stop
=4500)
161 raise AssertionError('failed to trap attempt to write to db')
167 raise AssertionError('failed to trap attempt to write to db')
172 raise AssertionError('failed to trap attempt to write to db')
175 def test_orderBy(self
):
176 'test iterator with orderBy, iterSQL, iterColumns'
177 self
.targetDB
.catchIter
= True # should not iterate
178 self
.targetDB
.arraysize
= 2 # force it to use multiple queries to finish
179 result
= self
.targetDB
.keys()
180 assert result
== [6, 7, 8, 99]
181 self
.targetDB
.catchIter
= False # next statement will iterate
182 assert result
== list(iter(self
.targetDB
))
183 self
.targetDB
.catchIter
= True # should not iterate
184 self
.targetDB
.orderBy
= 'ORDER BY other_id'
185 result
= self
.targetDB
.keys()
186 assert result
== [7, 99, 6, 8]
187 self
.targetDB
.catchIter
= False # next statement will iterate
188 if self
.serverInfo
._serverType
== 'mysql' \
189 and self
.serverInfo
.custom_iter_keys
: # only test this for mysql
191 assert result
== list(iter(self
.targetDB
))
192 raise AssertionError('failed to trap missing iterSQL attr')
193 except AttributeError:
195 self
.targetDB
.iterSQL
= 'WHERE other_id>%s' # tell it how to slice
196 self
.targetDB
.iterColumns
= ['other_id']
197 assert result
== list(iter(self
.targetDB
))
198 result
= self
.targetDB
.values()
199 assert result
== [self
.targetDB
[7], self
.targetDB
[99],
200 self
.targetDB
[6], self
.targetDB
[8]]
201 assert result
== list(self
.targetDB
.itervalues())
202 result
= self
.targetDB
.items()
203 assert result
== [(7, self
.targetDB
[7]), (99, self
.targetDB
[99]),
204 (6, self
.targetDB
[6]), (8, self
.targetDB
[8])]
205 assert result
== list(self
.targetDB
.iteritems())
207 s
= pickle
.dumps(self
.targetDB
) # test pickling & unpickling
210 correct
= self
.targetDB
.keys()
211 result
= list(iter(db
))
212 assert result
== correct
214 db
.serverInfo
.close() # close extra DB connection
216 def test_orderby_random(self
):
217 'test orderBy in SQLTable'
218 if self
.serverInfo
._serverType
== 'mysql' \
219 and self
.serverInfo
.custom_iter_keys
:
221 byNumber
= self
.tableClass(self
.orderTable
, arraysize
=2,
222 serverInfo
=self
.serverInfo
,
223 orderBy
='ORDER BY number')
224 raise AssertionError('failed to trap orderBy without iterSQL!')
227 byNumber
= self
.tableClass(self
.orderTable
, serverInfo
=self
.serverInfo
,
228 arraysize
=2, orderBy
='ORDER BY number,id',
229 iterSQL
='WHERE number>%s or (number=%s and id>%s)',
230 iterColumns
=('number','number','id'))
231 bv
= [val
.number
for val
in byNumber
.values()]
234 assert sortedBV
== bv
235 bv
= [val
.number
for val
in byNumber
.itervalues()]
236 assert sortedBV
== bv
238 byLetter
= self
.tableClass(self
.orderTable
, serverInfo
=self
.serverInfo
,
239 arraysize
=2, orderBy
='ORDER BY letter,id',
240 iterSQL
='WHERE letter>%s or (letter=%s and id>%s)',
241 iterColumns
=('letter','letter','id'))
242 bl
= [val
.letter
for val
in byLetter
.values()]
244 assert sortedBL
== bl
245 bl
= [val
.letter
for val
in byLetter
.itervalues()]
246 assert sortedBL
== bl
248 ### @CTB need to test write access
249 def test_mapview(self
):
250 'test MapView of SQL join'
251 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
252 m
= MapView(self
.sourceDB
, self
.targetDB
,"""\
253 SELECT t2.third_id FROM %s t1, %s t2
254 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
255 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
)
256 assert m
[self
.sourceDB
[2]] == self
.targetDB
[7]
257 assert m
[self
.sourceDB
[3]] == self
.targetDB
[99]
258 assert self
.sourceDB
[2] in m
260 d
= m
[self
.sourceDB
[4]]
261 raise AssertionError('failed to trap non-unique mapping')
266 raise AssertionError('failed to trap non-invertible mapping')
269 self
.sourceDB
.cursor
.execute("INSERT INTO %s VALUES (5,'seq78')"
270 % self
.sourceDB
.name
)
271 assert len(self
.sourceDB
) == 4
272 self
.sourceDB
.catchIter
= False # next step will cause iteration
276 correct
= [self
.sourceDB
[2],self
.sourceDB
[3]]
279 def test_mapview_inverse(self
):
280 'test inverse MapView of SQL join'
281 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
282 m
= MapView(self
.sourceDB
, self
.targetDB
,"""\
283 SELECT t2.third_id FROM %s t1, %s t2
284 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
285 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
,
287 SELECT t1.my_id FROM %s t1, %s t2
288 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
289 """ % (self
.joinTable1
,self
.joinTable2
))
290 r
= ~m
# get the inverse
291 assert self
.sourceDB
[2] == r
[self
.targetDB
[7]]
292 assert self
.sourceDB
[3] == r
[self
.targetDB
[99]]
293 assert self
.targetDB
[7] in r
295 m
= ~r
# get the inverse of the inverse!
296 assert m
[self
.sourceDB
[2]] == self
.targetDB
[7]
297 assert m
[self
.sourceDB
[3]] == self
.targetDB
[99]
298 assert self
.sourceDB
[2] in m
300 d
= m
[self
.sourceDB
[4]]
301 raise AssertionError('failed to trap non-unique mapping')
304 def test_graphview(self
):
305 'test GraphView of SQL join'
306 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
307 m
= GraphView(self
.sourceDB
, self
.targetDB
,"""\
308 SELECT t2.third_id FROM %s t1, %s t2
309 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
310 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
)
311 d
= m
[self
.sourceDB
[4]]
313 assert self
.targetDB
[6] in d
and self
.targetDB
[8] in d
314 assert self
.sourceDB
[2] in m
316 self
.sourceDB
.cursor
.execute("INSERT INTO %s VALUES (5,'seq78')"
317 % self
.sourceDB
.name
)
318 assert len(self
.sourceDB
) == 4
319 self
.sourceDB
.catchIter
= False # next step will cause iteration
323 correct
= [self
.sourceDB
[2],self
.sourceDB
[3],self
.sourceDB
[4]]
327 def test_graphview_inverse(self
):
328 'test inverse GraphView of SQL join'
329 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
330 m
= GraphView(self
.sourceDB
, self
.targetDB
,"""\
331 SELECT t2.third_id FROM %s t1, %s t2
332 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
333 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
,
335 SELECT t1.my_id FROM %s t1, %s t2
336 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
337 """ % (self
.joinTable1
,self
.joinTable2
))
338 r
= ~m
# get the inverse
339 assert self
.sourceDB
[2] in r
[self
.targetDB
[7]]
340 assert self
.sourceDB
[3] in r
[self
.targetDB
[99]]
341 assert self
.targetDB
[7] in r
342 d
= r
[self
.targetDB
[6]]
344 assert self
.sourceDB
[4] in d
346 m
= ~r
# get inverse of the inverse!
347 d
= m
[self
.sourceDB
[4]]
349 assert self
.targetDB
[6] in d
and self
.targetDB
[8] in d
350 assert self
.sourceDB
[2] in m
352 class SQLTable_No_SSCursor_Test(SQLTable_Test
):
353 serverArgs
= dict(serverSideCursors
=False)
355 class SQLTable_OldIter_Test(SQLTable_Test
):
356 serverArgs
= dict(serverSideCursors
=False,
357 blockIterators
=False)
359 class SQLiteBase(testutil
.SQLite_Mixin
):
360 def sqlite_load(self
):
361 self
.load_data('sqltable_test', writeable
=self
.writeable
)
363 class SQLiteTable_Test(SQLiteBase
, SQLTable_Test
):
366 ## class SQLitePickle_Test(SQLiteTable_Test):
368 ## """Pickle / unpickle our serverInfo before trying to use it """
369 ## SQLiteTable_Test.setUp(self)
370 ## self.serverInfo.close()
372 ## s = pickle.dumps(self.serverInfo)
373 ## del self.serverInfo
374 ## self.serverInfo = pickle.loads(s)
375 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
376 ## self.sourceDB = self.tableClass(self.joinTable1,
377 ## serverInfo=self.serverInfo)
378 ## self.targetDB = self.tableClass(self.joinTable2,
379 ## serverInfo=self.serverInfo)
381 class SQLTable_NoCache_Test(SQLTable_Test
):
382 tableClass
= SQLTableNoCacheCatcher
384 class SQLTableClustered_Test(SQLTable_Test
):
385 tableClass
= SQLTableClusteredCatcher
386 loadArgs
= dict(dbargs
=dict(clusterKey
='seq_id', arraysize
=2),
387 sourceDBargs
=dict(clusterKey
='other_id', arraysize
=2),
388 targetDBargs
=dict(clusterKey
='other_id', arraysize
=2))
389 def test_orderBy(self
): # neither of these tests useful in this context
391 def test_orderby_random(self
):
394 class SQLiteTable_NoCache_Test(SQLiteTable_Test
):
395 tableClass
= SQLTableNoCache
397 class SQLTableRW_Test(SQLTable_Setup
):
398 'test write operations'
401 'test row creation with auto inc ID'
402 self
.db
.catchIter
= True # no iter expected in this test
404 o
= self
.db
.new(seq_id
='freddy', start
=3000, stop
=4500)
405 assert len(self
.db
) == n
+ 1
406 t
= self
.tableClass(self
.tableName
,
407 serverInfo
=self
.serverInfo
) # requery the db
408 t
.catchIter
= True # no iter expected in this test
410 assert result
.seq_id
== 'freddy' and result
.start
==3000 \
411 and result
.stop
==4500
413 'check row creation with specified ID'
414 self
.db
.catchIter
= True # no iter expected in this test
416 o
= self
.db
.new(id=99, seq_id
='jeff', start
=3000, stop
=4500)
417 assert len(self
.db
) == n
+ 1
419 t
= self
.tableClass(self
.tableName
,
420 serverInfo
=self
.serverInfo
) # requery the db
421 t
.catchIter
= True # no iter expected in this test
423 assert result
.seq_id
== 'jeff' and result
.start
==3000 \
424 and result
.stop
==4500
426 'test changing an attr value'
427 self
.db
.catchIter
= True # no iter expected in this test
429 assert o
.seq_id
== 'seq2'
430 o
.seq_id
= 'newval' # overwrite this attribute
431 assert o
.seq_id
== 'newval' # check cached value
432 t
= self
.tableClass(self
.tableName
,
433 serverInfo
=self
.serverInfo
) # requery the db
434 t
.catchIter
= True # no iter expected in this test
436 assert result
.seq_id
== 'newval'
437 def test_delitem(self
):
438 'test deletion of a row'
439 self
.db
.catchIter
= True # no iter expected in this test
442 assert len(self
.db
) == n
- 1
445 raise AssertionError('old ID still exists!')
448 def test_setitem(self
):
449 'test assigning new ID to existing object'
450 self
.db
.catchIter
= True # no iter expected in this test
451 o
= self
.db
.new(id=17, seq_id
='bob', start
=2000, stop
=2500)
456 raise AssertionError('old ID still exists!')
459 t
= self
.tableClass(self
.tableName
,
460 serverInfo
=self
.serverInfo
) # requery the db
461 t
.catchIter
= True # no iter expected in this test
463 assert result
.seq_id
== 'bob' and result
.start
==2000 \
464 and result
.stop
==2500
467 raise AssertionError('old ID still exists!')
472 class SQLiteTableRW_Test(SQLiteBase
, SQLTableRW_Test
):
475 class SQLTableRW_NoCache_Test(SQLTableRW_Test
):
476 tableClass
= SQLTableNoCache
478 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test
):
479 tableClass
= SQLTableNoCache
481 class Ensembl_Test(unittest
.TestCase
):
484 # test will be skipped if mysql module or ensembldb server unavailable
486 logger
.debug('accessing ensembldb.ensembl.org')
487 conn
= DBServerInfo(host
='ensembldb.ensembl.org', user
='anonymous',
490 translationDB
= SQLTableCatcher('homo_sapiens_core_47_36i.translation',
492 translationDB
.catchIter
= True # should not iter in this test!
493 exonDB
= SQLTable('homo_sapiens_core_47_36i.exon', serverInfo
=conn
)
494 except ImportError,e
:
497 sql_statement
= '''SELECT t3.exon_id FROM
498 homo_sapiens_core_47_36i.translation AS tr,
499 homo_sapiens_core_47_36i.exon_transcript AS t1,
500 homo_sapiens_core_47_36i.exon_transcript AS t2,
501 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
502 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
503 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
504 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
505 t3.rank <= t2.rank ORDER BY t3.rank
507 self
.translationExons
= GraphView(translationDB
, exonDB
,
508 sql_statement
, serverInfo
=conn
)
509 self
.translation
= translationDB
[15121]
511 def test_orderBy(self
):
512 "Ensemble access, test order by"
513 'test issue 53: ensure that the ORDER BY results are correct'
514 exons
= self
.translationExons
[self
.translation
] # do the query
515 result
= [e
.id for e
in exons
]
516 correct
= [95160,95020,95035,95050,95059,95069,95081,95088,95101,
518 self
.assertEqual(result
, correct
) # make sure the exact order matches
521 if __name__
== '__main__':
522 PygrTestProgram(verbosity
=2)