1 import os
, random
, string
, unittest
2 from testlib
import testutil
, PygrTestProgram
, SkipTest
3 from pygr
.sqlgraph
import SQLTable
, SQLTableNoCache
,\
4 MapView
, GraphView
, DBServerInfo
, import_sqlite
5 from pygr
import logger
8 def catch_iterator(self
, *args
, **kwargs
):
10 assert not self
.catchIter
, 'this should not iterate!'
11 except AttributeError:
13 return SQLTable
.generic_iterator(self
, *args
, **kwargs
)
16 class SQLTableCatcher(SQLTable
):
17 generic_iterator
= catch_iterator
19 class SQLTableNoCacheCatcher(SQLTableNoCache
):
20 generic_iterator
= catch_iterator
22 class SQLTable_Setup(unittest
.TestCase
):
23 tableClass
= SQLTableCatcher
24 def __init__(self
, *args
, **kwargs
):
25 unittest
.TestCase
.__init
__(self
, *args
, **kwargs
)
26 self
.serverInfo
= DBServerInfo() # share conn for all tests
29 self
.load_data(writeable
=self
.writeable
)
31 raise SkipTest('missing MySQLdb module?')
32 def load_data(self
, tableName
='test.sqltable_test', writeable
=False):
33 'create 3 tables and load 9 rows for our tests'
34 self
.tableName
= tableName
35 self
.joinTable1
= joinTable1
= tableName
+ '1'
36 self
.joinTable2
= joinTable2
= tableName
+ '2'
38 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
40 self
.db
= self
.tableClass(tableName
, dropIfExists
=True,
41 serverInfo
=self
.serverInfo
,
42 createTable
=createTable
,
44 self
.sourceDB
= self
.tableClass(joinTable1
, serverInfo
=self
.serverInfo
,
45 dropIfExists
=True, createTable
="""\
46 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
49 self
.targetDB
= self
.tableClass(joinTable2
, serverInfo
=self
.serverInfo
,
50 dropIfExists
=True, createTable
="""\
51 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
55 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
56 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
57 INSERT INTO %s VALUES (2,'seq2')
58 INSERT INTO %s VALUES (3,'seq3')
59 INSERT INTO %s VALUES (4,'seq4')
60 INSERT INTO %s VALUES (7, 'seq2')
61 INSERT INTO %s VALUES (99, 'seq3')
62 INSERT INTO %s VALUES (6, 'seq4')
63 INSERT INTO %s VALUES (8, 'seq4')
64 """ % tuple(([tableName
]*2) + ([joinTable1
]*3) + ([joinTable2
]*4))
65 for line
in sql
.strip().splitlines(): # insert our test data
66 self
.db
.cursor
.execute(line
.strip())
68 # Another table, for the "ORDER BY" test
69 self
.orderTable
= tableName
+ '_orderBy'
70 self
.db
.cursor
.execute("""\
71 CREATE TABLE %s (id INTEGER PRIMARY KEY, number INTEGER, letter VARCHAR(1))
72 """ % self
.orderTable
)
73 for row
in range(0, 10):
74 self
.db
.cursor
.execute('INSERT INTO %s VALUES (%d, %d, \'%s\')' %
75 (self
.orderTable
, row
, random
.randint(0, 99),
76 string
.lowercase
[random
.randint(0,
80 self
.db
.cursor
.execute('drop table if exists %s' % self
.tableName
)
81 self
.db
.cursor
.execute('drop table if exists %s' % self
.joinTable1
)
82 self
.db
.cursor
.execute('drop table if exists %s' % self
.joinTable2
)
83 self
.db
.cursor
.execute('drop table if exists %s' % self
.orderTable
)
84 self
.serverInfo
.close()
86 class SQLTable_Test(SQLTable_Setup
):
87 writeable
= False # read-only database interface
93 self
.db
.catchIter
= True
94 assert len(self
.db
) == len(self
.db
.keys())
95 def test_contains(self
):
96 self
.db
.catchIter
= True
99 assert 'foo' not in self
.db
100 def test_has_key(self
):
101 self
.db
.catchIter
= True
102 assert self
.db
.has_key(1)
103 assert self
.db
.has_key(2)
104 assert not self
.db
.has_key('foo')
106 self
.db
.catchIter
= True
107 assert self
.db
.get('foo') is None
108 assert self
.db
.get(1) == self
.db
[1]
109 assert self
.db
.get(2) == self
.db
[2]
110 def test_items(self
):
111 i
= [ k
for (k
,v
) in self
.db
.items() ]
114 def test_iterkeys(self
):
117 ik
= list(self
.db
.iterkeys())
120 def test_itervalues(self
):
121 kv
= self
.db
.values()
123 iv
= list(self
.db
.itervalues())
126 def test_itervalues_long(self
):
127 """test iterator isolation from queries run inside iterator loop """
128 sql
= 'insert into %s (start) values (1)' % self
.tableName
129 for i
in range(40000): # insert 40000 rows
130 self
.db
.cursor
.execute(sql
)
132 print 'begin itervalues()'
133 for o
in self
.db
.itervalues():
134 status
= 99 in self
.db
# make it do a query inside iterator loop
136 print 'begin values()'
137 kv
= [o
.id for o
in self
.db
.values()]
138 assert len(kv
) == len(iv
)
141 def test_iteritems(self
):
144 ii
= list(self
.db
.iteritems())
147 def test_readonly(self
):
148 'test error handling of write attempts to read-only DB'
149 self
.db
.catchIter
= True # no iter expected in this test!
151 self
.db
.new(seq_id
='freddy', start
=3000, stop
=4500)
152 raise AssertionError('failed to trap attempt to write to db')
158 raise AssertionError('failed to trap attempt to write to db')
163 raise AssertionError('failed to trap attempt to write to db')
166 def test_orderBy(self
):
167 'test iterator with orderBy, iterSQL, iterColumns'
168 self
.targetDB
.catchIter
= True # should not iterate
169 self
.targetDB
.arraysize
= 2 # force it to use multiple queries to finish
170 result
= self
.targetDB
.keys()
171 assert result
== [6, 7, 8, 99]
172 self
.targetDB
.catchIter
= False # next statement will iterate
173 assert result
== list(iter(self
.targetDB
))
174 self
.targetDB
.catchIter
= True # should not iterate
175 self
.targetDB
.orderBy
= 'ORDER BY other_id'
176 result
= self
.targetDB
.keys()
177 assert result
== [7, 99, 6, 8]
178 self
.targetDB
.catchIter
= False # next statement will iterate
179 if self
.serverInfo
._serverType
== 'mysql': # only test this for mysql
181 assert result
== list(iter(self
.targetDB
))
182 raise AssertionError('failed to trap missing iterSQL attr')
183 except AttributeError:
185 self
.targetDB
.iterSQL
= 'WHERE other_id>%s' # tell it how to slice
186 self
.targetDB
.iterColumns
= ['other_id']
187 assert result
== list(iter(self
.targetDB
))
188 result
= self
.targetDB
.values()
189 assert result
== [self
.targetDB
[7], self
.targetDB
[99],
190 self
.targetDB
[6], self
.targetDB
[8]]
191 assert result
== list(self
.targetDB
.itervalues())
192 result
= self
.targetDB
.items()
193 assert result
== [(7, self
.targetDB
[7]), (99, self
.targetDB
[99]),
194 (6, self
.targetDB
[6]), (8, self
.targetDB
[8])]
195 assert result
== list(self
.targetDB
.iteritems())
197 ### @CTB need to test write access
198 def test_mapview(self
):
199 'test MapView of SQL join'
200 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
201 m
= MapView(self
.sourceDB
, self
.targetDB
,"""\
202 SELECT t2.third_id FROM %s t1, %s t2
203 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
204 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
)
205 assert m
[self
.sourceDB
[2]] == self
.targetDB
[7]
206 assert m
[self
.sourceDB
[3]] == self
.targetDB
[99]
207 assert self
.sourceDB
[2] in m
209 d
= m
[self
.sourceDB
[4]]
210 raise AssertionError('failed to trap non-unique mapping')
215 raise AssertionError('failed to trap non-invertible mapping')
218 self
.sourceDB
.cursor
.execute("INSERT INTO %s VALUES (5,'seq78')"
219 % self
.sourceDB
.name
)
220 assert len(self
.sourceDB
) == 4
221 self
.sourceDB
.catchIter
= False # next step will cause iteration
225 correct
= [self
.sourceDB
[2],self
.sourceDB
[3]]
228 def test_mapview_inverse(self
):
229 'test inverse MapView of SQL join'
230 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
231 m
= MapView(self
.sourceDB
, self
.targetDB
,"""\
232 SELECT t2.third_id FROM %s t1, %s t2
233 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
234 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
,
236 SELECT t1.my_id FROM %s t1, %s t2
237 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
238 """ % (self
.joinTable1
,self
.joinTable2
))
239 r
= ~m
# get the inverse
240 assert self
.sourceDB
[2] == r
[self
.targetDB
[7]]
241 assert self
.sourceDB
[3] == r
[self
.targetDB
[99]]
242 assert self
.targetDB
[7] in r
244 m
= ~r
# get the inverse of the inverse!
245 assert m
[self
.sourceDB
[2]] == self
.targetDB
[7]
246 assert m
[self
.sourceDB
[3]] == self
.targetDB
[99]
247 assert self
.sourceDB
[2] in m
249 d
= m
[self
.sourceDB
[4]]
250 raise AssertionError('failed to trap non-unique mapping')
253 def test_graphview(self
):
254 'test GraphView of SQL join'
255 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
256 m
= GraphView(self
.sourceDB
, self
.targetDB
,"""\
257 SELECT t2.third_id FROM %s t1, %s t2
258 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
259 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
)
260 d
= m
[self
.sourceDB
[4]]
262 assert self
.targetDB
[6] in d
and self
.targetDB
[8] in d
263 assert self
.sourceDB
[2] in m
265 self
.sourceDB
.cursor
.execute("INSERT INTO %s VALUES (5,'seq78')"
266 % self
.sourceDB
.name
)
267 assert len(self
.sourceDB
) == 4
268 self
.sourceDB
.catchIter
= False # next step will cause iteration
272 correct
= [self
.sourceDB
[2],self
.sourceDB
[3],self
.sourceDB
[4]]
276 def test_graphview_inverse(self
):
277 'test inverse GraphView of SQL join'
278 self
.sourceDB
.catchIter
= self
.targetDB
.catchIter
= True
279 m
= GraphView(self
.sourceDB
, self
.targetDB
,"""\
280 SELECT t2.third_id FROM %s t1, %s t2
281 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
282 """ % (self
.joinTable1
,self
.joinTable2
), serverInfo
=self
.serverInfo
,
284 SELECT t1.my_id FROM %s t1, %s t2
285 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
286 """ % (self
.joinTable1
,self
.joinTable2
))
287 r
= ~m
# get the inverse
288 assert self
.sourceDB
[2] in r
[self
.targetDB
[7]]
289 assert self
.sourceDB
[3] in r
[self
.targetDB
[99]]
290 assert self
.targetDB
[7] in r
291 d
= r
[self
.targetDB
[6]]
293 assert self
.sourceDB
[4] in d
295 m
= ~r
# get inverse of the inverse!
296 d
= m
[self
.sourceDB
[4]]
298 assert self
.targetDB
[6] in d
and self
.targetDB
[8] in d
299 assert self
.sourceDB
[2] in m
301 class SQLiteBase(testutil
.SQLite_Mixin
):
302 def sqlite_load(self
):
303 self
.load_data('sqltable_test', writeable
=self
.writeable
)
305 class SQLiteTable_Test(SQLiteBase
, SQLTable_Test
):
306 def test_orderby(self
):
307 'test orderBy in SQLTable'
309 byNumber
= self
.tableClass(self
.orderTable
, serverInfo
=self
.serverInfo
,
310 orderBy
='ORDER BY number')
312 for val
in byNumber
.values():
313 bv
.append(val
.number
)
316 assert sortedBV
== bv
318 byLetter
= self
.tableClass(self
.orderTable
, serverInfo
=self
.serverInfo
,
319 orderBy
='ORDER BY letter')
321 for val
in byLetter
.values():
322 bl
.append(val
.letter
)
324 assert sortedBL
== bl
326 ## class SQLitePickle_Test(SQLiteTable_Test):
328 ## """Pickle / unpickle our serverInfo before trying to use it """
329 ## SQLiteTable_Test.setUp(self)
330 ## self.serverInfo.close()
332 ## s = pickle.dumps(self.serverInfo)
333 ## del self.serverInfo
334 ## self.serverInfo = pickle.loads(s)
335 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
336 ## self.sourceDB = self.tableClass(self.joinTable1,
337 ## serverInfo=self.serverInfo)
338 ## self.targetDB = self.tableClass(self.joinTable2,
339 ## serverInfo=self.serverInfo)
341 class SQLTable_NoCache_Test(SQLTable_Test
):
342 tableClass
= SQLTableNoCache
344 class SQLiteTable_NoCache_Test(SQLiteTable_Test
):
345 tableClass
= SQLTableNoCache
347 class SQLTableRW_Test(SQLTable_Setup
):
348 'test write operations'
351 'test row creation with auto inc ID'
352 self
.db
.catchIter
= True # no iter expected in this test
354 o
= self
.db
.new(seq_id
='freddy', start
=3000, stop
=4500)
355 assert len(self
.db
) == n
+ 1
356 t
= self
.tableClass(self
.tableName
,
357 serverInfo
=self
.serverInfo
) # requery the db
358 t
.catchIter
= True # no iter expected in this test
360 assert result
.seq_id
== 'freddy' and result
.start
==3000 \
361 and result
.stop
==4500
363 'check row creation with specified ID'
364 self
.db
.catchIter
= True # no iter expected in this test
366 o
= self
.db
.new(id=99, seq_id
='jeff', start
=3000, stop
=4500)
367 assert len(self
.db
) == n
+ 1
369 t
= self
.tableClass(self
.tableName
,
370 serverInfo
=self
.serverInfo
) # requery the db
371 t
.catchIter
= True # no iter expected in this test
373 assert result
.seq_id
== 'jeff' and result
.start
==3000 \
374 and result
.stop
==4500
376 'test changing an attr value'
377 self
.db
.catchIter
= True # no iter expected in this test
379 assert o
.seq_id
== 'seq2'
380 o
.seq_id
= 'newval' # overwrite this attribute
381 assert o
.seq_id
== 'newval' # check cached value
382 t
= self
.tableClass(self
.tableName
,
383 serverInfo
=self
.serverInfo
) # requery the db
384 t
.catchIter
= True # no iter expected in this test
386 assert result
.seq_id
== 'newval'
387 def test_delitem(self
):
388 'test deletion of a row'
389 self
.db
.catchIter
= True # no iter expected in this test
392 assert len(self
.db
) == n
- 1
395 raise AssertionError('old ID still exists!')
398 def test_setitem(self
):
399 'test assigning new ID to existing object'
400 self
.db
.catchIter
= True # no iter expected in this test
401 o
= self
.db
.new(id=17, seq_id
='bob', start
=2000, stop
=2500)
406 raise AssertionError('old ID still exists!')
409 t
= self
.tableClass(self
.tableName
,
410 serverInfo
=self
.serverInfo
) # requery the db
411 t
.catchIter
= True # no iter expected in this test
413 assert result
.seq_id
== 'bob' and result
.start
==2000 \
414 and result
.stop
==2500
417 raise AssertionError('old ID still exists!')
422 class SQLiteTableRW_Test(SQLiteBase
, SQLTableRW_Test
):
425 class SQLTableRW_NoCache_Test(SQLTableRW_Test
):
426 tableClass
= SQLTableNoCache
428 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test
):
429 tableClass
= SQLTableNoCache
431 class Ensembl_Test(unittest
.TestCase
):
434 # test will be skipped if mysql module or ensembldb server unavailable
436 logger
.debug('accessing ensembldb.ensembl.org')
437 conn
= DBServerInfo(host
='ensembldb.ensembl.org', user
='anonymous',
440 translationDB
= SQLTableCatcher('homo_sapiens_core_47_36i.translation',
442 translationDB
.catchIter
= True # should not iter in this test!
443 exonDB
= SQLTable('homo_sapiens_core_47_36i.exon', serverInfo
=conn
)
444 except ImportError,e
:
447 sql_statement
= '''SELECT t3.exon_id FROM
448 homo_sapiens_core_47_36i.translation AS tr,
449 homo_sapiens_core_47_36i.exon_transcript AS t1,
450 homo_sapiens_core_47_36i.exon_transcript AS t2,
451 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
452 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
453 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
454 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
455 t3.rank <= t2.rank ORDER BY t3.rank
457 print 'creating GraphView...'
458 self
.translationExons
= GraphView(translationDB
, exonDB
,
459 sql_statement
, serverInfo
=conn
)
460 print 'getting translation...'
461 self
.translation
= translationDB
[15121]
463 def test_orderBy(self
):
464 "Ensemble access, test order by"
465 'test issue 53: ensure that the ORDER BY results are correct'
467 exons
= self
.translationExons
[self
.translation
] # do the query
469 result
= [e
.id for e
in exons
]
470 correct
= [95160,95020,95035,95050,95059,95069,95081,95088,95101,
472 self
.assertEqual(result
, correct
) # make sure the exact order matches
475 if __name__
== '__main__':
476 PygrTestProgram(verbosity
=2)