Merge branch 'tutorial'
[pygr.git] / tests / sqltable_test.py
blob3f45e4f1c3c6f2dee8393e6030ba824787375faa
1 import os, unittest
2 from testlib import testutil, PygrTestProgram, SkipTest
3 from pygr.sqlgraph import SQLTable, SQLTableNoCache,\
4 MapView, GraphView, DBServerInfo, import_sqlite
5 from pygr import logger
7 class SQLTable_Setup(unittest.TestCase):
8 tableClass = SQLTable
9 def __init__(self, *args, **kwargs):
10 unittest.TestCase.__init__(self, *args, **kwargs)
11 self.serverInfo = DBServerInfo() # share conn for all tests
12 def setUp(self):
13 try:
14 self.load_data(writeable=self.writeable)
15 except ImportError:
16 raise SkipTest('missing MySQLdb module?')
17 def load_data(self, tableName='test.sqltable_test', writeable=False):
18 'create 3 tables and load 9 rows for our tests'
19 self.tableName = tableName
20 self.joinTable1 = joinTable1 = tableName + '1'
21 self.joinTable2 = joinTable2 = tableName + '2'
22 createTable = """\
23 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
24 """ % tableName
25 self.db = self.tableClass(tableName, dropIfExists=True,
26 serverInfo=self.serverInfo,
27 createTable=createTable,
28 writeable=writeable)
29 self.sourceDB = self.tableClass(joinTable1, serverInfo=self.serverInfo,
30 dropIfExists=True, createTable="""\
31 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
32 other_id VARCHAR(16))
33 """ % joinTable1)
34 self.targetDB = self.tableClass(joinTable2, serverInfo=self.serverInfo,
35 dropIfExists=True, createTable="""\
36 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
37 other_id VARCHAR(16))
38 """ % joinTable2)
39 sql = """
40 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
41 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
42 INSERT INTO %s VALUES (2,'seq2')
43 INSERT INTO %s VALUES (3,'seq3')
44 INSERT INTO %s VALUES (4,'seq4')
45 INSERT INTO %s VALUES (7, 'seq2')
46 INSERT INTO %s VALUES (99, 'seq3')
47 INSERT INTO %s VALUES (6, 'seq4')
48 INSERT INTO %s VALUES (8, 'seq4')
49 """ % tuple(([tableName]*2) + ([joinTable1]*3) + ([joinTable2]*4))
50 for line in sql.strip().splitlines(): # insert our test data
51 self.db.cursor.execute(line.strip())
52 def tearDown(self):
53 self.db.cursor.execute('drop table if exists %s' % self.tableName)
54 self.db.cursor.execute('drop table if exists %s' % self.joinTable1)
55 self.db.cursor.execute('drop table if exists %s' % self.joinTable2)
56 self.serverInfo.close()
58 class SQLTable_Test(SQLTable_Setup):
59 writeable = False # read-only database interface
60 def test_keys(self):
61 k = self.db.keys()
62 k.sort()
63 assert k == [1, 2]
64 def test_len(self):
65 assert len(self.db) == len(self.db.keys())
66 def test_contains(self):
67 assert 1 in self.db
68 assert 2 in self.db
69 assert 'foo' not in self.db
70 def test_has_key(self):
71 assert self.db.has_key(1)
72 assert self.db.has_key(2)
73 assert not self.db.has_key('foo')
74 def test_get(self):
75 assert self.db.get('foo') is None
76 assert self.db.get(1) == self.db[1]
77 assert self.db.get(2) == self.db[2]
78 def test_items(self):
79 i = [ k for (k,v) in self.db.items() ]
80 i.sort()
81 assert i == [1, 2]
82 def test_iterkeys(self):
83 kk = self.db.keys()
84 kk.sort()
85 ik = list(self.db.iterkeys())
86 ik.sort()
87 assert kk == ik
88 def test_itervalues(self):
89 kv = self.db.values()
90 kv.sort()
91 iv = list(self.db.itervalues())
92 iv.sort()
93 assert kv == iv
94 def test_itervalues_long(self):
95 """test iterator isolation from queries run inside iterator loop """
96 sql = 'insert into %s (start) values (1)' % self.tableName
97 for i in range(40000): # insert 40000 rows
98 self.db.cursor.execute(sql)
99 iv = []
100 for o in self.db.itervalues():
101 status = 99 in self.db # make it do a query inside iterator loop
102 iv.append(o.id)
103 kv = [o.id for o in self.db.values()]
104 assert len(kv) == len(iv)
105 assert kv == iv
106 def test_iteritems(self):
107 ki = self.db.items()
108 ki.sort()
109 ii = list(self.db.iteritems())
110 ii.sort()
111 assert ki == ii
112 def test_readonly(self):
113 'test error handling of write attempts to read-only DB'
114 try:
115 self.db.new(seq_id='freddy', start=3000, stop=4500)
116 raise AssertionError('failed to trap attempt to write to db')
117 except ValueError:
118 pass
119 o = self.db[1]
120 try:
121 self.db[33] = o
122 raise AssertionError('failed to trap attempt to write to db')
123 except ValueError:
124 pass
125 try:
126 del self.db[2]
127 raise AssertionError('failed to trap attempt to write to db')
128 except ValueError:
129 pass
131 ### @CTB need to test write access
132 def test_mapview(self):
133 'test MapView of SQL join'
134 m = MapView(self.sourceDB, self.targetDB,"""\
135 SELECT t2.third_id FROM %s t1, %s t2
136 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
137 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
138 assert m[self.sourceDB[2]] == self.targetDB[7]
139 assert m[self.sourceDB[3]] == self.targetDB[99]
140 assert self.sourceDB[2] in m
141 try:
142 d = m[self.sourceDB[4]]
143 raise AssertionError('failed to trap non-unique mapping')
144 except KeyError:
145 pass
146 try:
147 r = ~m
148 raise AssertionError('failed to trap non-invertible mapping')
149 except ValueError:
150 pass
151 def test_mapview_inverse(self):
152 'test inverse MapView of SQL join'
153 m = MapView(self.sourceDB, self.targetDB,"""\
154 SELECT t2.third_id FROM %s t1, %s t2
155 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
156 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo,
157 inverseSQL="""\
158 SELECT t1.my_id FROM %s t1, %s t2
159 WHERE t2.third_id=%%s and t1.other_id=t2.other_id
160 """ % (self.joinTable1,self.joinTable2))
161 assert m[self.sourceDB[2]] == self.targetDB[7]
162 assert m[self.sourceDB[3]] == self.targetDB[99]
163 assert self.sourceDB[2] in m
164 r = ~m # get the inverse
165 assert self.sourceDB[2] == r[self.targetDB[7]]
166 assert self.sourceDB[3] == r[self.targetDB[99]]
167 assert self.targetDB[7] in r
169 try:
170 d = m[self.sourceDB[4]]
171 raise AssertionError('failed to trap non-unique mapping')
172 except KeyError:
173 pass
174 def test_graphview(self):
175 'test GraphView of SQL join'
176 m = GraphView(self.sourceDB, self.targetDB,"""\
177 SELECT t2.third_id FROM %s t1, %s t2
178 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
179 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
180 d = m[self.sourceDB[4]]
181 assert len(d) == 2
182 assert self.targetDB[6] in d and self.targetDB[8] in d
183 assert self.sourceDB[2] in m
185 class SQLiteBase(testutil.SQLite_Mixin):
186 def sqlite_load(self):
187 self.load_data('sqltable_test', writeable=self.writeable)
189 class SQLiteTable_Test(SQLiteBase, SQLTable_Test):
190 pass
192 ## class SQLitePickle_Test(SQLiteTable_Test):
193 ## def setUp(self):
194 ## """Pickle / unpickle our serverInfo before trying to use it """
195 ## SQLiteTable_Test.setUp(self)
196 ## self.serverInfo.close()
197 ## import pickle
198 ## s = pickle.dumps(self.serverInfo)
199 ## del self.serverInfo
200 ## self.serverInfo = pickle.loads(s)
201 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
202 ## self.sourceDB = self.tableClass(self.joinTable1,
203 ## serverInfo=self.serverInfo)
204 ## self.targetDB = self.tableClass(self.joinTable2,
205 ## serverInfo=self.serverInfo)
207 class SQLTable_NoCache_Test(SQLTable_Test):
208 tableClass = SQLTableNoCache
210 class SQLiteTable_NoCache_Test(SQLiteTable_Test):
211 tableClass = SQLTableNoCache
213 class SQLTableRW_Test(SQLTable_Setup):
214 'test write operations'
215 writeable = True
216 def test_new(self):
217 'test row creation with auto inc ID'
218 n = len(self.db)
219 o = self.db.new(seq_id='freddy', start=3000, stop=4500)
220 assert len(self.db) == n + 1
221 t = self.tableClass(self.tableName,
222 serverInfo=self.serverInfo) # requery the db
223 result = t[o.id]
224 assert result.seq_id == 'freddy' and result.start==3000 \
225 and result.stop==4500
226 def test_new2(self):
227 'check row creation with specified ID'
228 n = len(self.db)
229 o = self.db.new(id=99, seq_id='jeff', start=3000, stop=4500)
230 assert len(self.db) == n + 1
231 assert o.id == 99
232 t = self.tableClass(self.tableName,
233 serverInfo=self.serverInfo) # requery the db
234 result = t[99]
235 assert result.seq_id == 'jeff' and result.start==3000 \
236 and result.stop==4500
237 def test_attr(self):
238 'test changing an attr value'
239 o = self.db[2]
240 assert o.seq_id == 'seq2'
241 o.seq_id = 'newval' # overwrite this attribute
242 assert o.seq_id == 'newval' # check cached value
243 t = self.tableClass(self.tableName,
244 serverInfo=self.serverInfo) # requery the db
245 result = t[2]
246 assert result.seq_id == 'newval'
247 def test_delitem(self):
248 'test deletion of a row'
249 n = len(self.db)
250 del self.db[1]
251 assert len(self.db) == n - 1
252 try:
253 result = self.db[1]
254 raise AssertionError('old ID still exists!')
255 except KeyError:
256 pass
257 def test_setitem(self):
258 'test assigning new ID to existing object'
259 o = self.db.new(id=17, seq_id='bob', start=2000, stop=2500)
260 self.db[13] = o
261 assert o.id == 13
262 try:
263 result = self.db[17]
264 raise AssertionError('old ID still exists!')
265 except KeyError:
266 pass
267 t = self.tableClass(self.tableName,
268 serverInfo=self.serverInfo) # requery the db
269 result = t[13]
270 assert result.seq_id == 'bob' and result.start==2000 \
271 and result.stop==2500
272 try:
273 result = t[17]
274 raise AssertionError('old ID still exists!')
275 except KeyError:
276 pass
279 class SQLiteTableRW_Test(SQLiteBase, SQLTableRW_Test):
280 pass
282 class SQLTableRW_NoCache_Test(SQLTableRW_Test):
283 tableClass = SQLTableNoCache
285 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test):
286 tableClass = SQLTableNoCache
288 class Ensembl_Test(unittest.TestCase):
290 def setUp(self):
291 # test will be skipped if mysql module or ensembldb server unavailable
293 logger.debug('accessing ensembldb.ensembl.org')
294 conn = DBServerInfo(host='ensembldb.ensembl.org', user='anonymous',
295 passwd='')
296 try:
297 translationDB = SQLTable('homo_sapiens_core_47_36i.translation',
298 serverInfo=conn)
299 exonDB = SQLTable('homo_sapiens_core_47_36i.exon', serverInfo=conn)
300 except ImportError,e:
301 raise SkipTest(e)
303 sql_statement = '''SELECT t3.exon_id FROM
304 homo_sapiens_core_47_36i.translation AS tr,
305 homo_sapiens_core_47_36i.exon_transcript AS t1,
306 homo_sapiens_core_47_36i.exon_transcript AS t2,
307 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
308 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
309 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
310 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
311 t3.rank <= t2.rank ORDER BY t3.rank
313 self.translationExons = GraphView(translationDB, exonDB,
314 sql_statement, serverInfo=conn)
315 self.translation = translationDB[15121]
317 def test_orderBy(self):
318 "Ensemble access, test order by"
319 'test issue 53: ensure that the ORDER BY results are correct'
320 exons = self.translationExons[self.translation] # do the query
321 result = [e.id for e in exons]
322 correct = [95160,95020,95035,95050,95059,95069,95081,95088,95101,
323 95110,95172]
324 self.assertEqual(result, correct) # make sure the exact order matches
327 if __name__ == '__main__':
328 PygrTestProgram(verbosity=2)