branch-test modify of techwars.py
[limo.git] / db.py
blob927b2ca31adb1292550f1a06ac3ab1b11e12d91d
1 #!/usr/bin/env python
3 from __future__ import with_statement
4 import random, re, os, datetime, hashlib, thread
5 try:
6 import MySQLdb
7 except ImportError,e:
8 raise ImportError("MYSQL ERROR: "+str(e))
9 import apsw
10 import warnings
11 from limoutil import log
13 warnings.filterwarnings('ignore', '^.*exists$') # ignore the warnings generated by 'create X if not exists' statements
15 class Model(object):
16 """ The Model object is a way to create database schema incrementally.
18 The goal is to allow disparate modules to require different tables, or different columns on existing tables.
20 1) Create the test database.
22 Clear the test database forcefully.
24 >>> q=Query("drop database test; create database test; use test", dsn='test')
26 2) Define the database schema.
28 .database() - Ensures the database exists.
29 If there was no call to .database() is made in a sequence, then the db from the dsn passed to the Model() constructor is assumed for future operations.
30 .table() - Ensures the table exists.
32 .col() - Ensures that the column exists.
34 .data() - Ensures that some default data is in place.
36 The execute() on the end will build only the pieces that are missing.
38 >>> start = datetime.datetime.now()
39 >>> s=Model(dsn="test")\
40 .database("test")\
41 .table("users")\
42 .col("uid int primary key auto_increment")\
43 .col("name text not null")\
44 .data([1, 'anonymous'])\
45 .data([2, 'jldailey'])\
46 .table("groups")\
47 .col("gid int primary key auto_increment")\
48 .col("name text not null")\
49 .data([1, 'admin'])\
50 .table("roles")\
51 .col("rid int primary key auto_increment")\
52 .col("name text not null")\
53 .data([1, 'admin'])\
54 .table("group_roles")\
55 .col("gid int not null")\
56 .col("rid int not null")\
57 .col("primary key (gid, rid)")\
58 .data([1, 1])\
59 .table("group_users")\
60 .col("gid int not null")\
61 .col("uid int not null")\
62 .data([1, 2])\
63 .table("user_roles")\
64 .col("uid int not null")\
65 .col("rid int not null")\
66 .execute()
68 3) Extending the existing model
70 The reason to use the schema object (rather than directly modifying the db) is that you can come along later
71 and extend the existing model without clobbering (or knowing much about) what already exists.
73 So this example adds 2 fields to the users table.
75 >>> s=Model(dsn="test")\
76 .database("test")\
77 .table("users")\
78 .col("email varchar(128)")\
79 .col("passwd varchar(32) not null default ''")\
80 .execute()
82 >>> Query("show create table test.users", dsn="test")[0][1]
83 u"CREATE TABLE `users` (\\n `uid` int(11) NOT NULL AUTO_INCREMENT,\\n `name` text NOT NULL,\\n `email` varchar(128) DEFAULT NULL,\\n `passwd` varchar(32) NOT NULL DEFAULT '',\\n PRIMARY KEY (`uid`)\\n) ENGINE=MyISAM AUTO_INCREMENT=3 DEFAULT CHARSET=latin1"
85 If code in 2 separate places requires the same schema, it has no effect.
87 So, adding the email column again here has no effect.
89 Think of the column definitions as column requirements, ensuring that the column exists as defined.
91 >>> s=Model(dsn="test")\
92 .database("test")\
93 .table("users")\
94 .col("email varchar(128)")\
95 .col("unique index uq_email (email)")\
96 .col("index ix_email_uid (email, uid)")\
97 .execute()
98 >>> Query("show create table test.users", dsn="test")[0][1]
99 u"CREATE TABLE `users` (\\n `uid` int(11) NOT NULL AUTO_INCREMENT,\\n `name` text NOT NULL,\\n `email` varchar(128) DEFAULT NULL,\\n `passwd` varchar(32) NOT NULL DEFAULT '',\\n PRIMARY KEY (`uid`),\\n UNIQUE KEY `uq_email` (`email`),\\n KEY `ix_email_uid` (`email`,`uid`)\\n) ENGINE=MyISAM AUTO_INCREMENT=3 DEFAULT CHARSET=latin1"
101 >>> s=Model(dsn="test").defaultTable("pages")\
102 .column("pid int primary key auto_increment")\
103 .column("url varchar(256) not null")\
104 .column("name varchar(256) not null")\
105 .execute()
106 >>> s=Model(dsn="test").defaultTable("pages")\
107 .column("pid int primary key auto_increment")\
108 .column("url varchar(256) not null")\
109 .column("name varchar(256) not null")\
110 .execute()
113 4) You can forcefully modify an existing column definition by specifying a version on the column.
115 Also in this example you can see that passing the database() explicitly is not required, it can be inferred from the dsn passed to the Model() constructor.
117 >>> s=Model(dsn="test")\
118 .table("users")\
119 .col("email varchar(256)", version=1)\
120 .execute()
121 >>> Query("show create table test.users", dsn="test")[0][1]
122 u"CREATE TABLE `users` (\\n `uid` int(11) NOT NULL AUTO_INCREMENT,\\n `name` text NOT NULL,\\n `email` varchar(256) DEFAULT NULL,\\n `passwd` varchar(32) NOT NULL DEFAULT '',\\n PRIMARY KEY (`uid`),\\n UNIQUE KEY `uq_email` (`email`),\\n KEY `ix_email_uid` (`email`,`uid`)\\n) ENGINE=MyISAM AUTO_INCREMENT=3 DEFAULT CHARSET=latin1"
124 >>> s=Model(dsn="test")\
125 .table("users")\
126 .col("name varchar(64) not null", version=1)\
127 .col("unique index (name)")\
128 .execute()
129 >>> Query("show create table test.users", dsn="test")[0][1]
130 u"CREATE TABLE `users` (\\n `uid` int(11) NOT NULL AUTO_INCREMENT,\\n `name` varchar(64) NOT NULL,\\n `email` varchar(256) DEFAULT NULL,\\n `passwd` varchar(32) NOT NULL DEFAULT '',\\n PRIMARY KEY (`uid`),\\n UNIQUE KEY `uq_name` (`name`),\\n UNIQUE KEY `uq_email` (`email`),\\n KEY `ix_email_uid` (`email`,`uid`)\\n) ENGINE=MyISAM AUTO_INCREMENT=3 DEFAULT CHARSET=latin1"
132 This test always fails.
134 >>> _ms_elapsed(start)
139 _init = {}
141 def __init__(self, dsn="default", skipExecute=False):
142 self.sql = ""
143 self.dsn = dsn
144 self.db = {}
145 self._db = None
146 self._table = None
147 self._table_opts = {}
148 self._column = None
149 self._data = {}
150 self._order = {}
151 self._version = {}
152 if not skipExecute:
153 # only update the tables the first time
154 if not self._init.get(self.__class__.__name__, False):
155 self.model()
156 self.execute()
157 self._init[self.__class__.__name__] = True
159 def model(self):
160 self\
161 .table("schema_history")\
162 .col("`table_name` varchar(128) NOT NULL")\
163 .col("`column_name` varchar(128) NOT NULL")\
164 .col("`column_definition` text NOT NULL")\
165 .col("`version` int(11) NOT NULL")\
166 .col("`entered` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP")\
168 def defaultDatabase(self, name):
169 """ Require that a database exists, creating if necessary. """
170 if self.db.get(name, None) is None:
171 self.db[name] = {}
172 self._db = name
173 return self
174 def database(self, name):
175 """ Alias for defaultDatabase() """
176 return self.defaultDatabase(name)
178 def defaultTable(self, name,options=None):
179 """ Require that a certain table be present. Uses the most recent call to .requireDatabase to get the context for the table.
180 If there is no most recent call to .requireDatabase, parse the database from the datasource definition.
181 Subsequent calls to defaultColumn() can be used to ensure the table's content.
183 if self._db is None:
184 try:
185 Query._read_cfg()
186 (prot, user, password, host, port, db) = re.findall(Query._db_cfg_regex, Query._cfg[self.dsn])[0]
187 except TypeError:
188 raise Exception("no default database found: dsn: %s cfg: %s -> %s" % ( str(self.dsn), str(Query._cfg[self.dsn]), str(e)))
189 self.defaultDatabase(db)
190 if self.db[self._db].get(name, None) is None:
191 self.db[self._db][name] = {}
192 self._table = name
193 if options is not None:
194 self._table_opts[name] = options
195 key = self._db
196 if self._order.get(key, None) is None:
197 self._order[key] = []
198 self._order[key].append(self._table)
199 self.log("Requiring table %s: [%s] = %s" % (name, key, name))
200 return self
201 def table(self, name,options=None):
202 """ Alias for defaultTable() """
203 return self.defaultTable(name,options)
205 def defaultColumn(self, definition, version=0, skipNormalize=False):
207 Require that a certain column be present.
208 Uses the most recent call to .requireDatabase and .requireTable to get the context for where this column should exist.
210 if not skipNormalize:
211 definition = self._normalize_column_definition(definition)
212 # after a column definition is normalized, it might actually turn out to be more than one column (a data column, and an index)
213 if type(definition) == list:
214 # if so, recursively add the column and then the index
215 self.log("Column requirement is a list: %s" % ( str(definition) ) )
216 for d in definition:
217 self.defaultColumn(d.lower(), version=version, skipNormalize=True)
218 return self
219 definition = definition.lower()
220 name = _getNameFromDefn(definition)
221 # self.log("NAME: %s FROM: %s" % (name, definition))
222 self._column = name
223 if self.db[self._db][self._table].get(name, None) is None:
224 self.db[self._db][self._table][name] = definition
225 key = self._db+self._table
226 if version > 0:
227 self.log("Saving _version argument: [%s] = %s" % (key+name, version))
228 self._version[key+name] = version
229 if self._order.get(key, None) is None:
230 self._order[key] = []
231 # self.log("Appending column %s to key %s (from definition: %s)" % (name, key, definition))
232 self._order[key].append((name,definition))
233 self.log("Requiring column %s: [%s] = %s" % (name, key, definition))
234 return self
235 def column(self, d, version=0): return self.defaultColumn(d,version=version)
236 def col(self, d, version=0): return self.defaultColumn(d,version=version)
237 def c(self, d, version=0): return self.defaultColumn(d,version=version)
239 def defaultData(self, row):
240 key = self._db+self._table
241 if self._data.get(key, None) is None:
242 self._data[key] = [row]
243 else:
244 self._data[key].append(row)
245 return self
246 def data(self, row): return self.defaultData(row)
248 @staticmethod
249 def log(msg):
250 if msg[-1] != '\n':
251 msg += '\n'
252 with open("db.log", "a+") as f:
253 f.write(msg)
255 def execute(self):
256 for db, tables in self.db.items():
257 FastQuery("create database if not exists %s" % db, dsn=self.dsn)
258 # pull out all the tables in the order they were specified
259 for table_name in self._order[db]:
260 columns = self.db[db][table_name]
261 FastQuery("create table if not exists %s.%s ( __filler__ int null ) %s" %
262 (db, table_name, self._table_opts.get(table_name,'')), dsn=self.dsn)
263 table = _TableInfo(db+"."+table_name, dsn=self.dsn)
264 assert len(table.lines.keys()) > 0
265 key = db+table_name
266 skipPKLater = False
267 # pull out all the column in the order they were specified
268 for (column_name, column_definition) in self._order.get(key, []):
269 # self.log("EXECUTING: '%s' from %s" % (column_name, column_definition))
270 if table.names.get(unicode(column_name), False): # if the table already has this column
271 if table.names[unicode(column_name)] == column_definition: # if the existing column is identical
272 # do nothing
273 # self.log("COLUMN ALREADY EXISTS: %s as %s" % (column_name, column_definition))
274 continue
275 # retrieve the version specified for this column update
276 version = self._version.get(db+table_name+column_name, None)
277 self.log('DETECTED VERSION CHANGE [%s] \'%s\' :: %s != %s (version:%s)' % (table_name+"."+column_name, unicode(column_name), table.names[unicode(column_name)], column_definition, version))
278 # if none was specified then we cant do versioning, so bail
279 if version is None or version == 0:
280 e = ModelConflictError("%s already defined as %s, and no new version number was specified" % (column_definition, table.names[unicode(column_name)]))
281 # self.log('SCHEMA CONFLICT ERROR: %s' % str(e))
282 raise e
283 # get the current version of the column as it is on the table
284 current_definition = _TableInfo(db+"."+table_name, dsn=self.dsn).getName('primary key')
285 self.log("CURRENT DEFINITION: %s" % current_definition)
286 current_version = Query("select coalesce(max(version),0) from schema_history where table_name = '%s' and column_name = '%s' and column_definition = '%s'" % (table_name, column_name, current_definition), dsn=self.dsn)[0][0]
287 current_version = int(current_version)
288 # if the specified version is newer
289 if version > current_version:
290 # update the history
291 # TODO: save version 0
292 FastQuery("replace into schema_history (table_name, column_name, column_definition, version) values ('%s', '%s', '%s', %d)" \
293 % (table_name, column_name, safe(column_definition), version), dsn=self.dsn)
294 # and then modify the existing column
295 if column_definition.startswith("primary key"):
296 FastQuery("alter table %s.%s drop primary key" % (db, table_name))
297 FastQuery("alter table %s.%s add %s" % (db, table_name, column_definition))
298 else:
299 FastQuery("alter table %s.%s modify %s" % (db, table_name, column_definition), dsn=self.dsn)
300 else:
301 # skip obsolete column definition
302 continue
303 else:
304 # self.log("COLUMN DID NOT PREVIOUSLY EXIST: %s in %s" % (column_definition, str(table.names.keys())))
305 # the column didnt exist so just add it
306 # HACK: to fix bug with primary key columns
307 if skipPKLater and column_definition.find("primary key") > -1:
308 skipPKLater = False
309 continue
310 if column_definition.find("auto_increment") > -1:
311 column_definition = column_definition.replace("auto_increment", "auto_increment primary key")
312 skipPKLater = True
313 try:
314 FastQuery("alter table %s.%s add %s" % (db, table_name, column_definition), dsn=self.dsn)
315 except Exception, e:
316 self.log("FAILED: %s (sql: %s)" % ( str(e), "alter table %s.%s add %s" % (db, table_name, column_definition)))
317 raise
318 if table.names.get("__filler__", False):
319 FastQuery("alter table %s.%s drop __filler__" % (db, table_name))
320 if self._data.get(key, None) is not None:
321 data = self._data[key]
322 self.log("Key: %s has required data: %s" % (key, str(data)))
323 table = _TableInfo(db+"."+table_name, dsn=self.dsn) # reload the table definition
324 for row in data:
325 cols = ', '.join([c for c in table.order[0:len(row)] if c != 'primary key'])
326 for r in range(len(row)):
327 if row[r] is None:
328 row[r] = "NULL"
329 values = ', '.join(["'%s'"%s for s in row])
330 FastQuery("insert ignore into %s ( %s ) values ( %s )" % (table_name, cols, values), dsn=self.dsn)
331 del self._data[key]
333 def _normalize_column_definition(self, defn):
336 This creates an empty 'probe' table in the database with just the necessary columns.
337 Then calls 'show table' to get the server to produce the standard version of that sql.
340 >>> Model()._normalize_column_definition('email varchar(32)')
341 u'`email` varchar(32) default null'
343 If a single line column definition actually results in more than one column on the resulting table, a list of columns is returned.
345 >>> Model()._normalize_column_definition('uid int primary key auto_increment')
346 [u'`uid` int(11) not null auto_increment', u'primary key (`uid`)']
348 >>> Model()._normalize_column_definition('primary key (id, name)')
349 u'primary key (`id`,`name`)'
351 >>> Model()._normalize_column_definition('unique index (vid, url(32), sid)')
352 u'unique key `uq_vid_url_sid` (`vid`,`url`(32),`sid`)'
355 # self.log("_normalize_column_definition(%s)" % defn)
356 t = 'probe_table_%d' % random.randrange(0,9999999)
357 defn = defn.lower()
358 try:
359 columns = [ x.strip() for x in defn[defn.find('(')+1:defn.rfind(')')].split(',') ]
360 except IndexError:
361 columns = []
362 ret = []
363 is_compound_pk = (re.search('\S\sprimary key', defn) is not None)
364 if re.search("(?:index|key)\s*\(", defn) is not None:
365 prefix = "ix_"
366 if defn.find("unique") > -1:
367 prefix = "uq_"
368 name = prefix + "_".join([ re.sub("\(\d+\)","",x) for x in columns ]).replace("`","").replace("'","")
369 defn = re.sub("(?:index|key)", "key `%s`" % name, defn)
370 # if its a simple index column, then the probe table needs the indexed columns to be built
371 if _match_any(defn, ('unique ', 'key ', 'index ', 'fulltext ')) and not is_compound_pk:
372 if _match_any(defn, ('fulltext',)):
373 type = "text"
374 else:
375 type = "integer"
376 for col in columns:
377 col = col.replace(" ","")
378 # if the column being indexed specified a prefix, then it must refer to a varchar field that is long enough
379 save_type = type
380 open = col.find('(')
381 if open > -1:
382 close = col.rfind(')')
383 # build the new type with a long enough varchar
384 type = "varchar(%d)" % (int(col[open+1:close]) + 1)
385 # set the column name to be the name without the prefix
386 col = col[0:open]
387 # prepend this to the existing list of columns
388 defn = "%s %s, %s "%(col, type, defn)
389 type = save_type
390 # create the probe table
391 Query("create table %s ( %s )" % (t, defn))
392 # read back the table definition
393 table = _TableInfo(t)
394 Query("drop table %s" % t)
395 # pull out just the column text
396 for name in table.order:
397 line = table.names[name]
398 if name not in columns: # dont return columns that were generated to satisfy an index
399 ret.append(line)
400 # self.log("ADDING COLUMN: %s from %s not in %s" % (name, line, str(columns)))
401 # else:
402 # self.log("INGORING COLUMN: %s from %s in %s" % (name, line, str(columns)))
403 if is_compound_pk:
404 ret = [r.strip().lower() for r in ret]
405 else:
406 ret = ret[-1].strip().lower()
407 return ret
410 @staticmethod
411 def int_uid(from_string):
412 return Model.uid(from_string).__hash__()
413 @staticmethod
414 def uid(from_string):
415 return hashlib.sha1(from_string).hexdigest()
417 class Query(object):
419 A good non-dbapi db system should:
421 Gets its configuration from a db.cfg file, no connect params in code
423 db.cfg:
424 default="mysql://user:password@localhost/db1"
425 foo="mysql://user:password@localhost/foo"
427 1) Allow direct, named, query access:
429 >>> result = Query("select * from users where 1 = 0",dsn="test")
430 >>> type(result)
431 <class '__main__.Query'>
432 >>> len(result)
435 2) Support iteration:
437 >>> for row in Query("select * from users",dsn="test"):
438 ... print "uid: %d name: %s rowIndex: %d" % (int(row.uid), row.name, row.rowIndex)
439 uid: 1 name: anonymous rowIndex: 0
440 uid: 2 name: jldailey rowIndex: 1
442 3) Allow indexed access by columns in a wide variety of ways:
444 >>> result = Query("select uid, name from users",dsn="test")
445 >>> result[0].name
446 u'anonymous'
447 >>> result[0]['name']
448 u'anonymous'
450 3a) Using the name as the first indexor creates a projection of the data, and further indexers pull rows from that projection
452 >>> result['name'][0]
453 u'anonymous'
454 >>> result.name[0]
455 u'anonymous'
457 3b) Using the name as an attribute selects a single column
459 >>> result[1]['name']
460 u'jldailey'
461 >>> result[1].name
462 u'jldailey'
464 3c) You can see any individual Row object in its entirety
466 >>> type(result[0])
467 <class '__main__.Row'>
469 3d) You can pull out multi-dimensional projections of the data using comma-separated string indexes
471 >>> result[0]['uid, name']
472 [1L, u'anonymous']
474 >>> result[0:2]['uid, name']
475 [[1L, u'anonymous'], [2L, u'jldailey']]
477 >>> result[0:2].uid
478 [1L, 2L]
480 >>> for i in result[0:2].uid:
481 ... print i
485 >>> for i in result.uid:
486 ... print i
490 >>> for i in result['uid, name']:
491 ... print i
492 {Row:'uid': 1L,'name': u'anonymous'}
493 {Row:'uid': 2L,'name': u'jldailey'}
495 >>> result['uid, name']
496 (RowList:{Row:'uid': 1L,'name': u'anonymous'},{Row:'uid': 2L,'name': u'jldailey'})
497 >>> result['uid, name'][0]
498 {Row:'uid': 1L,'name': u'anonymous'}
499 >>> result['uid, name'][0:1]
500 (RowList:{Row:'uid': 1L,'name': u'anonymous'})
501 >>> result['uid, name'][0:2]
502 (RowList:{Row:'uid': 1L,'name': u'anonymous'},{Row:'uid': 2L,'name': u'jldailey'})
504 NOTE: The critical difference is whether you use names first: names select columns, numbers select rows.
505 If you select rows first, it returns you just the raw data in a tuple
506 If you select the columns first, it returns you real row objects
508 4) Allow query of queries
510 First do something to pull back a data set
512 >>> result = Query('''\
513 select u.name as user_name, r.name as role_name from users u\
514 left join user_roles ur on ur.uid = u.uid\
515 left join roles r on ur.rid = r.rid\
516 union\
517 select u.name as user_name, r.name as role_name from users u\
518 left join group_users gu on gu.uid = u.uid\
519 left join groups g on gu.gid = g.gid\
520 left join group_roles gr on gr.gid = g.gid\
521 left join roles r on r.rid = gr.rid\
522 ''',dsn="test")
523 >>> len(result)
525 >>> result[0:3]
526 (RowList:{Row:'user_name': u'anonymous','role_name': u'None'},{Row:'user_name': u'jldailey','role_name': u'None'},{Row:'user_name': u'jldailey','role_name': u'admin'})
528 Then execute queries against that resultset.
530 >>> result2 = result.Query("select distinct role_name from __self__")
531 >>> type(result2)
532 <class '__main__.Query'>
533 >>> len(result2)
535 >>> result2[0:2]
536 (RowList:u'None',u'admin')
538 >>> for i in result2:
539 ... print i
540 {Row:'role_name': u'None'}
541 {Row:'role_name': u'admin'}
543 >>> assert(u'admin' == result2[1].role_name)
544 >>> assert(u'None' == result2[0].role_name)
546 >>> count = result.Query("select count( distinct role_name ) from __self__")
547 >>> count[0][0]
548 u'2'
549 >>> count = result.Query("select count( distinct role_name ) as qty from __self__")
551 4a) All sub-queries support the same operations as real queries, except they lose type information (a temporary limitation of using pysqlite2 instead of 3, waiting for sqlite3/python2.5 to be fixed)
552 Everything that comes out of a sub-query will be a unicode string
554 >>> count[0].qty
555 u'2'
557 >>> int(count[0].qty)
560 4b) Query of queries can join 2 result sets together
562 >>> get_users = Query(name="get_users", sql="select * from users",dsn="test")
563 >>> get_groups = Query(name="get_groups", sql="select * from group_users",dsn="test")
564 >>> result = get_users.Query("select u.uid, u.name, r.gid from get_users u join get_groups r on r.uid = u.uid")
565 >>> result[0]
566 {Row:'uid': u'2','name': u'jldailey','gid': u'1'}
568 If you dont want to specify a name for the query, you can use the special token: __self__ in the query.
569 The __self__ token will be replaced by the name of the parent query.
570 You can also access a queries _name property to get the name of the backing table directly.
571 (If you dont specify a name, this will be an autogenerated value like: 'autoname_<40 character sha1 hash>')
573 >>> get_users = Query("select * from users", dsn="test")
574 >>> get_groups = Query("select * from group_users",dsn="test")
575 >>> result = get_users.Query("select u.uid, u.name, r.gid from __self__ u join %s r on r.uid = u.uid" % get_groups._name)
576 >>> result[0]
577 {Row:'uid': u'2','name': u'jldailey','gid': u'1'}
579 5) Get a log of all queries executed:
580 >>> for (sql, duration) in Query.getQueryLog():
581 ... pass # dont verify the output, since there is too much of it
583 5a) Clear the log of query data.
585 >>> Query.clearQueryLog()
589 _db = {}
590 _sqlite = {}
591 _cfg = None
592 _cfg_mtime = 0
593 _cache = {}
594 _db_cfg_regex = '(mysql)://(\w+):(\w*)@([^:]+):(\d+)/(\w+)'
595 @staticmethod
596 def _init_cfg():
597 if Query._cfg is None:
598 Query._read_cfg()
599 @staticmethod
600 def _read_cfg():
601 # read the cfg if modified
602 db_cfg = os.path.sep.join(__file__.split(os.path.sep)[:-1] + ["db.cfg"])
603 m = os.path.getmtime( db_cfg )
604 if m > Query._cfg_mtime:
605 Query._cfg_mtime = m
606 s = open(db_cfg).read()
607 Query._cfg = {}
608 for p in [p.split('=') for p in s.replace('\r','').split('\n') if len(p) > 0]:
609 Query._cfg[p[0]] = p[1]
610 if Query._db.get(p[0], None) is not None:
611 Query._db[p[0]].close()
612 Query._db[p[0]] = None
613 @staticmethod
614 def _update_connections(dsn):
615 """ Connect to the sqlite results datasource, and make sure that we have a mysql connection.
616 Returns the mysql connection.
618 tid = thread.get_ident()
619 if Query._sqlite.get(tid, None) is None:
620 Query._sqlite[tid] = apsw.Connection(":memory:")
621 if Query._db.get(tid,None) is None:
622 Query._db[tid] = {}
623 Query._read_cfg()
624 if Query._db[tid].get(dsn,None) is None: # instantiate the needed connections
625 try:
626 (prot, user, password, host, port, db) = re.findall(Query._db_cfg_regex, Query._cfg[dsn])[0]
627 port = int(port)
628 except KeyError:
629 raise Exception("DSN argument passed to Query (dsn='%s') was not found in the cfg (available: '%s')" % (db, str(Query._cfg.keys())))
630 except ValueError:
631 raise ValueError("DSN specified in db.cfg has non-integer port value: '%s'" % port)
633 try:
634 log("FastQuery: Connecting to db... %s %s %s" % (host,db,user))
635 Query._db[tid][dsn] = MySQLdb.connect(host=host, port=port, db=db, user=user, passwd=password)
636 log("FastQuery: Connected...")
637 except MySQLdb.OperationalError:
638 # try re-connecting without specifying the database
639 Query._db[tid][dsn] = MySQLdb.connect(host=host, user=user, passwd=password)
640 # then create it and connect to the empty db
641 Query._db[tid][dsn].cursor().execute("create database if not exists %s" % db)
642 Query._db[tid][dsn].close()
643 Query._db[tid][dsn] = MySQLdb.connect(host=host, db=db, user=user, passwd=password)
644 return Query._db[tid][dsn]
646 @staticmethod
647 def _close_connection(dsn):
648 tid = thread.get_ident()
649 cxt = Query._db.get(tid,{}).get(dsn,None)
650 if cxt is not None:
651 cxt.close()
652 Query._db[tid][dsn] = None
654 def __init__(self, sql=None, name=None, dsn="default", cursor=None, cache=False):
655 """ Query -
656 Can be instantiated in 2 ways: with a sql string, which will be executed against the remote data source
657 Or, with a cursor object that is already full of results. (.description and .fetchall() must work)
659 if name is None:
660 if sql is not None:
661 name = "autoname_%s" % hashlib.sha1(sql).hexdigest()
662 else:
663 name = "autoname_%s" % hashlib.sha1(str(random.randrange(0,99999999))).hexdigest()
664 self._name = name
665 self.dsn = dsn
666 self.currentRow = -1
667 self.rowCount = 0
668 # if we were called with just sql, this means we want to run statements directly on the real back end
669 if sql is not None and cursor is None:
670 self.sql = sql
671 if cache and Query._cache.get(sql, None) is not None:
672 start = datetime.datetime.now()
673 (other_name, self.columns, self.rowCount) = Query._cache[sql]
674 if self._name.startswith('autoname'):
675 # since the caller didnt specify a name, they dont care what the name is, so we can share a backing table with the cached data
676 self._name = other_name
677 else:
678 # since we are giving this data a new name, it needs a new backing table, in case we need to join against it
679 # TODO: a dict full of redirections would serve? Difficulty: any number of autoname queries could reference the same backing table, but once that table gets a real name, you have to start copying, because someone could issue an update statement on it, and change other result sets.
680 c = Query._sqlite_cursor()
681 c.setexectrace(self.log_sqlite)
682 c.execute("create table %s as select * from %s" % (self._name, other_name))
683 Query.logQuery("CACHED: "+sql, _ms_elapsed(start))
684 return
685 # read the cfg file if it has changed
686 Query._read_cfg()
687 # get a db connection
688 self.db = Query._update_connections(dsn)
689 # start a timer
690 start = datetime.datetime.now()
691 try:
692 # run the query
693 c = self.db.cursor()
694 c.execute(sql)
695 except Exception, e:
696 if e[0] == 2006:
697 self.log("LOST CONNECTION TO DB: RECONNECTING...")
698 Query._close_connection(self.dsn)
699 self.db = Query._update_connections(self.dsn)
700 c = self.db.cursor()
701 c.execute(sql)
702 else:
703 raise SQLException(str(e), sql)
704 Query.__init__(self, sql=sql, name=self._name, cursor=c)
705 # log the query
706 Query.logQuery(sql, _ms_elapsed(start))
707 elif cursor is not None: # otherwise, we are loading the raw way, using the result of some other query
708 if sql is not None:
709 self.sql = sql
710 try:
711 self.columns = cursor.description
712 self.rows = cursor.fetchall()
713 except AttributeError:
714 self.rows = []
715 try:
716 self.columns = cursor.getdescription()
717 for row in cursor:
718 self.rows.append(Row(cursor, row, index=len(self.rows)))
719 except apsw.ExecutionCompleteError:
720 self.columns = []
721 self.rowCount = len(self.rows)
722 c = Query._sqlite_cursor()
723 c.setexectrace(self.log_sqlite)
724 if self.columns is not None and len(self.columns) > 0:
725 create = "create table %s ( %s )" % (self._name, ','.join(["[%s] %s"%(x[0], self._get_sqlite_typename(x[1])) for x in self.columns]))
726 try:
727 c.execute(create)
728 except apsw.SchemaChangeError, e: # some version of apsw raise SchemaChangeError, some versions raise SQLError, so we have to catch both
729 s = str(e)
730 if s.endswith('already exists'):
731 pass
732 else:
733 raise
734 except apsw.SQLError, e:
735 s = str(e)
736 if s.endswith('already exists'):
737 pass
738 else:
739 raise
740 if len(self.rows) > 0:
741 max_rows_per = 499
742 num_inserts = int(len(self.rows)/max_rows_per) + 1
743 for n in range(num_inserts):
744 a = n*max_rows_per
745 b = min(len(self.rows),(n+1)*max_rows_per)
746 insert = u"insert or ignore into %s ( %s ) %s" % (
747 self._name,
748 u','.join([u"[%s]"%x[0] for x in self.columns]),
749 u' union all '.join([u'select '+u','.join([u"%s" % self._sqlrepr(s) for s in row]) for row in self.rows[a:b]]))
750 try:
751 c.execute(insert)
752 except Exception, e:
753 raise Exception("FAILED INSERT: (a: %s, b: %s len(self.rows[a:b]): %s)" % (a,b,len(self.rows[a:b])))
754 if sql.lower().startswith('select'):
755 Query._cache[sql] = (self._name, self.columns, self.rowCount)
757 def __del__(self):
758 c = Query._sqlite_cursor()
759 c.setexectrace(self.log_sqlite)
760 self.log_sqlite("IN __DEL__", None)
761 c.execute("drop table if exists %s" % self._name)
763 @staticmethod
764 def logQuery(sql, duration):
765 Query.log("[%.2f] %s" % (duration, sql))
766 if Query._sqlite is None: return
767 # prepare the query log
768 c2 = Query._sqlite_cursor()
769 c2.execute(u"create table if not exists query_log ( sql text not null, duration int )");
770 sql = sql.replace("'","''")
771 c2.execute("insert into query_log ( sql, duration ) values ( '%s', %f )" % (sql, duration))
773 @staticmethod
774 def getQueryLog():
775 if Query._sqlite is None: return
776 c2 = Query._sqlite_cursor()
777 c2.execute("select sql, duration from query_log order by rowid asc")
778 for (sql, duration) in c2:
779 yield (sql, duration)
781 @staticmethod
782 def clearQueryLog():
783 try:
784 c2 = Query._sqlite_cursor()
785 c2.execute("delete from query_log")
786 except Exception, e:
787 raise e
789 def _sqlrepr(self, o):
790 if type(o) == int:
791 return unicode(o)
792 else:
793 return u"'%s'" % str(o).decode('iso-8859-2').replace("'","''")
795 def _get_sqlite_typename(self, t):
796 if t == MySQLdb.NUMBER:
797 return "integer"
798 else:
799 return "text"
801 @staticmethod
802 def _update_sqlite():
803 tid = thread.get_ident()
804 if Query._sqlite[tid] is None:
805 Query._sqlite[tid] = apsw.Connection(":memory:")
806 @staticmethod
807 def _sqlite_cursor():
808 tid = thread.get_ident()
809 return Query._sqlite[tid].cursor()
811 def Query(self, sql=None, name=None):
813 Returns a new Query that selects from this result set as if it were a table.
814 - Can join to other resultsets by name.
816 c = Query._sqlite_cursor()
817 c.setexectrace(self.log_sqlite)
818 sql = sql.replace("__self__", self._name)
819 c.execute(sql)
820 return Query(name=name, sql=sql, cursor=c)
822 def __len__(self):
823 return self.rowCount
825 def __getitem__(self, key):
826 if isinstance(key, slice):
827 i = key.indices(len(self))
828 return self._rows("*", range(i[0],i[1],i[2]))
829 elif isinstance(key, str):
830 return self._rows(key, range(0,len(self)))
831 elif isinstance(key, int):
832 return self._rows("*", [key], single=True)
833 else:
834 raise ArgumentError("Unknown type: "+str(type(key)))
836 def __getattribute__(self, name):
837 try:
838 return object.__getattribute__(self, name)
839 except AttributeError:
840 try:
841 return self._rows(name, range(0,object.__getattribute__(self, "rowCount")))
842 except apsw.SQLError, e:
843 raise AttributeError(name)
845 def _rows(self, columns, keys, single=False, rowIndex=-1):
846 c = Query._sqlite_cursor()
847 c.setexectrace(self.log_sqlite)
848 ret = []
849 try:
850 for row in c.execute("select %s from %s where (rowid-1) in (%s)" % (columns, self._name, ','.join([str(x) for x in keys]))):
851 ret.append(Row(c,row, index=len(ret)))
852 except apsw.SQLError, e:
853 raise #apsw.SQLError("SQLite error. Original query that generated sqlitetable: %s SQLite error: %s" % (str(self.sql), str(e)))
854 if single:
855 try:
856 ret[0].rowIndex = rowIndex
857 return ret[0]
858 except IndexError:
859 raise IndexError("columns: %s keys: %s expected a result and got none" % (columns, keys))
860 else:
861 # r = RowList(ret)
862 # assert(len(r) == len(ret))
863 # return r
864 return RowList(ret)
866 def __iter__(self):
867 return self
869 def next(self):
870 self.currentRow += 1
871 if self.currentRow >= self.rowCount:
872 self.currentRow = -1
873 raise StopIteration
874 return self._rows("*",[self.currentRow, ], single=True, rowIndex=self.currentRow)
876 def __repr__(self):
877 return "{Query: %s backed by table: %s}" % (self.sql, self._name)
879 def __str__(self):
880 return str(self[0:len(self)])
882 def log_sqlite(self, sql, bindings,unknown=None):
883 # self.log("SQLITE: "+sql+" BINDINGS: "+str(bindings))
884 return True
886 @staticmethod
887 def log(msg):
888 if msg[-1] != '\n':
889 msg += '\n'
890 with open("db.log", "a+") as f:
891 f.write(msg)
893 class RowList(object):
894 def __init__(self, list):
895 """ Takes a list of Row objects """
896 self.__list = []
897 for i in list:
898 if i.columncount() == 1:
899 self.__list.append(i[0])
900 else:
901 self.__list.append(i)
902 assert(len(self.__list) == len(list))
903 def __getitem__(self, key):
904 if isinstance(key, int):
905 return self.__list[key]
906 elif isinstance(key, slice):
907 return RowList(self.__list[key])
908 else:
909 ret = [i.__getitem__(key) for i in self.__list]
910 return ret
911 def __getattribute__(self, key):
912 try:
913 return object.__getattribute__(self, key)
914 except AttributeError:
915 return [i.__getattribute__(key) for i in self.__list]
916 def __repr__(self):
917 return self.__str__()
918 def __str__(self):
919 return "(RowList:"+','.join([repr(x) for x in self.__list])+")"
920 def __len__(self):
921 return len(self.__list)
922 def __iter__(self):
923 return self.__list.__iter__()
925 class Row(object):
926 def __init__(self, cursor, row, index=None):
927 self.rowIndex = index
928 self.__c = []
929 self.__t = []
930 for col in cursor.getdescription():
931 self.__c.append(col[0])
932 self.__t.append(col[1])
933 self.__d = [(d if d is not None else 0) for d in row]
934 def columncount(self):
935 return len(self.__c)
936 def column(self, i):
937 return self.__c[i]
938 def __repr__(self):
939 return self.__str__()
940 def __str__(self):
941 return "{Row:"+','.join(["'%s': %s" % (self.__c[i],repr(self.__d[i])) for i in range(0, len(self.__c))])+"}"
942 def __getattribute__(self, key):
943 try:
944 return object.__getattribute__(self, key)
945 except AttributeError:
946 try:
947 i = self.__c.index(key)
948 return self.__d[i]
949 except ValueError:
950 raise AttributeError(key)
951 def __getitem__(self, key):
952 if isinstance(key, str):
953 try:
954 i = self.__c.index(key)
955 return self.__d[i]
956 except ValueError:
957 if key.find(',') > -1:
958 return [self.__getitem__(k.replace(' ','')) for k in key.split(',')]
959 raise KeyError(key)
960 elif isinstance(key, int):
961 return self.__d[key]
962 elif isinstance(key, slice):
963 return self.__d[key]
964 def dict(self):
965 d = {}
966 for i in range(0,len(self.__c)):
967 col = self.__c[i]
968 d[col] = self.__d[i]
969 return d
970 def __len__(self):
971 return 1
973 class FastQuery(object):
974 """ FastQuery is a way to shortcut all the bulky components in query that allow requerying, slicing, etc.
975 Still uses dsn="..." and the db.cfg file, except it only supports iterating the results
977 def __init__(self, sql, dsn="default"):
978 self.sql = sql
979 self.dsn = dsn
980 self.c = None
981 # get a db connection
982 db = Query._update_connections(dsn)
983 # execute the query
984 self.start = datetime.datetime.now()
985 self.c = db.cursor()
986 try:
987 self.c.execute(sql)
988 except Exception, e:
989 if e[0] == 2006:
990 Query._close_connection(self.dsn)
991 self.c = Query._update_connections(self.dsn).cursor()
992 self.c.execute(sql)
993 else:
994 raise SQLException(str(e), sql)
995 Query.logQuery("FAST: "+self.sql, _ms_elapsed(self.start))
996 def __iter__(self):
997 return self
998 def next(self):
999 row = self.c.fetchone()
1000 if row is None:
1001 raise StopIteration
1002 return [ (d if d is not None else 0) for d in row ]
1003 def __del__(self):
1004 self.c.close()
1006 def _getNameFromDefn(defn):
1008 >>> Model()._getNameFromDefn("`email` varchar(256) not null")
1009 'email'
1011 >>> Model()._getNameFromDefn("`email` text not null")
1012 'email'
1014 >>> Model()._getNameFromDefn("primary key (`foo`)")
1015 'primary key'
1017 >>> Model()._getNameFromDefn("index `foo` (`bar`)")
1018 'foo'
1020 >>> Model()._getNameFromDefn("index (`bar`)")
1021 'ix_bar'
1024 # primary keys have a constant name
1025 if re.search('^\s*primary key\s*\(', defn, re.I) is not None:
1026 return "primary key"
1027 # normal columns just get the name parsed off
1028 ret = defn.split("`")
1029 try:
1030 return ret[1].lower()
1031 except IndexError, e:
1032 raise IndexError("%s did not split on `" % defn)
1034 class _TableInfo:
1035 def __init__(self, name, dsn="default"):
1036 self.name = name
1037 self.names = {}
1038 self.lines = {}
1039 self.order = []
1040 self.text = ""
1041 try:
1042 self.text = Query("show create table %s" % name, dsn=dsn, cache=False)
1043 assert type(self.text) == Query, "Unexpected type: %s" % str(type(self.text))
1044 self.text = self.text[0]
1045 assert type(self.text) == Row, "Unexpected type: %s" % str(type(self.text))
1046 self.text = self.text[1]
1047 assert type(self.text) == unicode, "Unexpected type: %s" % str(type(self.text))
1048 self.text = self.text.lower()
1049 for line in self.text.split('\n')[1:-1]:
1050 name = _getNameFromDefn(line)# line.split('`')[1]
1051 name = re.sub( ',*\s*$', '', re.sub('^\s*','', name))
1052 line = re.sub( ',*\s*$', '', re.sub('^\s*','', line))
1053 self.names[name] = line
1054 self.lines[line] = name
1055 self.order.append(name)
1056 except MySQLdb.ProgrammingError:
1057 raise
1058 def getName(self, name):
1059 return self.names.get(name, None)
1060 def hasDefn(self, defn):
1061 return self.lines.get(defn, None)
1063 def safe(s):
1065 Returns a db-injection safe copy of the given string.
1067 # TODO: more filters
1068 return re.sub("([^'])'",r"\1''", s)
1070 class ModelConflictError(Exception):
1071 pass
1073 class SQLException(Exception):
1075 >>> Query("this is not a valid sql statement", dsn="test") #doctest: ELLIPSIS
1076 Traceback (most recent call last):
1078 SQLException: ...
1080 def __init__(self, msg, sql):
1081 Exception.__init__(self, msg)
1082 self.sql = sql
1083 def __str__(self):
1084 return Exception.__str__(self)+" (sql: %s)" % (self.sql)
1085 def __unicode__(self):
1086 return unicode(self.__str__())
1087 def __repr__(self):
1088 return self.__str__()
1090 def _ms_elapsed(since):
1091 d = (datetime.datetime.now() - since)
1092 return (d.seconds *1000.0) + (d.microseconds / 1000.0)
1094 def _match_any(s, l):
1095 for t in l:
1096 if s.find(t) > -1:
1097 return True
1099 if __name__ == "__main__":
1100 # run the tests the hard way, to force the order
1101 import doctest, unittest
1102 suite = unittest.TestSuite()
1103 finder = doctest.DocTestFinder()
1104 tests = []
1105 tests.extend(finder.find(Model))
1106 tests.extend(finder.find(Query))
1107 # print str(tests)
1109 runner = doctest.DocTestRunner()
1110 for test in tests:
1111 runner.run(test)