added serverSideCursors and blockIterators options
[pygr.git] / pygr / sqlgraph.py
blob46f2fe34ae80e38818c776f058562baf2ab47a41
3 from __future__ import generators
4 from mapping import *
5 from sequence import SequenceBase, DNA_SEQTYPE, RNA_SEQTYPE, PROTEIN_SEQTYPE
6 import types
7 from classutil import methodFactory,standard_getstate,\
8 override_rich_cmp,generate_items,get_bound_subclass,standard_setstate,\
9 get_valid_path,standard_invert,RecentValueDictionary,read_only_error,\
10 SourceFileName, split_kwargs
11 import os
12 import platform
13 import UserDict
14 import warnings
15 import logger
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self, db, attr):
20 self.icol = db.data[attr] # index of this attribute in the tuple
21 def __get__(self, obj, klass):
22 return obj._data[self.icol]
23 def __set__(self, obj, val):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor):
27 def __set__(self, obj, val):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor):
32 'read-write interface to named attribute'
33 def __init__(self, db, attr):
34 self.attr = attr
35 self.icol = db.data[attr] # index of this attribute in the tuple
36 self.attrSQL = db._attrSQL(attr, sqlColumn=True) # SQL column name
37 def __set__(self, obj, val):
38 obj.db._update(obj.id, self.attrSQL, val) # AND UPDATE THE DATABASE
39 obj.save_local(self.attr, val)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self, db, attr):
44 self.selectSQL = db._attrSQL(attr) # SQL expression for this attr
45 def __get__(self, obj, klass):
46 return obj._select(self.selectSQL)
47 def __set__(self, obj, val):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self, obj, val):
53 obj.db._update(obj.id, self.selectSQL, val) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self, db, attr):
58 self.attr = '_'+attr
59 def __get__(self, obj, klass):
60 return getattr(obj, self.attr)
61 def __set__(self, obj, val):
62 raise AttributeError('attribute %s is read-only!' % self.attr)
65 def select_from_row(row, what):
66 "return value from SQL expression applied to this row"
67 sql,params = row.db._format_query('select %s from %s where %s=%%s limit 2'
68 % (what,row.db.name,row.db.primary_key),
69 (row.id,))
70 row.db.cursor.execute(sql, params)
71 t = row.db.cursor.fetchmany(2) # get at most two rows
72 if len(t) != 1:
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row.db.name,str(row.id),what))
75 return t[0][0] #return the single field we requested
77 def init_row_subclass(cls, db):
78 'add descriptors for db attributes'
79 for attr in db.data: # bind all database columns
80 if attr == 'id': # handle ID attribute specially
81 setattr(cls, attr, cls._idDescriptor(db, attr))
82 continue
83 try: # check if this attr maps to an SQL column
84 db._attrSQL(attr, columnNumber=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls, attr, cls._sqlDescriptor(db, attr))
87 else: # treat as interface to our stored tuple
88 setattr(cls, attr, cls._columnDescriptor(db, attr))
90 def dir_row(self):
91 """get list of column names as our attributes """
92 return self.db.data.keys()
94 class TupleO(object):
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor = TupleDescriptor
115 _idDescriptor = TupleIDDescriptor
116 _sqlDescriptor = SQLDescriptor
117 _init_subclass = classmethod(init_row_subclass)
118 _select = select_from_row
119 __dir__ = dir_row
120 def __init__(self, data):
121 self._data = data # save our data tuple
123 def insert_and_cache_id(self, l, **kwargs):
124 'insert tuple into db and cache its rowID on self'
125 self.db._insert(l) # save to database
126 try:
127 rowID = kwargs['id'] # use the ID supplied by user
128 except KeyError:
129 rowID = self.db.get_insert_id() # get auto-inc ID value
130 self.cache_id(rowID) # cache this ID on self
132 class TupleORW(TupleO):
133 'read-write version of TupleO'
134 _columnDescriptor = TupleDescriptorRW
135 insert_and_cache_id = insert_and_cache_id
136 def __init__(self, data, newRow=False, **kwargs):
137 if not newRow: # just cache data from the database
138 self._data = data
139 return
140 self._data = self.db.tuple_from_dict(kwargs) # convert to tuple
141 self.insert_and_cache_id(self._data, **kwargs)
142 def cache_id(self,row_id):
143 self.save_local('id',row_id)
144 def save_local(self,attr,val):
145 icol = self._attrcol[attr]
146 try:
147 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self._data = list(self._data)
150 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
152 TupleO._RWClass = TupleORW # record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self, db, attr, readOnly = False):
157 self.attr = attr
158 self.col = db._attrSQL(attr, sqlColumn=True) # MAP THIS TO SQL COLUMN NAME
159 self.db = db
160 if readOnly:
161 self.__class__ = self._readOnlyClass
162 def __get__(self, obj, objtype):
163 try:
164 return obj.__dict__[self.attr]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self.col==self.db.primary_key:
167 raise AttributeError
168 self.db._select('where %s=%%s' % self.db.primary_key,(obj.id,),self.col)
169 l = self.db.cursor.fetchall()
170 if len(l)!=1:
171 raise AttributeError('db row not found or not unique!')
172 obj.__dict__[self.attr] = l[0][0] # UPDATE THE CACHE
173 return l[0][0]
174 def __set__(self, obj, val):
175 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self.db._update(obj.id, self.col, val) # UPDATE THE DATABASE
177 obj.__dict__[self.attr] = val # UPDATE THE CACHE
178 ## try:
179 ## m = self.consequences
180 ## except AttributeError:
181 ## return
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
185 ## import new
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor):
188 def __set__(self, obj, val):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor._readOnlyClass = ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor = _sqlDescriptor = SQLDescriptor
200 _idDescriptor = ReadOnlyDescriptor
201 _init_subclass = classmethod(init_row_subclass)
202 _select = select_from_row
203 __dir__ = dir_row
204 def __init__(self, rowID):
205 self._id = rowID
208 class SQLRowRW(SQLRow):
209 'read-write version of SQLRow'
210 _columnDescriptor = SQLDescriptorRW
211 insert_and_cache_id = insert_and_cache_id
212 def __init__(self, rowID, newRow=False, **kwargs):
213 if not newRow: # just cache data from the database
214 return self.cache_id(rowID)
215 l = self.db.tuple_from_dict(kwargs) # convert to tuple
216 self.insert_and_cache_id(l, **kwargs)
217 def cache_id(self, rowID):
218 self._id = rowID
220 SQLRow._RWClass = SQLRowRW
224 def list_to_dict(names, values):
225 'return dictionary of those named args that are present in values[]'
226 d={}
227 for i,v in enumerate(values):
228 try:
229 d[names[i]] = v
230 except IndexError:
231 break
232 return d
235 def get_name_cursor(name=None, **kwargs):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
239 if name is not None:
240 argList = name.split() # TREAT AS WS-SEPARATED LIST
241 if len(argList)>1:
242 name = argList[0] # USE 1ST ARG AS TABLE NAME
243 argnames = ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs = kwargs.copy() # a copy we can overwrite
245 kwargs.update(list_to_dict(argnames, argList[1:]))
246 serverInfo = DBServerInfo(**kwargs)
247 return name,serverInfo.cursor(),serverInfo
249 def mysql_connect(connect=None, configFile=None, useStreaming=False, **args):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs = args.copy() # a copy we can modify
252 if 'user' not in kwargs and configFile is None: #Find where config file is
253 osname = platform.system()
254 if osname in('Microsoft', 'Windows'): # Machine is a Windows box
255 paths = []
256 try: # handle case where WINDIR not defined by Windows...
257 windir = os.environ['WINDIR']
258 paths += [(windir, 'my.ini'), (windir, 'my.cnf')]
259 except KeyError:
260 pass
261 try:
262 sysdrv = os.environ['SYSTEMDRIVE']
263 paths += [(sysdrv, os.path.sep + 'my.ini'),
264 (sysdrv, os.path.sep + 'my.cnf')]
265 except KeyError:
266 pass
267 if len(paths) > 0:
268 configFile = get_valid_path(*paths)
269 else: # treat as normal platform with home directories
270 configFile = os.path.join(os.path.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile = configFile or os.path.join( os.getcwd(), 'mysql.cnf' )
276 if configFile and os.path.exists(configFile):
277 kwargs['read_default_file'] = configFile
278 connect = None # force it to use MySQLdb
279 if connect is None:
280 import MySQLdb
281 connect = MySQLdb.connect
282 kwargs['compress'] = True
283 if useStreaming: # use server side cursors for scalable result sets
284 try:
285 from MySQLdb import cursors
286 kwargs['cursorclass'] = cursors.SSCursor
287 except (ImportError, AttributeError):
288 pass
289 conn = connect(**kwargs)
290 cursor = conn.cursor()
291 return conn,cursor
293 _mysqlMacros = dict(IGNORE='ignore', REPLACE='replace',
294 AUTO_INCREMENT='AUTO_INCREMENT', SUBSTRING='substring',
295 SUBSTR_FROM='FROM', SUBSTR_FOR='FOR')
297 def mysql_table_schema(self, analyzeSchema=True):
298 'retrieve table schema from a MySQL database, save on self'
299 import MySQLdb
300 self._format_query = SQLFormatDict(MySQLdb.paramstyle, _mysqlMacros)
301 if not analyzeSchema:
302 return
303 self.clear_schema() # reset settings and dictionaries
304 self.cursor.execute('describe %s' % self.name) # get info about columns
305 columns = self.cursor.fetchall()
306 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
307 for icol,c in enumerate(columns):
308 field = c[0]
309 self.columnName.append(field) # list of columns in same order as table
310 if c[3] == "PRI": # record as primary key
311 if self.primary_key is None:
312 self.primary_key = field
313 else:
314 try:
315 self.primary_key.append(field)
316 except AttributeError:
317 self.primary_key = [self.primary_key,field]
318 if c[1][:3].lower() == 'int':
319 self.usesIntID = True
320 else:
321 self.usesIntID = False
322 elif c[3] == "MUL":
323 self.indexed[field] = icol
324 self.description[field] = self.cursor.description[icol]
325 self.columnType[field] = c[1] # SQL COLUMN TYPE
327 _sqliteMacros = dict(IGNORE='or ignore', REPLACE='insert or replace',
328 AUTO_INCREMENT='', SUBSTRING='substr',
329 SUBSTR_FROM=',', SUBSTR_FOR=',')
331 def import_sqlite():
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
333 try:
334 import sqlite3 as sqlite
335 except ImportError:
336 from pysqlite2 import dbapi2 as sqlite
337 return sqlite
339 def sqlite_table_schema(self, analyzeSchema=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite = import_sqlite()
342 self._format_query = SQLFormatDict(sqlite.paramstyle, _sqliteMacros)
343 if not analyzeSchema:
344 return
345 self.clear_schema() # reset settings and dictionaries
346 self.cursor.execute('PRAGMA table_info("%s")' % self.name)
347 columns = self.cursor.fetchall()
348 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
349 for icol,c in enumerate(columns):
350 field = c[1]
351 self.columnName.append(field) # list of columns in same order as table
352 self.description[field] = self.cursor.description[icol]
353 self.columnType[field] = c[2] # SQL COLUMN TYPE
354 self.cursor.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self.name) # get primary key / unique indexes
355 for indexname in self.cursor.fetchall(): # search indexes for primary key
356 self.cursor.execute('PRAGMA index_info("%s")' % indexname)
357 l = self.cursor.fetchall() # get list of columns in this index
358 if len(l) == 1: # assume 1st single-column unique index is primary key!
359 self.primary_key = l[0][2]
360 break # done searching for primary key!
361 if self.primary_key is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self.cursor.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self.name)
363 sql = self.cursor.fetchall()[0][0]
364 for columnSQL in sql[sql.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL.lower(): # must be the primary key!
366 col = columnSQL.split()[0] # get column name
367 if col in self.columnType:
368 self.primary_key = col
369 break # done searching for primary key!
370 else:
371 raise ValueError('unknown primary key %s in table %s'
372 % (col,self.name))
373 if self.primary_key is not None: # check its type
374 if self.columnType[self.primary_key] == 'int' or \
375 self.columnType[self.primary_key] == 'integer':
376 self.usesIntID = True
377 else:
378 self.usesIntID = False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats = dict(pyformat='%%(%d)s', numeric=':%d', named=':%d',
393 qmark='(ignore)', format='(ignore)')
394 def __init__(self, paramstyle, substitutionDict={}):
395 self.substitutionDict = substitutionDict.copy()
396 self.paramstyle = paramstyle
397 self.paramFormat = self._paramFormats[paramstyle]
398 self.makeDict = (paramstyle == 'pyformat' or paramstyle == 'named')
399 if paramstyle == 'qmark': # handle these as simple substitution
400 self.substitutionDict['?'] = '?'
401 elif paramstyle == 'format':
402 self.substitutionDict['?'] = '%s'
403 def __getitem__(self, k):
404 'apply correct substitution for this SQL interface'
405 try:
406 return self.substitutionDict[k] # apply our substitutions
407 except KeyError:
408 pass
409 if k == '?': # sequential parameter
410 s = self.paramFormat % self.iparam
411 self.iparam += 1 # advance to the next parameter
412 return s
413 raise KeyError('unknown macro: %s' % k)
414 def __call__(self, sql, paramList):
415 'returns corrected sql,params for this interface'
416 self.iparam = 1 # DB-ABI param indexing begins at 1
417 sql = sql.replace('%s', '%(?)s') # convert format into pyformat
418 s = sql % self # apply all %(x)s replacements in sql
419 if self.makeDict: # construct a params dict
420 paramDict = {}
421 for i,param in enumerate(paramList):
422 paramDict[str(i + 1)] = param #DB-ABI param indexing begins at 1
423 return s,paramDict
424 else: # just return the original params list
425 return s,paramList
427 def get_table_schema(self, analyzeSchema=True):
428 'run the right schema function based on type of db server connection'
429 try:
430 modname = self.cursor.__class__.__module__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
433 try:
434 schema_func = self._schemaModuleDict[modname]
435 except KeyError:
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname)
439 schema_func(self, analyzeSchema) # run the schema function
442 _schemaModuleDict = {'MySQLdb.cursors':mysql_table_schema,
443 'pysqlite2.dbapi2':sqlite_table_schema,
444 'sqlite3':sqlite_table_schema}
446 class SQLTableBase(object, UserDict.DictMixin):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict = _schemaModuleDict # default module list
449 get_table_schema = get_table_schema
450 def __init__(self,name,cursor=None,itemClass=None,attrAlias=None,
451 clusterKey=None,createTable=None,graph=None,maxCache=None,
452 arraysize=1024, itemSliceClass=None, dropIfExists=False,
453 serverInfo=None, autoGC=True, orderBy=None,
454 writeable=False, iterSQL=None, iterColumns=None, **kwargs):
455 if autoGC: # automatically garbage collect unused objects
456 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
457 else:
458 self._weakValueDict = {}
459 self.autoGC = autoGC
460 self.orderBy = orderBy
461 if orderBy and serverInfo and serverInfo._serverType == 'mysql':
462 if iterSQL and iterColumns: # both required for mysql!
463 self.iterSQL, self.iterColumns = iterSQL, iterColumns
464 else:
465 raise ValueError('For MySQL tables with orderBy, you MUST specify iterSQL and iterColumns as well!')
467 self.writeable = writeable
468 if cursor is None:
469 if serverInfo is not None: # get cursor from serverInfo
470 cursor = serverInfo.cursor()
471 else: # try to read connection info from name or config file
472 name,cursor,serverInfo = get_name_cursor(name,**kwargs)
473 else:
474 warnings.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
475 DeprecationWarning, stacklevel=2)
476 self.cursor = cursor
477 if createTable is not None: # RUN COMMAND TO CREATE THIS TABLE
478 if dropIfExists: # get rid of any existing table
479 cursor.execute('drop table if exists ' + name)
480 self.get_table_schema(False) # check dbtype, init _format_query
481 sql,params = self._format_query(createTable, ()) # apply macros
482 cursor.execute(sql) # create the table
483 self.name = name
484 if graph is not None:
485 self.graph = graph
486 if maxCache is not None:
487 self.maxCache = maxCache
488 if arraysize is not None:
489 self.arraysize = arraysize
490 cursor.arraysize = arraysize
491 self.get_table_schema() # get schema of columns to serve as attrs
492 self.data = {} # map of all attributes, including aliases
493 for icol,field in enumerate(self.columnName):
494 self.data[field] = icol # 1st add mappings to columns
495 try:
496 self.data['id']=self.data[self.primary_key]
497 except (KeyError,TypeError):
498 pass
499 if hasattr(self,'_attr_alias'): # apply attribute aliases for this class
500 self.addAttrAlias(False,**self._attr_alias)
501 self.objclass(itemClass) # NEED TO SUBCLASS OUR ITEM CLASS
502 if itemSliceClass is not None:
503 self.itemSliceClass = itemSliceClass
504 get_bound_subclass(self, 'itemSliceClass', self.name) # need to subclass itemSliceClass
505 if attrAlias is not None: # ADD ATTRIBUTE ALIASES
506 self.attrAlias = attrAlias # RECORD FOR PICKLING PURPOSES
507 self.data.update(attrAlias)
508 if clusterKey is not None:
509 self.clusterKey=clusterKey
510 if serverInfo is not None:
511 self.serverInfo = serverInfo
513 def __len__(self):
514 self._select(selectCols='count(*)')
515 return self.cursor.fetchone()[0]
516 def __hash__(self):
517 return id(self)
518 def __cmp__(self, other):
519 'only match self and no other!'
520 if self is other:
521 return 0
522 else:
523 return cmp(id(self), id(other))
524 _pickleAttrs = dict(name=0, clusterKey=0, maxCache=0, arraysize=0,
525 attrAlias=0, serverInfo=0, autoGC=0, orderBy=0,
526 writeable=0, iterSQL=0, iterColumns=0)
527 __getstate__ = standard_getstate
528 def __setstate__(self,state):
529 # default cursor provisioning by worldbase is deprecated!
530 ## if 'serverInfo' not in state: # hmm, no address for db server?
531 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
532 ## from Data import getResource
533 ## state['cursor'] = getResource.getTableCursor(state['name'])
534 ## except ImportError:
535 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
536 self.__init__(**state)
537 def __repr__(self):
538 return '<SQL table '+self.name+'>'
540 def clear_schema(self):
541 'reset all schema information for this table'
542 self.description={}
543 self.columnName = []
544 self.columnType = {}
545 self.usesIntID = None
546 self.primary_key = None
547 self.indexed = {}
548 def _attrSQL(self,attr,sqlColumn=False,columnNumber=False):
549 "Translate python attribute name to appropriate SQL expression"
550 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
551 field=self.data[attr]
552 except KeyError:
553 raise AttributeError('attribute %s not a valid column or alias in %s'
554 % (attr,self.name))
555 if sqlColumn: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
556 try: # CHECK IF field IS COLUMN NUMBER
557 return self.columnName[field] # RETURN SQL COLUMN NAME
558 except TypeError:
559 try: # CHECK IF field IS SQL COLUMN NAME
560 return self.columnName[self.data[field]] # THIS WILL JUST RETURN field...
561 except (KeyError,TypeError):
562 raise AttributeError('attribute %s does not map to an SQL column in %s'
563 % (attr,self.name))
564 if columnNumber:
565 try: # CHECK IF field IS A COLUMN NUMBER
566 return field+0 # ONLY RETURN AN INTEGER
567 except TypeError:
568 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
569 return self.data[field]+0 # ONLY RETURN AN INTEGER
570 except (KeyError,TypeError):
571 raise ValueError('attribute %s does not map to a SQL column!' % attr)
572 if isinstance(field,types.StringType):
573 attr=field # USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
574 elif attr=='id':
575 attr=self.primary_key
576 return attr
577 def addAttrAlias(self,saveToPickle=True,**kwargs):
578 """Add new attributes as aliases of existing attributes.
579 They can be specified either as named args:
580 t.addAttrAlias(newattr=oldattr)
581 or by passing a dictionary kwargs whose keys are newattr
582 and values are oldattr:
583 t.addAttrAlias(**kwargs)
584 saveToPickle=True forces these aliases to be saved if object is pickled.
586 if saveToPickle:
587 self.attrAlias.update(kwargs)
588 for key,val in kwargs.items():
589 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
590 self.data[val]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
591 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
592 except TypeError: # val IS ITSELF AN ALIAS
593 self.data[key] = self.data[val] # SO MAP TO WHAT IT MAPS TO
594 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
595 self.data[key] = val
596 def objclass(self,oclass=None):
597 "Create class representing a row in this table by subclassing oclass, adding data"
598 if oclass is not None: # use this as our base itemClass
599 self.itemClass = oclass
600 if self.writeable:
601 self.itemClass = self.itemClass._RWClass # use its writeable version
602 oclass = get_bound_subclass(self, 'itemClass', self.name,
603 subclassArgs=dict(db=self)) # bind itemClass
604 if issubclass(oclass, TupleO):
605 oclass._attrcol = self.data # BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
606 if hasattr(oclass,'_tableclass') and not isinstance(self,oclass._tableclass):
607 self.__class__=oclass._tableclass # ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
608 def _select(self, whereClause='', params=(), selectCols='t1.*',
609 cursor=None, orderBy='', limit=''):
610 'execute the specified query but do not fetch'
611 sql,params = self._format_query('select %s from %s t1 %s %s %s'
612 % (selectCols, self.name, whereClause, orderBy,
613 limit), params)
614 if cursor is None:
615 self.cursor.execute(sql, params)
616 else:
617 cursor.execute(sql, params)
618 def select(self,whereClause,params=None,oclass=None,selectCols='t1.*'):
619 "Generate the list of objects that satisfy the database SELECT"
620 if oclass is None:
621 oclass=self.itemClass
622 self._select(whereClause,params,selectCols)
623 l=self.cursor.fetchall()
624 for t in l:
625 yield self.cacheItem(t,oclass)
626 def query(self,**kwargs):
627 'query for intersection of all specified kwargs, returned as iterator'
628 criteria = []
629 params = []
630 for k,v in kwargs.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
631 if v is None: # CONVERT TO SQL NULL TEST
632 criteria.append('%s IS NULL' % self._attrSQL(k))
633 else: # TEST FOR EQUALITY
634 criteria.append('%s=%%s' % self._attrSQL(k))
635 params.append(v)
636 return self.select('where '+' and '.join(criteria),params)
637 def _update(self,row_id,col,val):
638 'update a single field in the specified row to the specified value'
639 sql,params = self._format_query('update %s set %s=%%s where %s=%%s'
640 %(self.name,col,self.primary_key),
641 (val,row_id))
642 self.cursor.execute(sql, params)
643 def getID(self,t):
644 try:
645 return t[self.data['id']] # GET ID FROM TUPLE
646 except TypeError: # treat as alias
647 return t[self.data[self.data['id']]]
648 def cacheItem(self,t,oclass):
649 'get obj from cache if possible, or construct from tuple'
650 try:
651 id=self.getID(t)
652 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
653 return oclass(t)
654 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
655 return self._weakValueDict[id]
656 except KeyError:
657 pass
658 o = oclass(t)
659 self._weakValueDict[id] = o # CACHE THIS ITEM IN OUR DICTIONARY
660 return o
661 def cache_items(self,rows,oclass=None):
662 if oclass is None:
663 oclass=self.itemClass
664 for t in rows:
665 yield self.cacheItem(t,oclass)
666 def foreignKey(self,attr,k):
667 'get iterator for objects with specified foreign key value'
668 return self.select('where %s=%%s'%attr,(k,))
669 def limit_cache(self):
670 'APPLY maxCache LIMIT TO CACHE SIZE'
671 try:
672 if self.maxCache<len(self._weakValueDict):
673 self._weakValueDict.clear()
674 except AttributeError:
675 pass
677 def get_new_cursor(self):
678 """Return a new cursor object, or None if not possible """
679 try:
680 new_cursor = self.serverInfo.new_cursor
681 except AttributeError:
682 return None
683 return new_cursor(self.arraysize)
685 def generic_iterator(self, cursor=None, fetch_f=None, cache_f=None,
686 map_f=iter, cursorHolder=None):
687 """generic iterator that runs fetch, cache and map functions.
688 cursorHolder is used only to keep a ref in this function's locals,
689 so that if it is prematurely terminated (by deleting its
690 iterator), cursorHolder.__del__() will close the cursor."""
691 if fetch_f is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
692 if cursor is None:
693 fetch_f = self.cursor.fetchmany
694 else: # isolate this iter from other queries
695 fetch_f = cursor.fetchmany
696 if cache_f is None:
697 cache_f = self.cache_items
698 while True:
699 self.limit_cache()
700 rows = fetch_f() # FETCH THE NEXT SET OF ROWS
701 if len(rows)==0: # NO MORE DATA SO ALL DONE
702 break
703 for v in map_f(cache_f(rows)): # CACHE AND GENERATE RESULTS
704 yield v
705 def tuple_from_dict(self, d):
706 'transform kwarg dict into tuple for storing in database'
707 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
708 for col,icol in self.data.items():
709 try:
710 l[icol] = d[col]
711 except (KeyError,TypeError):
712 pass
713 return l
714 def tuple_from_obj(self, obj):
715 'transform object attributes into tuple for storing in database'
716 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
717 for col,icol in self.data.items():
718 try:
719 l[icol] = getattr(obj,col)
720 except (AttributeError,TypeError):
721 pass
722 return l
723 def _insert(self, l):
724 '''insert tuple into the database. Note this uses the MySQL
725 extension REPLACE, which overwrites any duplicate key.'''
726 s = '%(REPLACE)s into ' + self.name + ' values (' \
727 + ','.join(['%s']*len(l)) + ')'
728 sql,params = self._format_query(s, l)
729 self.cursor.execute(sql, params)
730 def insert(self, obj):
731 '''insert new row by transforming obj to tuple of values'''
732 l = self.tuple_from_obj(obj)
733 self._insert(l)
734 def get_insert_id(self):
735 'get the primary key value for the last INSERT'
736 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
737 auto_id = self.cursor.lastrowid
738 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
739 raise NotImplementedError('''your db lacks lastrowid support?''')
740 if auto_id is None:
741 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
742 return auto_id
743 def new(self, **kwargs):
744 'return a new record with the assigned attributes, added to DB'
745 if not self.writeable:
746 raise ValueError('this database is read only!')
747 obj = self.itemClass(None, newRow=True, **kwargs) # saves itself to db
748 self._weakValueDict[obj.id] = obj # AND SAVE TO OUR LOCAL DICT CACHE
749 return obj
750 def clear_cache(self):
751 'empty the cache'
752 self._weakValueDict.clear()
753 def __delitem__(self, k):
754 if not self.writeable:
755 raise ValueError('this database is read only!')
756 sql,params = self._format_query('delete from %s where %s=%%s'
757 % (self.name,self.primary_key),(k,))
758 self.cursor.execute(sql, params)
759 try:
760 del self._weakValueDict[k]
761 except KeyError:
762 pass
764 def getKeys(self,queryOption='', selectCols=None):
765 'uses db select; does not force load'
766 if selectCols is None:
767 selectCols=self.primary_key
768 if queryOption=='' and self.orderBy is not None:
769 queryOption = self.orderBy # apply default ordering
770 self.cursor.execute('select %s from %s %s'
771 %(selectCols,self.name,queryOption))
772 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
774 def iter_keys(self, selectCols=None, orderBy='', map_f=iter,
775 cache_f=lambda x:[t[0] for t in x], get_f=None, **kwargs):
776 'guarantee correct iteration insulated from other queries'
777 if selectCols is None:
778 selectCols=self.primary_key
779 if orderBy=='' and self.orderBy is not None:
780 orderBy = self.orderBy # apply default ordering
781 cursor = self.get_new_cursor()
782 if cursor: # got our own cursor, guaranteeing query isolation
783 if hasattr(self.serverInfo, 'iter_keys') \
784 and self.serverInfo.custom_iter_keys:
785 # use custom iter_keys() method from serverInfo
786 return self.serverInfo.iter_keys(self, cursor, selectCols=selectCols,
787 map_f=map_f, orderBy=orderBy,
788 cache_f=cache_f, **kwargs)
789 else:
790 self._select(cursor=cursor, selectCols=selectCols,
791 orderBy=orderBy, **kwargs)
792 return self.generic_iterator(cursor=cursor, cache_f=cache_f,
793 map_f=map_f,
794 cursorHolder=CursorCloser(cursor))
795 else: # must pre-fetch all keys to ensure query isolation
796 if get_f is not None:
797 return iter(get_f())
798 else:
799 return iter(self.keys())
801 class SQLTable(SQLTableBase):
802 "Provide on-the-fly access to rows in the database, caching the results in dict"
803 itemClass = TupleO # our default itemClass; constructor can override
804 keys=getKeys
805 __iter__ = iter_keys
806 def load(self,oclass=None):
807 "Load all data from the table"
808 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
809 return self._isLoaded
810 except AttributeError:
811 pass
812 if oclass is None:
813 oclass=self.itemClass
814 self.cursor.execute('select * from %s' % self.name)
815 l=self.cursor.fetchall()
816 self._weakValueDict = {} # just store the whole dataset in memory
817 for t in l:
818 self.cacheItem(t,oclass) # CACHE IT IN LOCAL DICTIONARY
819 self._isLoaded=True # MARK THIS CONTAINER AS FULLY LOADED
821 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
822 try:
823 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
824 except KeyError: # NOT FOUND, SO TRY THE DATABASE
825 sql,params = self._format_query('select * from %s where %s=%%s limit 2'
826 % (self.name,self.primary_key),(k,))
827 self.cursor.execute(sql, params)
828 l = self.cursor.fetchmany(2) # get at most 2 rows
829 if len(l) != 1:
830 raise KeyError('%s not found in %s, or not unique' %(str(k),self.name))
831 self.limit_cache()
832 return self.cacheItem(l[0],self.itemClass) # CACHE IT IN LOCAL DICTIONARY
833 def __setitem__(self, k, v):
834 if not self.writeable:
835 raise ValueError('this database is read only!')
836 try:
837 if v.db is not self:
838 raise AttributeError
839 except AttributeError:
840 raise ValueError('object not bound to itemClass for this db!')
841 try:
842 oldID = v.id
843 if oldID is None:
844 raise AttributeError
845 except AttributeError:
846 pass
847 else: # delete row with old ID
848 del self[v.id]
849 v.cache_id(k) # cache the new ID on the object
850 self.insert(v) # SAVE TO THE RELATIONAL DB SERVER
851 self._weakValueDict[k] = v # CACHE THIS ITEM IN OUR DICTIONARY
852 def items(self):
853 'forces load of entire table into memory'
854 self.load()
855 return [(k,self[k]) for k in self] # apply orderBy rules...
856 def iteritems(self):
857 'uses arraysize / maxCache and fetchmany() to manage data transfer'
858 return iter_keys(self, selectCols='*', cache_f=None,
859 map_f=generate_items, get_f=self.items)
860 def values(self):
861 'forces load of entire table into memory'
862 self.load()
863 return [self[k] for k in self] # apply orderBy rules...
864 def itervalues(self):
865 'uses arraysize / maxCache and fetchmany() to manage data transfer'
866 return iter_keys(self, selectCols='*', cache_f=None, get_f=self.values)
868 def getClusterKeys(self,queryOption=''):
869 'uses db select; does not force load'
870 self.cursor.execute('select distinct %s from %s %s'
871 %(self.clusterKey,self.name,queryOption))
872 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
875 class SQLTableClustered(SQLTable):
876 '''use clusterKey to load a whole cluster of rows at once,
877 specifically, all rows that share the same clusterKey value.'''
878 def __init__(self, *args, **kwargs):
879 kwargs = kwargs.copy() # get a copy we can alter
880 kwargs['autoGC'] = False # don't use WeakValueDictionary
881 SQLTable.__init__(self, *args, **kwargs)
882 def keys(self):
883 return getKeys(self,'order by %s' %self.clusterKey)
884 def clusterkeys(self):
885 return getClusterKeys(self, 'order by %s' %self.clusterKey)
886 def __getitem__(self,k):
887 try:
888 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
889 except KeyError: # NOT FOUND, SO TRY THE DATABASE
890 sql,params = self._format_query('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
891 % (self.name,self.name,self.primary_key,
892 self.clusterKey,self.clusterKey),(k,))
893 self.cursor.execute(sql, params)
894 l=self.cursor.fetchall()
895 self.limit_cache()
896 for t in l: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
897 self.cacheItem(t,self.itemClass)
898 return self._weakValueDict[k] # should be in cache, if row k exists
899 def itercluster(self,cluster_id):
900 'iterate over all items from the specified cluster'
901 self.limit_cache()
902 return self.select('where %s=%%s'%self.clusterKey,(cluster_id,))
903 def fetch_cluster(self):
904 'use self.cursor.fetchmany to obtain all rows for next cluster'
905 icol = self._attrSQL(self.clusterKey,columnNumber=True)
906 result = []
907 try:
908 rows = self._fetch_cluster_cache # USE SAVED ROWS FROM PREVIOUS CALL
909 del self._fetch_cluster_cache
910 except AttributeError:
911 rows = self.cursor.fetchmany()
912 try:
913 cluster_id = rows[0][icol]
914 except IndexError:
915 return result
916 while len(rows)>0:
917 for i,t in enumerate(rows): # CHECK THAT ALL ROWS FROM THIS CLUSTER
918 if cluster_id != t[icol]: # START OF A NEW CLUSTER
919 result += rows[:i] # RETURN ROWS OF CURRENT CLUSTER
920 self._fetch_cluster_cache = rows[i:] # SAVE NEXT CLUSTER
921 return result
922 result += rows
923 rows = self.cursor.fetchmany() # GET NEXT SET OF ROWS
924 return result
925 def itervalues(self):
926 'uses arraysize / maxCache and fetchmany() to manage data transfer'
927 cursor = self.get_new_cursor()
928 self._select('order by %s' %self.clusterKey, cursor=cursor)
929 return self.generic_iterator(cursor, self.fetch_cluster,
930 cursorHolder=CursorHolder(cursor))
931 def iteritems(self):
932 'uses arraysize / maxCache and fetchmany() to manage data transfer'
933 cursor = self.get_new_cursor()
934 self._select('order by %s' %self.clusterKey, cursor=cursor)
935 return self.generic_iterator(cursor, self.fetch_cluster,
936 map_f=generate_items,
937 cursorHolder=CursorHolder(cursor))
939 class SQLForeignRelation(object):
940 'mapping based on matching a foreign key in an SQL table'
941 def __init__(self,table,keyName):
942 self.table=table
943 self.keyName=keyName
944 def __getitem__(self,k):
945 'get list of objects o with getattr(o,keyName)==k.id'
946 l=[]
947 for o in self.table.select('where %s=%%s'%self.keyName,(k.id,)):
948 l.append(o)
949 if len(l)==0:
950 raise KeyError('%s not found in %s' %(str(k),self.name))
951 return l
954 class SQLTableNoCache(SQLTableBase):
955 '''Provide on-the-fly access to rows in the database;
956 values are simply an object interface (SQLRow) to back-end db query.
957 Row data are not stored locally, but always accessed by querying the db'''
958 itemClass=SQLRow # DEFAULT OBJECT CLASS FOR ROWS...
959 keys=getKeys
960 __iter__ = iter_keys
961 def getID(self,t): return t[0] # GET ID FROM TUPLE
962 def select(self,whereClause,params):
963 return SQLTableBase.select(self,whereClause,params,self.oclass,
964 self._attrSQL('id'))
965 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
966 try:
967 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
968 except KeyError: # NOT FOUND, SO TRY THE DATABASE
969 self._select('where %s=%%s' % self.primary_key, (k,),
970 self.primary_key)
971 t = self.cursor.fetchmany(2)
972 if len(t) != 1:
973 raise KeyError('id %s non-existent or not unique' % k)
974 o = self.itemClass(k) # create obj referencing this ID
975 self._weakValueDict[k] = o # cache the SQLRow object
976 return o
977 def __setitem__(self, k, v):
978 if not self.writeable:
979 raise ValueError('this database is read only!')
980 try:
981 if v.db is not self:
982 raise AttributeError
983 except AttributeError:
984 raise ValueError('object not bound to itemClass for this db!')
985 try:
986 del self[k] # delete row with new ID if any
987 except KeyError:
988 pass
989 try:
990 del self._weakValueDict[v.id] # delete from old cache location
991 except KeyError:
992 pass
993 self._update(v.id, self.primary_key, k) # just change its ID in db
994 v.cache_id(k) # change the cached ID value
995 self._weakValueDict[k] = v # assign to new cache location
996 def addAttrAlias(self,**kwargs):
997 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
999 SQLRow._tableclass=SQLTableNoCache # SQLRow IS FOR NON-CACHING TABLE INTERFACE
1002 class SQLTableMultiNoCache(SQLTableBase):
1003 "Trivial on-the-fly access for table with key that returns multiple rows"
1004 itemClass = TupleO # default itemClass; constructor can override
1005 _distinct_key='id' # DEFAULT COLUMN TO USE AS KEY
1006 def keys(self):
1007 return getKeys(self, selectCols='distinct(%s)'
1008 % self._attrSQL(self._distinct_key))
1009 def __iter__(self):
1010 return iter_keys(self, 'distinct(%s)' % self._attrSQL(self._distinct_key))
1011 def __getitem__(self,id):
1012 sql,params = self._format_query('select * from %s where %s=%%s'
1013 %(self.name,self._attrSQL(self._distinct_key)),(id,))
1014 self.cursor.execute(sql, params)
1015 l=self.cursor.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1016 for row in l:
1017 yield self.itemClass(row)
1018 def addAttrAlias(self,**kwargs):
1019 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
1023 class SQLEdges(SQLTableMultiNoCache):
1024 '''provide iterator over edges as (source,target,edge)
1025 and getitem[edge] --> [(source,target),...]'''
1026 _distinct_key='edge_id'
1027 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1028 _pickleAttrs.update(dict(graph=0))
1029 def keys(self):
1030 self.cursor.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1031 %(self._attrSQL('source_id'),self._attrSQL('target_id'),
1032 self._attrSQL('edge_id'),self.name,
1033 self._attrSQL('target_id'),self._attrSQL('source_id'),
1034 self._attrSQL('target_id')))
1035 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1036 for source_id,target_id,edge_id in self.cursor.fetchall():
1037 l.append((self.graph.unpack_source(source_id),
1038 self.graph.unpack_target(target_id),
1039 self.graph.unpack_edge(edge_id)))
1040 return l
1041 __call__=keys
1042 def __iter__(self):
1043 return iter(self.keys())
1044 def __getitem__(self,edge):
1045 sql,params = self._format_query('select %s,%s from %s where %s=%%s'
1046 %(self._attrSQL('source_id'),
1047 self._attrSQL('target_id'),
1048 self.name,
1049 self._attrSQL(self._distinct_key)),
1050 (self.graph.pack_edge(edge),))
1051 self.cursor.execute(sql, params)
1052 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1053 for source_id,target_id in self.cursor.fetchall():
1054 l.append((self.graph.unpack_source(source_id),
1055 self.graph.unpack_target(target_id)))
1056 return l
1059 class SQLEdgeDict(object):
1060 '2nd level graph interface to SQL database'
1061 def __init__(self,fromNode,table):
1062 self.fromNode=fromNode
1063 self.table=table
1064 if not hasattr(self.table,'allowMissingNodes'):
1065 sql,params = self.table._format_query('select %s from %s where %s=%%s limit 1'
1066 %(self.table.sourceSQL,
1067 self.table.name,
1068 self.table.sourceSQL),
1069 (self.fromNode,))
1070 self.table.cursor.execute(sql, params)
1071 if len(self.table.cursor.fetchall())<1:
1072 raise KeyError('node not in graph!')
1074 def __getitem__(self,target):
1075 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1076 %(self.table.edgeSQL,
1077 self.table.name,
1078 self.table.sourceSQL,
1079 self.table.targetSQL),
1080 (self.fromNode,
1081 self.table.pack_target(target)))
1082 self.table.cursor.execute(sql, params)
1083 l = self.table.cursor.fetchmany(2) # get at most two rows
1084 if len(l) != 1:
1085 raise KeyError('either no edge from source to target or not unique!')
1086 try:
1087 return self.table.unpack_edge(l[0][0]) # RETURN EDGE
1088 except IndexError:
1089 raise KeyError('no edge from node to target')
1090 def __setitem__(self,target,edge):
1091 sql,params = self.table._format_query('replace into %s values (%%s,%%s,%%s)'
1092 %self.table.name,
1093 (self.fromNode,
1094 self.table.pack_target(target),
1095 self.table.pack_edge(edge)))
1096 self.table.cursor.execute(sql, params)
1097 if not hasattr(self.table,'sourceDB') or \
1098 (hasattr(self.table,'targetDB') and
1099 self.table.sourceDB is self.table.targetDB):
1100 self.table += target # ADD AS NODE TO GRAPH
1101 def __iadd__(self,target):
1102 self[target] = None
1103 return self # iadd MUST RETURN self!
1104 def __delitem__(self,target):
1105 sql,params = self.table._format_query('delete from %s where %s=%%s and %s=%%s'
1106 %(self.table.name,
1107 self.table.sourceSQL,
1108 self.table.targetSQL),
1109 (self.fromNode,
1110 self.table.pack_target(target)))
1111 self.table.cursor.execute(sql, params)
1112 if self.table.cursor.rowcount < 1: # no rows deleted?
1113 raise KeyError('no edge from node to target')
1115 def iterator_query(self):
1116 sql,params = self.table._format_query('select %s,%s from %s where %s=%%s and %s is not null'
1117 %(self.table.targetSQL,
1118 self.table.edgeSQL,
1119 self.table.name,
1120 self.table.sourceSQL,
1121 self.table.targetSQL),
1122 (self.fromNode,))
1123 self.table.cursor.execute(sql, params)
1124 return self.table.cursor.fetchall()
1125 def keys(self):
1126 return [self.table.unpack_target(target_id)
1127 for target_id,edge_id in self.iterator_query()]
1128 def values(self):
1129 return [self.table.unpack_edge(edge_id)
1130 for target_id,edge_id in self.iterator_query()]
1131 def edges(self):
1132 return [(self.table.unpack_source(self.fromNode),self.table.unpack_target(target_id),
1133 self.table.unpack_edge(edge_id))
1134 for target_id,edge_id in self.iterator_query()]
1135 def items(self):
1136 return [(self.table.unpack_target(target_id),self.table.unpack_edge(edge_id))
1137 for target_id,edge_id in self.iterator_query()]
1138 def __iter__(self): return iter(self.keys())
1139 def itervalues(self): return iter(self.values())
1140 def iteritems(self): return iter(self.items())
1141 def __len__(self):
1142 return len(self.keys())
1143 __cmp__ = graph_cmp
1145 class SQLEdgelessDict(SQLEdgeDict):
1146 'for SQLGraph tables that lack edge_id column'
1147 def __getitem__(self,target):
1148 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1149 %(self.table.targetSQL,
1150 self.table.name,
1151 self.table.sourceSQL,
1152 self.table.targetSQL),
1153 (self.fromNode,
1154 self.table.pack_target(target)))
1155 self.table.cursor.execute(sql, params)
1156 l = self.table.cursor.fetchmany(2)
1157 if len(l) != 1:
1158 raise KeyError('either no edge from source to target or not unique!')
1159 return None # no edge info!
1160 def iterator_query(self):
1161 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s is not null'
1162 %(self.table.targetSQL,
1163 self.table.name,
1164 self.table.sourceSQL,
1165 self.table.targetSQL),
1166 (self.fromNode,))
1167 self.table.cursor.execute(sql, params)
1168 return [(t[0],None) for t in self.table.cursor.fetchall()]
1170 SQLEdgeDict._edgelessClass = SQLEdgelessDict
1172 class SQLGraphEdgeDescriptor(object):
1173 'provide an SQLEdges interface on demand'
1174 def __get__(self,obj,objtype):
1175 try:
1176 attrAlias=obj.attrAlias.copy()
1177 except AttributeError:
1178 return SQLEdges(obj.name, obj.cursor, graph=obj)
1179 else:
1180 return SQLEdges(obj.name, obj.cursor, attrAlias=attrAlias,
1181 graph=obj)
1183 def getColumnTypes(createTable,attrAlias={},defaultColumnType='int',
1184 columnAttrs=('source','target','edge'),**kwargs):
1185 'return list of [(colname,coltype),...] for source,target,edge'
1186 l = []
1187 for attr in columnAttrs:
1188 try:
1189 attrName = attrAlias[attr+'_id']
1190 except KeyError:
1191 attrName = attr+'_id'
1192 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1193 l.append((attrName,createTable[attr+'_id']))
1194 continue
1195 except (KeyError,TypeError):
1196 pass
1197 try: # get type info from primary key for that database
1198 db = kwargs[attr+'DB']
1199 if db is None:
1200 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1201 except KeyError:
1202 pass
1203 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1204 it = iter(db)
1205 try: # GET ONE IDENTIFIER FROM THE DATABASE
1206 k = it.next()
1207 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1208 try:
1209 l.append((attrName,db.columnType[db.primary_key]))
1210 continue
1211 except AttributeError:
1212 pass
1213 else: # GET THE TYPE FROM THIS IDENTIFIER
1214 if isinstance(k,int) or isinstance(k,long):
1215 l.append((attrName,'int'))
1216 continue
1217 elif isinstance(k,str):
1218 l.append((attrName,'varchar(32)'))
1219 continue
1220 else:
1221 raise ValueError('SQLGraph node / edge must be int or str!')
1222 l.append((attrName,defaultColumnType))
1223 logger.warn('no type info found for %s, so using default: %s'
1224 % (attrName, defaultColumnType))
1227 return l
1230 class SQLGraph(SQLTableMultiNoCache):
1231 '''provide a graph interface via a SQL table. Key capabilities are:
1232 - setitem with an empty dictionary: a dummy operation
1233 - getitem with a key that exists: return a placeholder
1234 - setitem with non empty placeholder: again a dummy operation
1235 EXAMPLE TABLE SCHEMA:
1236 create table mygraph (source_id int not null,target_id int,edge_id int,
1237 unique(source_id,target_id));
1239 _distinct_key='source_id'
1240 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1241 _pickleAttrs.update(dict(sourceDB=0,targetDB=0,edgeDB=0,allowMissingNodes=0))
1242 _edgeClass = SQLEdgeDict
1243 def __init__(self,name,*l,**kwargs):
1244 graphArgs,tableArgs = split_kwargs(kwargs,
1245 ('attrAlias','defaultColumnType','columnAttrs',
1246 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1247 'edgeDictClass','graph'))
1248 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1249 c = getColumnTypes(**kwargs)
1250 tableArgs['createTable'] = \
1251 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1252 % (name,c[0][0],c[0][1],c[1][0],c[1][1],c[2][0],c[2][1],c[0][0],c[1][0])
1253 try:
1254 self.allowMissingNodes = kwargs['allowMissingNodes']
1255 except KeyError: pass
1256 SQLTableMultiNoCache.__init__(self,name,*l,**tableArgs)
1257 self.sourceSQL = self._attrSQL('source_id')
1258 self.targetSQL = self._attrSQL('target_id')
1259 try:
1260 self.edgeSQL = self._attrSQL('edge_id')
1261 except AttributeError:
1262 self.edgeSQL = None
1263 self._edgeClass = self._edgeClass._edgelessClass
1264 save_graph_db_refs(self,**kwargs)
1265 def __getitem__(self,k):
1266 return self._edgeClass(self.pack_source(k),self)
1267 def __iadd__(self,k):
1268 sql,params = self._format_query('delete from %s where %s=%%s and %s is null'
1269 % (self.name,self.sourceSQL,self.targetSQL),
1270 (self.pack_source(k),))
1271 self.cursor.execute(sql, params)
1272 sql,params = self._format_query('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1273 % self.name,(self.pack_source(k),))
1274 self.cursor.execute(sql, params)
1275 return self # iadd MUST RETURN SELF!
1276 def __isub__(self,k):
1277 sql,params = self._format_query('delete from %s where %s=%%s'
1278 % (self.name,self.sourceSQL),
1279 (self.pack_source(k),))
1280 self.cursor.execute(sql, params)
1281 if self.cursor.rowcount == 0:
1282 raise KeyError('node not found in graph')
1283 return self # iadd MUST RETURN SELF!
1284 __setitem__ = graph_setitem
1285 def __contains__(self,k):
1286 sql,params = self._format_query('select * from %s where %s=%%s limit 1'
1287 %(self.name,self.sourceSQL),
1288 (self.pack_source(k),))
1289 self.cursor.execute(sql, params)
1290 l = self.cursor.fetchmany(2)
1291 return len(l) > 0
1292 def __invert__(self):
1293 'get an interface to the inverse graph mapping'
1294 try: # CACHED
1295 return self._inverse
1296 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1297 attrAlias = dict(source_id=self.targetSQL, # SWAP SOURCE & TARGET
1298 target_id=self.sourceSQL,
1299 edge_id=self.edgeSQL)
1300 if self.edgeSQL is None: # no edge interface
1301 del attrAlias['edge_id']
1302 self._inverse=SQLGraph(self.name,self.cursor,
1303 attrAlias=attrAlias,
1304 **graph_db_inverse_refs(self))
1305 self._inverse._inverse=self
1306 return self._inverse
1307 def __iter__(self):
1308 for k in SQLTableMultiNoCache.__iter__(self):
1309 yield self.unpack_source(k)
1310 def iteritems(self):
1311 for k in SQLTableMultiNoCache.__iter__(self):
1312 yield (self.unpack_source(k), self._edgeClass(k, self))
1313 def itervalues(self):
1314 for k in SQLTableMultiNoCache.__iter__(self):
1315 yield self._edgeClass(k, self)
1316 def keys(self):
1317 return [self.unpack_source(k) for k in SQLTableMultiNoCache.keys(self)]
1318 def values(self): return list(self.itervalues())
1319 def items(self): return list(self.iteritems())
1320 edges=SQLGraphEdgeDescriptor()
1321 update = update_graph
1322 def __len__(self):
1323 'get number of source nodes in graph'
1324 self.cursor.execute('select count(distinct %s) from %s'
1325 %(self.sourceSQL,self.name))
1326 return self.cursor.fetchone()[0]
1327 __cmp__ = graph_cmp
1328 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1329 ## def __cmp__(self,other):
1330 ## node = ()
1331 ## n = 0
1332 ## d = None
1333 ## it = iter(self.edges)
1334 ## while True:
1335 ## try:
1336 ## source,target,edge = it.next()
1337 ## except StopIteration:
1338 ## source = None
1339 ## if source!=node:
1340 ## if d is not None:
1341 ## diff = cmp(n_target,len(d))
1342 ## if diff!=0:
1343 ## return diff
1344 ## if source is None:
1345 ## break
1346 ## node = source
1347 ## n += 1 # COUNT SOURCE NODES
1348 ## n_target = 0
1349 ## try:
1350 ## d = other[node]
1351 ## except KeyError:
1352 ## return 1
1353 ## try:
1354 ## diff = cmp(edge,d[target])
1355 ## except KeyError:
1356 ## return 1
1357 ## if diff!=0:
1358 ## return diff
1359 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1360 ## return cmp(n,len(other))
1362 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1364 class SQLIDGraph(SQLGraph):
1365 add_trivial_packing_methods(locals())
1366 SQLGraph._IDGraphClass = SQLIDGraph
1370 class SQLEdgeDictClustered(dict):
1371 'simple cache for 2nd level dictionary of target_id:edge_id'
1372 def __init__(self,g,fromNode):
1373 self.g=g
1374 self.fromNode=fromNode
1375 dict.__init__(self)
1376 def __iadd__(self,l):
1377 for target_id,edge_id in l:
1378 dict.__setitem__(self,target_id,edge_id)
1379 return self # iadd MUST RETURN SELF!
1381 class SQLEdgesClusteredDescr(object):
1382 def __get__(self,obj,objtype):
1383 e=SQLEdgesClustered(obj.table,obj.edge_id,obj.source_id,obj.target_id,
1384 graph=obj,**graph_db_inverse_refs(obj,True))
1385 for source_id,d in obj.d.iteritems(): # COPY EDGE CACHE
1386 e.load([(edge_id,source_id,target_id)
1387 for (target_id,edge_id) in d.iteritems()])
1388 return e
1390 class SQLGraphClustered(object):
1391 'SQL graph with clustered caching -- loads an entire cluster at a time'
1392 _edgeDictClass=SQLEdgeDictClustered
1393 def __init__(self,table,source_id='source_id',target_id='target_id',
1394 edge_id='edge_id',clusterKey=None,**kwargs):
1395 import types
1396 if isinstance(table,types.StringType): # CREATE THE TABLE INTERFACE
1397 if clusterKey is None:
1398 raise ValueError('you must provide a clusterKey argument!')
1399 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1400 c = getColumnTypes(attrAlias=dict(source_id=source_id,target_id=target_id,
1401 edge_id=edge_id),**kwargs)
1402 kwargs['createTable'] = \
1403 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1404 % (table,c[0][0],c[0][1],c[1][0],c[1][1],
1405 c[2][0],c[2][1],c[0][0],c[1][0])
1406 table = SQLTableClustered(table,clusterKey=clusterKey,**kwargs)
1407 self.table=table
1408 self.source_id=source_id
1409 self.target_id=target_id
1410 self.edge_id=edge_id
1411 self.d={}
1412 save_graph_db_refs(self,**kwargs)
1413 _pickleAttrs = dict(table=0,source_id=0,target_id=0,edge_id=0,sourceDB=0,targetDB=0,
1414 edgeDB=0)
1415 def __getstate__(self):
1416 state = standard_getstate(self)
1417 state['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1418 return state
1419 def __getitem__(self,k):
1420 'get edgeDict for source node k, from cache or by loading its cluster'
1421 try: # GET DIRECTLY FROM CACHE
1422 return self.d[k]
1423 except KeyError:
1424 if hasattr(self,'_isLoaded'):
1425 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1426 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1427 sql,params = self.table._format_query('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1428 %(self.source_id,self.target_id,
1429 self.edge_id,self.table.name,
1430 self.table.name,self.source_id,
1431 self.table.clusterKey,self.table.clusterKey,
1432 self.table.primary_key),
1433 (self.pack_source(k),))
1434 self.table.cursor.execute(sql, params)
1435 self.load(self.table.cursor.fetchall()) # CACHE THIS CLUSTER
1436 return self.d[k] # RETURN EDGE DICT FOR THIS NODE
1437 def load(self,l=None,unpack=True):
1438 'load the specified rows (or all, if None provided) into local cache'
1439 if l is None:
1440 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1441 return self._isLoaded
1442 except AttributeError:
1443 pass
1444 self.table.cursor.execute('select %s,%s,%s from %s'
1445 %(self.source_id,self.target_id,
1446 self.edge_id,self.table.name))
1447 l=self.table.cursor.fetchall()
1448 self._isLoaded=True
1449 self.d.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1450 for source,target,edge in l: # SAVE TO OUR CACHE
1451 if unpack:
1452 source = self.unpack_source(source)
1453 target = self.unpack_target(target)
1454 edge = self.unpack_edge(edge)
1455 try:
1456 self.d[source] += [(target,edge)]
1457 except KeyError:
1458 d = self._edgeDictClass(self,source)
1459 d += [(target,edge)]
1460 self.d[source] = d
1461 def __invert__(self):
1462 'interface to reverse graph mapping'
1463 try:
1464 return self._inverse # INVERSE MAP ALREADY EXISTS
1465 except AttributeError:
1466 pass
1467 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1468 self._inverse=SQLGraphClustered(self.table,self.target_id,self.source_id,
1469 self.edge_id,**graph_db_inverse_refs(self))
1470 self._inverse._inverse=self
1471 for source,d in self.d.iteritems(): # INVERT OUR CACHE
1472 self._inverse.load([(target,source,edge)
1473 for (target,edge) in d.iteritems()],unpack=False)
1474 return self._inverse
1475 edges=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1476 update = update_graph
1477 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1478 def __iter__(self): ################# ITERATORS
1479 'uses db select; does not force load'
1480 return iter(self.keys())
1481 def keys(self):
1482 'uses db select; does not force load'
1483 self.table.cursor.execute('select distinct(%s) from %s'
1484 %(self.source_id,self.table.name))
1485 return [self.unpack_source(t[0])
1486 for t in self.table.cursor.fetchall()]
1487 methodFactory(['iteritems','items','itervalues','values'],
1488 'lambda self:(self.load(),self.d.%s())[1]',locals())
1489 def __contains__(self,k):
1490 try:
1491 x=self[k]
1492 return True
1493 except KeyError:
1494 return False
1496 class SQLIDGraphClustered(SQLGraphClustered):
1497 add_trivial_packing_methods(locals())
1498 SQLGraphClustered._IDGraphClass = SQLIDGraphClustered
1500 class SQLEdgesClustered(SQLGraphClustered):
1501 'edges interface for SQLGraphClustered'
1502 _edgeDictClass = list
1503 _pickleAttrs = SQLGraphClustered._pickleAttrs.copy()
1504 _pickleAttrs.update(dict(graph=0))
1505 def keys(self):
1506 self.load()
1507 result = []
1508 for edge_id,l in self.d.iteritems():
1509 for source_id,target_id in l:
1510 result.append((self.graph.unpack_source(source_id),
1511 self.graph.unpack_target(target_id),
1512 self.graph.unpack_edge(edge_id)))
1513 return result
1515 class ForeignKeyInverse(object):
1516 'map each key to a single value according to its foreign key'
1517 def __init__(self,g):
1518 self.g = g
1519 def __getitem__(self,obj):
1520 self.check_obj(obj)
1521 source_id = getattr(obj,self.g.keyColumn)
1522 if source_id is None:
1523 return None
1524 return self.g.sourceDB[source_id]
1525 def __setitem__(self,obj,source):
1526 self.check_obj(obj)
1527 if source is not None:
1528 self.g[source][obj] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1529 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1530 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1531 old_source = self[obj]
1532 if old_source is not None:
1533 del self.g[old_source][obj]
1534 def check_obj(self,obj):
1535 'raise KeyError if obj not from this db'
1536 try:
1537 if obj.db is not self.g.targetDB:
1538 raise AttributeError
1539 except AttributeError:
1540 raise KeyError('key is not from targetDB of this graph!')
1541 def __contains__(self,obj):
1542 try:
1543 self.check_obj(obj)
1544 return True
1545 except KeyError:
1546 return False
1547 def __iter__(self):
1548 return self.g.targetDB.itervalues()
1549 def keys(self):
1550 return self.g.targetDB.values()
1551 def iteritems(self):
1552 for obj in self:
1553 source_id = getattr(obj,self.g.keyColumn)
1554 if source_id is None:
1555 yield obj,None
1556 else:
1557 yield obj,self.g.sourceDB[source_id]
1558 def items(self):
1559 return list(self.iteritems())
1560 def itervalues(self):
1561 for obj,val in self.iteritems():
1562 yield val
1563 def values(self):
1564 return list(self.itervalues())
1565 def __invert__(self):
1566 return self.g
1568 class ForeignKeyEdge(dict):
1569 '''edge interface to a foreign key in an SQL table.
1570 Caches dict of target nodes in itself; provides dict interface.
1571 Adds or deletes edges by setting foreign key values in the table'''
1572 def __init__(self,g,k):
1573 dict.__init__(self)
1574 self.g = g
1575 self.src = k
1576 for v in g.targetDB.select('where %s=%%s' % g.keyColumn,(k.id,)): # SEARCH THE DB
1577 dict.__setitem__(self,v,None) # SAVE IN CACHE
1578 def __setitem__(self,dest,v):
1579 if not hasattr(dest,'db') or dest.db is not self.g.targetDB:
1580 raise KeyError('dest is not in the targetDB bound to this graph!')
1581 if v is not None:
1582 raise ValueError('sorry,this graph cannot store edge information!')
1583 if not hasattr(dest,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1584 old_source = self.g._inverse[dest] # CHECK FOR PRE-EXISTING EDGE
1585 if old_source is not None: # REMOVE OLD EDGE FROM CACHE
1586 dict.__delitem__(self.g[old_source],dest)
1587 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1588 setattr(dest,self.g.keyColumn,self.src.id) # SAVE TO DB ATTRIBUTE
1589 dict.__setitem__(self,dest,None) # SAVE IN CACHE
1590 def __delitem__(self,dest):
1591 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1592 setattr(dest,self.g.keyColumn,None) # SAVE TO DB ATTRIBUTE
1593 dict.__delitem__(self,dest) # REMOVE FROM CACHE
1595 class ForeignKeyGraph(object, UserDict.DictMixin):
1596 '''graph interface to a foreign key in an SQL table
1597 Caches dict of target nodes in itself; provides dict interface.
1599 def __init__(self, sourceDB, targetDB, keyColumn, autoGC=True, **kwargs):
1600 '''sourceDB is any database of source nodes;
1601 targetDB must be an SQL database of target nodes;
1602 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1603 if autoGC: # automatically garbage collect unused objects
1604 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
1605 else:
1606 self._weakValueDict = {}
1607 self.autoGC = autoGC
1608 self.sourceDB = sourceDB
1609 self.targetDB = targetDB
1610 self.keyColumn = keyColumn
1611 self._inverse = ForeignKeyInverse(self)
1612 _pickleAttrs = dict(sourceDB=0, targetDB=0, keyColumn=0, autoGC=0)
1613 __getstate__ = standard_getstate ########### SUPPORT FOR PICKLING
1614 __setstate__ = standard_setstate
1615 def _inverse_schema(self):
1616 'provide custom schema rule for inverting this graph... just use keyColumn!'
1617 return dict(invert=True,uniqueMapping=True)
1618 def __getitem__(self,k):
1619 if not hasattr(k,'db') or k.db is not self.sourceDB:
1620 raise KeyError('object is not in the sourceDB bound to this graph!')
1621 try:
1622 return self._weakValueDict[k.id] # get from cache
1623 except KeyError:
1624 pass
1625 d = ForeignKeyEdge(self,k)
1626 self._weakValueDict[k.id] = d # save in cache
1627 return d
1628 def __setitem__(self, k, v):
1629 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1630 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1631 def __delitem__(self, k):
1632 raise KeyError('''Instead of del g[k], follow a graph
1633 interface: del g[src][dest]''')
1634 def keys(self):
1635 return self.sourceDB.values()
1636 __invert__ = standard_invert
1638 def describeDBTables(name,cursor,idDict):
1640 Get table info about database <name> via <cursor>, and store primary keys
1641 in idDict, along with a list of the tables each key indexes.
1643 cursor.execute('use %s' % name)
1644 cursor.execute('show tables')
1645 tables={}
1646 l=[c[0] for c in cursor.fetchall()]
1647 for t in l:
1648 tname=name+'.'+t
1649 o=SQLTable(tname,cursor)
1650 tables[tname]=o
1651 for f in o.description:
1652 if f==o.primary_key:
1653 idDict.setdefault(f, []).append(o)
1654 elif f[-3:]=='_id' and f not in idDict:
1655 idDict[f]=[]
1656 return tables
1660 def indexIDs(tables,idDict=None):
1661 "Get an index of primary keys in the <tables> dictionary."
1662 if idDict==None:
1663 idDict={}
1664 for o in tables.values():
1665 if o.primary_key:
1666 if o.primary_key not in idDict:
1667 idDict[o.primary_key]=[]
1668 idDict[o.primary_key].append(o) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1669 for f in o.description:
1670 if f[-3:]=='_id' and f not in idDict:
1671 idDict[f]=[]
1672 return idDict
1676 def suffixSubset(tables,suffix):
1677 "Filter table index for those matching a specific suffix"
1678 subset={}
1679 for name,t in tables.items():
1680 if name.endswith(suffix):
1681 subset[name]=t
1682 return subset
1685 PRIMARY_KEY=1
1687 def graphDBTables(tables,idDict):
1688 g=dictgraph()
1689 for t in tables.values():
1690 for f in t.description:
1691 if f==t.primary_key:
1692 edgeInfo=PRIMARY_KEY
1693 else:
1694 edgeInfo=None
1695 g.setEdge(f,t,edgeInfo)
1696 g.setEdge(t,f,edgeInfo)
1697 return g
1699 SQLTypeTranslation= {types.StringType:'varchar(32)',
1700 types.IntType:'int',
1701 types.FloatType:'float'}
1703 def createTableFromRepr(rows,tableName,cursor,typeTranslation=None,
1704 optionalDict=None,indexDict=()):
1705 """Save rows into SQL tableName using cursor, with optional
1706 translations of columns to specific SQL types (specified
1707 by typeTranslation dict).
1708 - optionDict can specify columns that are allowed to be NULL.
1709 - indexDict can specify columns that must be indexed; columns
1710 whose names end in _id will be indexed by default.
1711 - rows must be an iterator which in turn returns dictionaries,
1712 each representing a tuple of values (indexed by their column
1713 names).
1715 try:
1716 row=rows.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1717 except StopIteration:
1718 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1719 try:
1720 createTableFromRow(cursor, tableName,row,typeTranslation,
1721 optionalDict,indexDict)
1722 except:
1723 pass
1724 storeRow(cursor,tableName,row) # SAVE OUR FIRST ROW
1725 for row in rows: # NOW SAVE ALL THE ROWS
1726 storeRow(cursor,tableName,row)
1728 def createTableFromRow(cursor, tableName, row,typeTranslation=None,
1729 optionalDict=None,indexDict=()):
1730 create_defs=[]
1731 for col,val in row.items(): # PREPARE SQL TYPES FOR COLUMNS
1732 coltype=None
1733 if typeTranslation!=None and col in typeTranslation:
1734 coltype=typeTranslation[col] # USER-SUPPLIED TRANSLATION
1735 elif type(val) in SQLTypeTranslation:
1736 coltype=SQLTypeTranslation[type(val)]
1737 else: # SEARCH FOR A COMPATIBLE TYPE
1738 for t in SQLTypeTranslation:
1739 if isinstance(val,t):
1740 coltype=SQLTypeTranslation[t]
1741 break
1742 if coltype==None:
1743 raise TypeError("Don't know SQL type to use for %s" % col)
1744 create_def='%s %s' %(col,coltype)
1745 if optionalDict==None or col not in optionalDict:
1746 create_def+=' not null'
1747 create_defs.append(create_def)
1748 for col in row: # CREATE INDEXES FOR ID COLUMNS
1749 if col[-3:]=='_id' or col in indexDict:
1750 create_defs.append('index(%s)' % col)
1751 cmd='create table if not exists %s (%s)' % (tableName,','.join(create_defs))
1752 cursor.execute(cmd) # CREATE THE TABLE IN THE DATABASE
1755 def storeRow(cursor, tableName, row):
1756 row_format=','.join(len(row)*['%s'])
1757 cmd='insert into %s values (%s)' % (tableName,row_format)
1758 cursor.execute(cmd,tuple(row.values()))
1760 def storeRowDelayed(cursor, tableName, row):
1761 row_format=','.join(len(row)*['%s'])
1762 cmd='insert delayed into %s values (%s)' % (tableName,row_format)
1763 cursor.execute(cmd,tuple(row.values()))
1766 class TableGroup(dict):
1767 'provide attribute access to dbname qualified tablenames'
1768 def __init__(self,db='test',suffix=None,**kw):
1769 dict.__init__(self)
1770 self.db=db
1771 if suffix is not None:
1772 self.suffix=suffix
1773 for k,v in kw.items():
1774 if v is not None and '.' not in v:
1775 v=self.db+'.'+v # ADD DATABASE NAME AS PREFIX
1776 self[k]=v
1777 def __getattr__(self,k):
1778 return self[k]
1780 def sqlite_connect(*args, **kwargs):
1781 sqlite = import_sqlite()
1782 connection = sqlite.connect(*args, **kwargs)
1783 cursor = connection.cursor()
1784 return connection, cursor
1786 class DBServerInfo(object):
1787 'picklable reference to a database server'
1788 def __init__(self, moduleName='MySQLdb', serverSideCursors=True,
1789 blockIterators=True, *args, **kwargs):
1790 try:
1791 self.__class__ = _DBServerModuleDict[moduleName]
1792 except KeyError:
1793 raise ValueError('Module name not found in _DBServerModuleDict: '\
1794 + moduleName)
1795 self.moduleName = moduleName
1796 self.args = args # connection arguments
1797 self.kwargs = kwargs
1798 self.serverSideCursors = serverSideCursors
1799 self.custom_iter_keys = blockIterators
1800 if self.serverSideCursors and not self.custom_iter_keys:
1801 raise ValueError('serverSideCursors=True requires blockIterators=True!')
1803 def cursor(self):
1804 """returns cursor associated with the DB server info (reused)"""
1805 try:
1806 return self._cursor
1807 except AttributeError:
1808 self._start_connection()
1809 return self._cursor
1811 def new_cursor(self, arraysize=None):
1812 """returns a NEW cursor; you must close it yourself! """
1813 if not hasattr(self, '_connection'):
1814 self._start_connection()
1815 cursor = self._connection.cursor()
1816 if arraysize is not None:
1817 cursor.arraysize = arraysize
1818 return cursor
1820 def close(self):
1821 """Close file containing this database"""
1822 self._cursor.close()
1823 self._connection.close()
1824 del self._cursor
1825 del self._connection
1827 def __getstate__(self):
1828 """return all picklable arguments"""
1829 return dict(args=self.args, kwargs=self.kwargs,
1830 moduleName=self.moduleName,
1831 serverSideCursors=self.serverSideCursors,
1832 blockIterators=self.custom_iter_keys)
1834 def __setstate__(self, moduleName, serverSideCursors, blockIterators,
1835 args, kwargs):
1836 self.__init__(moduleName, serverSideCursors=serverSideCursors,
1837 blockIterators=blockIterators, *args, **kwargs)
1840 class MySQLServerInfo(DBServerInfo):
1841 'customized for MySQLdb SSCursor support via new_cursor()'
1842 _serverType = 'mysql'
1843 def _start_connection(self):
1844 self._connection,self._cursor = mysql_connect(*self.args, **self.kwargs)
1845 def new_cursor(self, arraysize=None):
1846 'provide streaming cursor support'
1847 if not self.serverSideCursors: # use regular MySQLdb cursor
1848 return DBServerInfo.new_cursor(self, arraysize)
1849 try:
1850 conn = self._conn_sscursor
1851 except AttributeError:
1852 self._conn_sscursor,cursor = mysql_connect(useStreaming=True,
1853 *self.args, **self.kwargs)
1854 else:
1855 cursor = self._conn_sscursor.cursor()
1856 if arraysize is not None:
1857 cursor.arraysize = arraysize
1858 return cursor
1859 def close(self):
1860 DBServerInfo.close(self)
1861 try:
1862 self._conn_sscursor.close()
1863 del self._conn_sscursor
1864 except AttributeError:
1865 pass
1866 def iter_keys(self, db, cursor, map_f=iter,
1867 cache_f=lambda x:[t[0] for t in x], **kwargs):
1868 block_generator = BlockGenerator(db, self, cursor, **kwargs)
1869 return db.generic_iterator(cursor=cursor, cache_f=cache_f,
1870 map_f=map_f, fetch_f=block_generator)
1872 class CursorCloser(object):
1873 """container for ensuring cursor.close() is called, when this obj deleted.
1874 For Python 2.5+, we could replace this with a try... finally clause
1875 in a generator function such as generic_iterator(); see PEP 342 or
1876 What's New in Python 2.5. """
1877 def __init__(self, cursor):
1878 self.cursor = cursor
1879 def __del__(self):
1880 self.cursor.close()
1882 class BlockGenerator(CursorCloser):
1883 'workaround for MySQLdb iteration horrible performance'
1884 def __init__(self, db, serverInfo, cursor, whereClause='', **kwargs):
1885 self.db = db
1886 self.serverInfo = serverInfo
1887 self.cursor = cursor
1888 self.kwargs = kwargs
1889 self.whereClause = ''
1890 if kwargs['orderBy']: # use iterSQL/iterColumns for WHERE / SELECT
1891 self.whereSQL = db.iterSQL
1892 if kwargs['selectCols'] == db.primary_key: # extract iterColumns
1893 self.whereColumns = ','.join(db.iterColumns) # required!!
1894 else: # extracting all columns
1895 self.whereParams = [db.data[col] for col in db.iterColumns]
1896 else: # just use primary key
1897 self.whereSQL = 'WHERE %s>%%s' % self.db.primary_key
1898 self.whereParams = (db.data['id'],)
1899 self.params = ()
1900 self.done = False
1902 def __call__(self):
1903 'get the next block of data'
1904 if self.done:
1905 return ()
1906 self.db._select(self.whereClause, self.params, cursor=self.cursor,
1907 limit='LIMIT %s' % self.cursor.arraysize, **(self.kwargs))
1908 rows = self.cursor.fetchall()
1909 if len(rows) < self.cursor.arraysize: # iteration complete
1910 self.done = True
1911 return rows
1912 lastrow = rows[-1] # extract params from the last row in this block
1913 if len(lastrow) > 1:
1914 self.params = [lastrow[icol] for icol in self.whereParams]
1915 else:
1916 try: # get whereColumns values for last row
1917 self.db._select('WHERE %s=%%s' % self.db.primary_key,
1918 lastrow, self.whereColumns, self.cursor)
1919 except AttributeError:
1920 self.params = lastrow
1921 else:
1922 self.params = self.cursor.fetchall()[0]
1923 self.whereClause = self.whereSQL
1924 return rows
1928 class SQLiteServerInfo(DBServerInfo):
1929 """picklable reference to a sqlite database"""
1930 _serverType = 'sqlite'
1931 def __init__(self, database, *args, **kwargs):
1932 """Takes same arguments as sqlite3.connect()"""
1933 DBServerInfo.__init__(self, 'sqlite', # save abs path!
1934 database=SourceFileName(database),
1935 *args, **kwargs)
1936 def _start_connection(self):
1937 self._connection,self._cursor = sqlite_connect(*self.args, **self.kwargs)
1938 def __getstate__(self):
1939 if self.args[0] == ':memory:':
1940 raise ValueError('SQLite in-memory database is not picklable!')
1941 return DBServerInfo.__getstate__(self)
1943 # list of DBServerInfo subclasses for different modules
1944 _DBServerModuleDict = dict(MySQLdb=MySQLServerInfo, sqlite=SQLiteServerInfo)
1947 class MapView(object, UserDict.DictMixin):
1948 'general purpose 1:1 mapping defined by any SQL query'
1949 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None,
1950 serverInfo=None, inverseSQL=None, **kwargs):
1951 self.sourceDB = sourceDB
1952 self.targetDB = targetDB
1953 self.viewSQL = viewSQL
1954 self.inverseSQL = inverseSQL
1955 if cursor is None:
1956 if serverInfo is not None: # get cursor from serverInfo
1957 cursor = serverInfo.cursor()
1958 else:
1959 try: # can we get it from our other db?
1960 serverInfo = sourceDB.serverInfo
1961 except AttributeError:
1962 raise ValueError('you must provide serverInfo or cursor!')
1963 else:
1964 cursor = serverInfo.cursor()
1965 self.cursor = cursor
1966 self.serverInfo = serverInfo
1967 self.get_sql_format(False) # get sql formatter for this db interface
1968 _schemaModuleDict = _schemaModuleDict # default module list
1969 get_sql_format = get_table_schema
1970 def __getitem__(self, k):
1971 if not hasattr(k,'db') or k.db is not self.sourceDB:
1972 raise KeyError('object is not in the sourceDB bound to this map!')
1973 sql,params = self._format_query(self.viewSQL, (k.id,))
1974 self.cursor.execute(sql, params) # formatted for this db interface
1975 t = self.cursor.fetchmany(2) # get at most two rows
1976 if len(t) != 1:
1977 raise KeyError('%s not found in MapView, or not unique'
1978 % str(k))
1979 return self.targetDB[t[0][0]] # get the corresponding object
1980 _pickleAttrs = dict(sourceDB=0, targetDB=0, viewSQL=0, serverInfo=0,
1981 inverseSQL=0)
1982 __getstate__ = standard_getstate
1983 __setstate__ = standard_setstate
1984 __setitem__ = __delitem__ = clear = pop = popitem = update = \
1985 setdefault = read_only_error
1986 def __iter__(self):
1987 'only yield sourceDB items that are actually in this mapping!'
1988 for k in self.sourceDB.itervalues():
1989 try:
1990 self[k]
1991 yield k
1992 except KeyError:
1993 pass
1994 def keys(self):
1995 return [k for k in self] # don't use list(self); causes infinite loop!
1996 def __invert__(self):
1997 try:
1998 return self._inverse
1999 except AttributeError:
2000 if self.inverseSQL is None:
2001 raise ValueError('this MapView has no inverseSQL!')
2002 self._inverse = self.__class__(self.targetDB, self.sourceDB,
2003 self.inverseSQL, self.cursor,
2004 serverInfo=self.serverInfo,
2005 inverseSQL=self.viewSQL)
2006 self._inverse._inverse = self
2007 return self._inverse
2009 class GraphViewEdgeDict(UserDict.DictMixin):
2010 'edge dictionary for GraphView: just pre-loaded on init'
2011 def __init__(self, g, k):
2012 self.g = g
2013 self.k = k
2014 sql,params = self.g._format_query(self.g.viewSQL, (k.id,))
2015 self.g.cursor.execute(sql, params) # run the query
2016 l = self.g.cursor.fetchall() # get results
2017 if len(l) <= 0:
2018 raise KeyError('key %s not in GraphView' % k.id)
2019 self.targets = [t[0] for t in l] # preserve order of the results
2020 d = {} # also keep targetID:edgeID mapping
2021 if self.g.edgeDB is not None: # save with edge info
2022 for t in l:
2023 d[t[0]] = t[1]
2024 else:
2025 for t in l:
2026 d[t[0]] = None
2027 self.targetDict = d
2028 def __len__(self):
2029 return len(self.targets)
2030 def __iter__(self):
2031 for k in self.targets:
2032 yield self.g.targetDB[k]
2033 def keys(self):
2034 return list(self)
2035 def iteritems(self):
2036 if self.g.edgeDB is not None: # save with edge info
2037 for k in self.targets:
2038 yield (self.g.targetDB[k], self.g.edgeDB[self.targetDict[k]])
2039 else: # just save the list of targets, no edge info
2040 for k in self.targets:
2041 yield (self.g.targetDB[k], None)
2042 def __getitem__(self, o, exitIfFound=False):
2043 'for the specified target object, return its associated edge object'
2044 try:
2045 if o.db is not self.g.targetDB:
2046 raise KeyError('key is not part of targetDB!')
2047 edgeID = self.targetDict[o.id]
2048 except AttributeError:
2049 raise KeyError('key has no id or db attribute?!')
2050 if exitIfFound:
2051 return
2052 if self.g.edgeDB is not None: # return the edge object
2053 return self.g.edgeDB[edgeID]
2054 else: # no edge info
2055 return None
2056 def __contains__(self, o):
2057 try:
2058 self.__getitem__(o, True) # raise KeyError if not found
2059 return True
2060 except KeyError:
2061 return False
2062 __setitem__ = __delitem__ = clear = pop = popitem = update = \
2063 setdefault = read_only_error
2065 class GraphView(MapView):
2066 'general purpose graph interface defined by any SQL query'
2067 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None, edgeDB=None,
2068 **kwargs):
2069 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
2070 self.edgeDB = edgeDB
2071 MapView.__init__(self, sourceDB, targetDB, viewSQL, cursor, **kwargs)
2072 def __getitem__(self, k):
2073 if not hasattr(k,'db') or k.db is not self.sourceDB:
2074 raise KeyError('object is not in the sourceDB bound to this map!')
2075 return GraphViewEdgeDict(self, k)
2076 _pickleAttrs = MapView._pickleAttrs.copy()
2077 _pickleAttrs.update(dict(edgeDB=0))
2079 # @CTB move to sqlgraph.py?
2081 class SQLSequence(SQLRow, SequenceBase):
2082 """Transparent access to a DB row representing a sequence.
2084 Use attrAlias dict to rename 'length' to something else.
2086 def _init_subclass(cls, db, **kwargs):
2087 db.seqInfoDict = db # db will act as its own seqInfoDict
2088 SQLRow._init_subclass(db=db, **kwargs)
2089 _init_subclass = classmethod(_init_subclass)
2090 def __init__(self, id):
2091 SQLRow.__init__(self, id)
2092 SequenceBase.__init__(self)
2093 def __len__(self):
2094 return self.length
2095 def strslice(self,start,end):
2096 "Efficient access to slice of a sequence, useful for huge contigs"
2097 return self._select('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
2098 %(self.db._attrSQL('seq'),start+1,end-start))
2100 class DNASQLSequence(SQLSequence):
2101 _seqtype=DNA_SEQTYPE
2103 class RNASQLSequence(SQLSequence):
2104 _seqtype=RNA_SEQTYPE
2106 class ProteinSQLSequence(SQLSequence):
2107 _seqtype=PROTEIN_SEQTYPE