added test of len() method for SQLTable
[pygr.git] / tests / sqltable_test.py
blobd5957fda86f2ef1a4662083daab91a4a2d4efed1
1 import os, unittest
2 from testlib import testutil, PygrTestProgram, SkipTest
3 from pygr.sqlgraph import SQLTable, SQLTableNoCache,\
4 MapView, GraphView, DBServerInfo, import_sqlite
5 from pygr import logger
7 class SQLTable_Setup(unittest.TestCase):
8 tableClass = SQLTable
9 def __init__(self, *args, **kwargs):
10 unittest.TestCase.__init__(self, *args, **kwargs)
11 self.serverInfo = DBServerInfo() # share conn for all tests
12 def setUp(self):
13 try:
14 self.load_data(writeable=self.writeable)
15 except ImportError:
16 raise SkipTest('missing MySQLdb module?')
17 def load_data(self, tableName='test.sqltable_test', writeable=False):
18 'create 3 tables and load 9 rows for our tests'
19 self.tableName = tableName
20 self.joinTable1 = joinTable1 = tableName + '1'
21 self.joinTable2 = joinTable2 = tableName + '2'
22 createTable = """\
23 CREATE TABLE %s (primary_id INTEGER PRIMARY KEY %%(AUTO_INCREMENT)s, seq_id TEXT, start INTEGER, stop INTEGER)
24 """ % tableName
25 self.db = self.tableClass(tableName, dropIfExists=True,
26 serverInfo=self.serverInfo,
27 createTable=createTable,
28 writeable=writeable)
29 self.sourceDB = self.tableClass(joinTable1, serverInfo=self.serverInfo,
30 dropIfExists=True, createTable="""\
31 CREATE TABLE %s (my_id INTEGER PRIMARY KEY,
32 other_id VARCHAR(16))
33 """ % joinTable1)
34 self.targetDB = self.tableClass(joinTable2, serverInfo=self.serverInfo,
35 dropIfExists=True, createTable="""\
36 CREATE TABLE %s (third_id INTEGER PRIMARY KEY,
37 other_id VARCHAR(16))
38 """ % joinTable2)
39 sql = """
40 INSERT INTO %s (seq_id, start, stop) VALUES ('seq1', 0, 10)
41 INSERT INTO %s (seq_id, start, stop) VALUES ('seq2', 5, 15)
42 INSERT INTO %s VALUES (2,'seq2')
43 INSERT INTO %s VALUES (3,'seq3')
44 INSERT INTO %s VALUES (4,'seq4')
45 INSERT INTO %s VALUES (7, 'seq2')
46 INSERT INTO %s VALUES (99, 'seq3')
47 INSERT INTO %s VALUES (6, 'seq4')
48 INSERT INTO %s VALUES (8, 'seq4')
49 """ % tuple(([tableName]*2) + ([joinTable1]*3) + ([joinTable2]*4))
50 for line in sql.strip().splitlines(): # insert our test data
51 self.db.cursor.execute(line.strip())
52 def tearDown(self):
53 self.db.cursor.execute('drop table if exists %s' % self.tableName)
54 self.db.cursor.execute('drop table if exists %s' % self.joinTable1)
55 self.db.cursor.execute('drop table if exists %s' % self.joinTable2)
56 self.serverInfo.close()
58 class SQLTable_Test(SQLTable_Setup):
59 writeable = False # read-only database interface
60 def test_keys(self):
61 k = self.db.keys()
62 k.sort()
63 assert k == [1, 2]
64 def test_len(self):
65 assert len(self.db) == len(self.db.keys())
66 def test_contains(self):
67 assert 1 in self.db
68 assert 2 in self.db
69 assert 'foo' not in self.db
70 def test_has_key(self):
71 assert self.db.has_key(1)
72 assert self.db.has_key(2)
73 assert not self.db.has_key('foo')
74 def test_get(self):
75 assert self.db.get('foo') is None
76 assert self.db.get(1) == self.db[1]
77 assert self.db.get(2) == self.db[2]
78 def test_items(self):
79 i = [ k for (k,v) in self.db.items() ]
80 i.sort()
81 assert i == [1, 2]
82 def test_iterkeys(self):
83 kk = self.db.keys()
84 kk.sort()
85 ik = list(self.db.iterkeys())
86 ik.sort()
87 assert kk == ik
88 def test_itervalues(self):
89 kv = self.db.values()
90 kv.sort()
91 iv = list(self.db.itervalues())
92 iv.sort()
93 assert kv == iv
94 def test_itervalues_long(self):
95 """test iterator isolation from queries run inside iterator loop """
96 sql = 'insert into %s (start) values (1)' % self.tableName
97 for i in range(40000): # insert 40000 rows
98 self.db.cursor.execute(sql)
99 iv = []
100 for o in self.db.itervalues():
101 status = 99 in self.db # make it do a query inside iterator loop
102 iv.append(o.id)
103 kv = [o.id for o in self.db.values()]
104 assert len(kv) == len(iv)
105 assert kv == iv
106 def test_iteritems(self):
107 ki = self.db.items()
108 ki.sort()
109 ii = list(self.db.iteritems())
110 ii.sort()
111 assert ki == ii
112 def test_readonly(self):
113 'test error handling of write attempts to read-only DB'
114 try:
115 self.db.new(seq_id='freddy', start=3000, stop=4500)
116 raise AssertionError('failed to trap attempt to write to db')
117 except ValueError:
118 pass
119 o = self.db[1]
120 try:
121 self.db[33] = o
122 raise AssertionError('failed to trap attempt to write to db')
123 except ValueError:
124 pass
125 try:
126 del self.db[2]
127 raise AssertionError('failed to trap attempt to write to db')
128 except ValueError:
129 pass
131 ### @CTB need to test write access
132 def test_mapview(self):
133 'test MapView of SQL join'
134 m = MapView(self.sourceDB, self.targetDB,"""\
135 SELECT t2.third_id FROM %s t1, %s t2
136 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
137 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
138 assert m[self.sourceDB[2]] == self.targetDB[7]
139 assert m[self.sourceDB[3]] == self.targetDB[99]
140 assert self.sourceDB[2] in m
141 try:
142 d = m[self.sourceDB[4]]
143 raise AssertionError('failed to trap non-unique mapping')
144 except KeyError:
145 pass
146 def test_graphview(self):
147 'test GraphView of SQL join'
148 m = GraphView(self.sourceDB, self.targetDB,"""\
149 SELECT t2.third_id FROM %s t1, %s t2
150 WHERE t1.my_id=%%s and t1.other_id=t2.other_id
151 """ % (self.joinTable1,self.joinTable2), serverInfo=self.serverInfo)
152 d = m[self.sourceDB[4]]
153 assert len(d) == 2
154 assert self.targetDB[6] in d and self.targetDB[8] in d
155 assert self.sourceDB[2] in m
157 class SQLiteBase(testutil.SQLite_Mixin):
158 def sqlite_load(self):
159 self.load_data('sqltable_test', writeable=self.writeable)
161 class SQLiteTable_Test(SQLiteBase, SQLTable_Test):
162 pass
164 ## class SQLitePickle_Test(SQLiteTable_Test):
165 ## def setUp(self):
166 ## """Pickle / unpickle our serverInfo before trying to use it """
167 ## SQLiteTable_Test.setUp(self)
168 ## self.serverInfo.close()
169 ## import pickle
170 ## s = pickle.dumps(self.serverInfo)
171 ## del self.serverInfo
172 ## self.serverInfo = pickle.loads(s)
173 ## self.db = self.tableClass(self.tableName, serverInfo=self.serverInfo)
174 ## self.sourceDB = self.tableClass(self.joinTable1,
175 ## serverInfo=self.serverInfo)
176 ## self.targetDB = self.tableClass(self.joinTable2,
177 ## serverInfo=self.serverInfo)
179 class SQLTable_NoCache_Test(SQLTable_Test):
180 tableClass = SQLTableNoCache
182 class SQLiteTable_NoCache_Test(SQLiteTable_Test):
183 tableClass = SQLTableNoCache
185 class SQLTableRW_Test(SQLTable_Setup):
186 'test write operations'
187 writeable = True
188 def test_new(self):
189 'test row creation with auto inc ID'
190 n = len(self.db)
191 o = self.db.new(seq_id='freddy', start=3000, stop=4500)
192 assert len(self.db) == n + 1
193 t = self.tableClass(self.tableName,
194 serverInfo=self.serverInfo) # requery the db
195 result = t[o.id]
196 assert result.seq_id == 'freddy' and result.start==3000 \
197 and result.stop==4500
198 def test_new2(self):
199 'check row creation with specified ID'
200 n = len(self.db)
201 o = self.db.new(id=99, seq_id='jeff', start=3000, stop=4500)
202 assert len(self.db) == n + 1
203 assert o.id == 99
204 t = self.tableClass(self.tableName,
205 serverInfo=self.serverInfo) # requery the db
206 result = t[99]
207 assert result.seq_id == 'jeff' and result.start==3000 \
208 and result.stop==4500
209 def test_attr(self):
210 'test changing an attr value'
211 o = self.db[2]
212 assert o.seq_id == 'seq2'
213 o.seq_id = 'newval' # overwrite this attribute
214 assert o.seq_id == 'newval' # check cached value
215 t = self.tableClass(self.tableName,
216 serverInfo=self.serverInfo) # requery the db
217 result = t[2]
218 assert result.seq_id == 'newval'
219 def test_delitem(self):
220 'test deletion of a row'
221 n = len(self.db)
222 del self.db[1]
223 assert len(self.db) == n - 1
224 try:
225 result = self.db[1]
226 raise AssertionError('old ID still exists!')
227 except KeyError:
228 pass
229 def test_setitem(self):
230 'test assigning new ID to existing object'
231 o = self.db.new(id=17, seq_id='bob', start=2000, stop=2500)
232 self.db[13] = o
233 assert o.id == 13
234 try:
235 result = self.db[17]
236 raise AssertionError('old ID still exists!')
237 except KeyError:
238 pass
239 t = self.tableClass(self.tableName,
240 serverInfo=self.serverInfo) # requery the db
241 result = t[13]
242 assert result.seq_id == 'bob' and result.start==2000 \
243 and result.stop==2500
244 try:
245 result = t[17]
246 raise AssertionError('old ID still exists!')
247 except KeyError:
248 pass
251 class SQLiteTableRW_Test(SQLiteBase, SQLTableRW_Test):
252 pass
254 class SQLTableRW_NoCache_Test(SQLTableRW_Test):
255 tableClass = SQLTableNoCache
257 class SQLiteTableRW_NoCache_Test(SQLiteTableRW_Test):
258 tableClass = SQLTableNoCache
260 class Ensembl_Test(unittest.TestCase):
262 def setUp(self):
263 # test will be skipped if mysql module or ensembldb server unavailable
265 logger.debug('accessing ensembldb.ensembl.org')
266 conn = DBServerInfo(host='ensembldb.ensembl.org', user='anonymous',
267 passwd='')
268 try:
269 translationDB = SQLTable('homo_sapiens_core_47_36i.translation',
270 serverInfo=conn)
271 exonDB = SQLTable('homo_sapiens_core_47_36i.exon', serverInfo=conn)
272 except ImportError,e:
273 raise SkipTest(e)
275 sql_statement = '''SELECT t3.exon_id FROM
276 homo_sapiens_core_47_36i.translation AS tr,
277 homo_sapiens_core_47_36i.exon_transcript AS t1,
278 homo_sapiens_core_47_36i.exon_transcript AS t2,
279 homo_sapiens_core_47_36i.exon_transcript AS t3 WHERE tr.translation_id = %s
280 AND tr.transcript_id = t1.transcript_id AND t1.transcript_id =
281 t2.transcript_id AND t2.transcript_id = t3.transcript_id AND t1.exon_id =
282 tr.start_exon_id AND t2.exon_id = tr.end_exon_id AND t3.rank >= t1.rank AND
283 t3.rank <= t2.rank ORDER BY t3.rank
285 self.translationExons = GraphView(translationDB, exonDB,
286 sql_statement, serverInfo=conn)
287 self.translation = translationDB[15121]
289 def test_orderBy(self):
290 "Ensemble access, test order by"
291 'test issue 53: ensure that the ORDER BY results are correct'
292 exons = self.translationExons[self.translation] # do the query
293 result = [e.id for e in exons]
294 correct = [95160,95020,95035,95050,95059,95069,95081,95088,95101,
295 95110,95172]
296 self.assertEqual(result, correct) # make sure the exact order matches
299 if __name__ == '__main__':
300 PygrTestProgram(verbosity=2)