generalized iter_keys(), added orderBy support
[pygr.git] / pygr / sqlgraph.py
blobdaa5122fbbd5bdf9d8374ca6fbff5c1f7c014a97
3 from __future__ import generators
4 from mapping import *
5 from sequence import SequenceBase, DNA_SEQTYPE, RNA_SEQTYPE, PROTEIN_SEQTYPE
6 import types
7 from classutil import methodFactory,standard_getstate,\
8 override_rich_cmp,generate_items,get_bound_subclass,standard_setstate,\
9 get_valid_path,standard_invert,RecentValueDictionary,read_only_error,\
10 SourceFileName, split_kwargs
11 import os
12 import platform
13 import UserDict
14 import warnings
15 import logger
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self, db, attr):
20 self.icol = db.data[attr] # index of this attribute in the tuple
21 def __get__(self, obj, klass):
22 return obj._data[self.icol]
23 def __set__(self, obj, val):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor):
27 def __set__(self, obj, val):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor):
32 'read-write interface to named attribute'
33 def __init__(self, db, attr):
34 self.attr = attr
35 self.icol = db.data[attr] # index of this attribute in the tuple
36 self.attrSQL = db._attrSQL(attr, sqlColumn=True) # SQL column name
37 def __set__(self, obj, val):
38 obj.db._update(obj.id, self.attrSQL, val) # AND UPDATE THE DATABASE
39 obj.save_local(self.attr, val)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self, db, attr):
44 self.selectSQL = db._attrSQL(attr) # SQL expression for this attr
45 def __get__(self, obj, klass):
46 return obj._select(self.selectSQL)
47 def __set__(self, obj, val):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self, obj, val):
53 obj.db._update(obj.id, self.selectSQL, val) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self, db, attr):
58 self.attr = '_'+attr
59 def __get__(self, obj, klass):
60 return getattr(obj, self.attr)
61 def __set__(self, obj, val):
62 raise AttributeError('attribute %s is read-only!' % self.attr)
65 def select_from_row(row, what):
66 "return value from SQL expression applied to this row"
67 sql,params = row.db._format_query('select %s from %s where %s=%%s limit 2'
68 % (what,row.db.name,row.db.primary_key),
69 (row.id,))
70 row.db.cursor.execute(sql, params)
71 t = row.db.cursor.fetchmany(2) # get at most two rows
72 if len(t) != 1:
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row.db.name,str(row.id),what))
75 return t[0][0] #return the single field we requested
77 def init_row_subclass(cls, db):
78 'add descriptors for db attributes'
79 for attr in db.data: # bind all database columns
80 if attr == 'id': # handle ID attribute specially
81 setattr(cls, attr, cls._idDescriptor(db, attr))
82 continue
83 try: # check if this attr maps to an SQL column
84 db._attrSQL(attr, columnNumber=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls, attr, cls._sqlDescriptor(db, attr))
87 else: # treat as interface to our stored tuple
88 setattr(cls, attr, cls._columnDescriptor(db, attr))
90 def dir_row(self):
91 """get list of column names as our attributes """
92 return self.db.data.keys()
94 class TupleO(object):
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor = TupleDescriptor
115 _idDescriptor = TupleIDDescriptor
116 _sqlDescriptor = SQLDescriptor
117 _init_subclass = classmethod(init_row_subclass)
118 _select = select_from_row
119 __dir__ = dir_row
120 def __init__(self, data):
121 self._data = data # save our data tuple
123 def insert_and_cache_id(self, l, **kwargs):
124 'insert tuple into db and cache its rowID on self'
125 self.db._insert(l) # save to database
126 try:
127 rowID = kwargs['id'] # use the ID supplied by user
128 except KeyError:
129 rowID = self.db.get_insert_id() # get auto-inc ID value
130 self.cache_id(rowID) # cache this ID on self
132 class TupleORW(TupleO):
133 'read-write version of TupleO'
134 _columnDescriptor = TupleDescriptorRW
135 insert_and_cache_id = insert_and_cache_id
136 def __init__(self, data, newRow=False, **kwargs):
137 if not newRow: # just cache data from the database
138 self._data = data
139 return
140 self._data = self.db.tuple_from_dict(kwargs) # convert to tuple
141 self.insert_and_cache_id(self._data, **kwargs)
142 def cache_id(self,row_id):
143 self.save_local('id',row_id)
144 def save_local(self,attr,val):
145 icol = self._attrcol[attr]
146 try:
147 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self._data = list(self._data)
150 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
152 TupleO._RWClass = TupleORW # record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self, db, attr, readOnly = False):
157 self.attr = attr
158 self.col = db._attrSQL(attr, sqlColumn=True) # MAP THIS TO SQL COLUMN NAME
159 self.db = db
160 if readOnly:
161 self.__class__ = self._readOnlyClass
162 def __get__(self, obj, objtype):
163 try:
164 return obj.__dict__[self.attr]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self.col==self.db.primary_key:
167 raise AttributeError
168 self.db._select('where %s=%%s' % self.db.primary_key,(obj.id,),self.col)
169 l = self.db.cursor.fetchall()
170 if len(l)!=1:
171 raise AttributeError('db row not found or not unique!')
172 obj.__dict__[self.attr] = l[0][0] # UPDATE THE CACHE
173 return l[0][0]
174 def __set__(self, obj, val):
175 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self.db._update(obj.id, self.col, val) # UPDATE THE DATABASE
177 obj.__dict__[self.attr] = val # UPDATE THE CACHE
178 ## try:
179 ## m = self.consequences
180 ## except AttributeError:
181 ## return
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
185 ## import new
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor):
188 def __set__(self, obj, val):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor._readOnlyClass = ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor = _sqlDescriptor = SQLDescriptor
200 _idDescriptor = ReadOnlyDescriptor
201 _init_subclass = classmethod(init_row_subclass)
202 _select = select_from_row
203 __dir__ = dir_row
204 def __init__(self, rowID):
205 self._id = rowID
208 class SQLRowRW(SQLRow):
209 'read-write version of SQLRow'
210 _columnDescriptor = SQLDescriptorRW
211 insert_and_cache_id = insert_and_cache_id
212 def __init__(self, rowID, newRow=False, **kwargs):
213 if not newRow: # just cache data from the database
214 return self.cache_id(rowID)
215 l = self.db.tuple_from_dict(kwargs) # convert to tuple
216 self.insert_and_cache_id(l, **kwargs)
217 def cache_id(self, rowID):
218 self._id = rowID
220 SQLRow._RWClass = SQLRowRW
224 def list_to_dict(names, values):
225 'return dictionary of those named args that are present in values[]'
226 d={}
227 for i,v in enumerate(values):
228 try:
229 d[names[i]] = v
230 except IndexError:
231 break
232 return d
235 def get_name_cursor(name=None, **kwargs):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
239 if name is not None:
240 argList = name.split() # TREAT AS WS-SEPARATED LIST
241 if len(argList)>1:
242 name = argList[0] # USE 1ST ARG AS TABLE NAME
243 argnames = ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs = kwargs.copy() # a copy we can overwrite
245 kwargs.update(list_to_dict(argnames, argList[1:]))
246 serverInfo = DBServerInfo(**kwargs)
247 return name,serverInfo.cursor(),serverInfo
249 def mysql_connect(connect=None, configFile=None, useStreaming=False, **args):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs = args.copy() # a copy we can modify
252 if 'user' not in kwargs and configFile is None: #Find where config file is
253 osname = platform.system()
254 if osname in('Microsoft', 'Windows'): # Machine is a Windows box
255 paths = []
256 try: # handle case where WINDIR not defined by Windows...
257 windir = os.environ['WINDIR']
258 paths += [(windir, 'my.ini'), (windir, 'my.cnf')]
259 except KeyError:
260 pass
261 try:
262 sysdrv = os.environ['SYSTEMDRIVE']
263 paths += [(sysdrv, os.path.sep + 'my.ini'),
264 (sysdrv, os.path.sep + 'my.cnf')]
265 except KeyError:
266 pass
267 if len(paths) > 0:
268 configFile = get_valid_path(*paths)
269 else: # treat as normal platform with home directories
270 configFile = os.path.join(os.path.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile = configFile or os.path.join( os.getcwd(), 'mysql.cnf' )
276 if configFile and os.path.exists(configFile):
277 kwargs['read_default_file'] = configFile
278 connect = None # force it to use MySQLdb
279 if connect is None:
280 import MySQLdb
281 connect = MySQLdb.connect
282 kwargs['compress'] = True
283 if useStreaming: # use server side cursors for scalable result sets
284 try:
285 from MySQLdb import cursors
286 kwargs['cursorclass'] = cursors.SSCursor
287 except (ImportError, AttributeError):
288 pass
289 conn = connect(**kwargs)
290 cursor = conn.cursor()
291 return conn,cursor
293 _mysqlMacros = dict(IGNORE='ignore', REPLACE='replace',
294 AUTO_INCREMENT='AUTO_INCREMENT', SUBSTRING='substring',
295 SUBSTR_FROM='FROM', SUBSTR_FOR='FOR')
297 def mysql_table_schema(self, analyzeSchema=True):
298 'retrieve table schema from a MySQL database, save on self'
299 import MySQLdb
300 self._format_query = SQLFormatDict(MySQLdb.paramstyle, _mysqlMacros)
301 if not analyzeSchema:
302 return
303 self.clear_schema() # reset settings and dictionaries
304 self.cursor.execute('describe %s' % self.name) # get info about columns
305 columns = self.cursor.fetchall()
306 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
307 for icol,c in enumerate(columns):
308 field = c[0]
309 self.columnName.append(field) # list of columns in same order as table
310 if c[3] == "PRI": # record as primary key
311 if self.primary_key is None:
312 self.primary_key = field
313 else:
314 try:
315 self.primary_key.append(field)
316 except AttributeError:
317 self.primary_key = [self.primary_key,field]
318 if c[1][:3].lower() == 'int':
319 self.usesIntID = True
320 else:
321 self.usesIntID = False
322 elif c[3] == "MUL":
323 self.indexed[field] = icol
324 self.description[field] = self.cursor.description[icol]
325 self.columnType[field] = c[1] # SQL COLUMN TYPE
327 _sqliteMacros = dict(IGNORE='or ignore', REPLACE='insert or replace',
328 AUTO_INCREMENT='', SUBSTRING='substr',
329 SUBSTR_FROM=',', SUBSTR_FOR=',')
331 def import_sqlite():
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
333 try:
334 import sqlite3 as sqlite
335 except ImportError:
336 from pysqlite2 import dbapi2 as sqlite
337 return sqlite
339 def sqlite_table_schema(self, analyzeSchema=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite = import_sqlite()
342 self._format_query = SQLFormatDict(sqlite.paramstyle, _sqliteMacros)
343 if not analyzeSchema:
344 return
345 self.clear_schema() # reset settings and dictionaries
346 self.cursor.execute('PRAGMA table_info("%s")' % self.name)
347 columns = self.cursor.fetchall()
348 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
349 for icol,c in enumerate(columns):
350 field = c[1]
351 self.columnName.append(field) # list of columns in same order as table
352 self.description[field] = self.cursor.description[icol]
353 self.columnType[field] = c[2] # SQL COLUMN TYPE
354 self.cursor.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self.name) # get primary key / unique indexes
355 for indexname in self.cursor.fetchall(): # search indexes for primary key
356 self.cursor.execute('PRAGMA index_info("%s")' % indexname)
357 l = self.cursor.fetchall() # get list of columns in this index
358 if len(l) == 1: # assume 1st single-column unique index is primary key!
359 self.primary_key = l[0][2]
360 break # done searching for primary key!
361 if self.primary_key is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self.cursor.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self.name)
363 sql = self.cursor.fetchall()[0][0]
364 for columnSQL in sql[sql.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL.lower(): # must be the primary key!
366 col = columnSQL.split()[0] # get column name
367 if col in self.columnType:
368 self.primary_key = col
369 break # done searching for primary key!
370 else:
371 raise ValueError('unknown primary key %s in table %s'
372 % (col,self.name))
373 if self.primary_key is not None: # check its type
374 if self.columnType[self.primary_key] == 'int' or \
375 self.columnType[self.primary_key] == 'integer':
376 self.usesIntID = True
377 else:
378 self.usesIntID = False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats = dict(pyformat='%%(%d)s', numeric=':%d', named=':%d',
393 qmark='(ignore)', format='(ignore)')
394 def __init__(self, paramstyle, substitutionDict={}):
395 self.substitutionDict = substitutionDict.copy()
396 self.paramstyle = paramstyle
397 self.paramFormat = self._paramFormats[paramstyle]
398 self.makeDict = (paramstyle == 'pyformat' or paramstyle == 'named')
399 if paramstyle == 'qmark': # handle these as simple substitution
400 self.substitutionDict['?'] = '?'
401 elif paramstyle == 'format':
402 self.substitutionDict['?'] = '%s'
403 def __getitem__(self, k):
404 'apply correct substitution for this SQL interface'
405 try:
406 return self.substitutionDict[k] # apply our substitutions
407 except KeyError:
408 pass
409 if k == '?': # sequential parameter
410 s = self.paramFormat % self.iparam
411 self.iparam += 1 # advance to the next parameter
412 return s
413 raise KeyError('unknown macro: %s' % k)
414 def __call__(self, sql, paramList):
415 'returns corrected sql,params for this interface'
416 self.iparam = 1 # DB-ABI param indexing begins at 1
417 sql = sql.replace('%s', '%(?)s') # convert format into pyformat
418 s = sql % self # apply all %(x)s replacements in sql
419 if self.makeDict: # construct a params dict
420 paramDict = {}
421 for i,param in enumerate(paramList):
422 paramDict[str(i + 1)] = param #DB-ABI param indexing begins at 1
423 return s,paramDict
424 else: # just return the original params list
425 return s,paramList
427 def get_table_schema(self, analyzeSchema=True):
428 'run the right schema function based on type of db server connection'
429 try:
430 modname = self.cursor.__class__.__module__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
433 try:
434 schema_func = self._schemaModuleDict[modname]
435 except KeyError:
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname)
439 schema_func(self, analyzeSchema) # run the schema function
442 _schemaModuleDict = {'MySQLdb.cursors':mysql_table_schema,
443 'pysqlite2.dbapi2':sqlite_table_schema,
444 'sqlite3':sqlite_table_schema}
446 class SQLTableBase(object, UserDict.DictMixin):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict = _schemaModuleDict # default module list
449 get_table_schema = get_table_schema
450 def __init__(self,name,cursor=None,itemClass=None,attrAlias=None,
451 clusterKey=None,createTable=None,graph=None,maxCache=None,
452 arraysize=1024, itemSliceClass=None, dropIfExists=False,
453 serverInfo=None, autoGC=True, orderBy=None,
454 writeable=False, **kwargs):
455 if autoGC: # automatically garbage collect unused objects
456 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
457 else:
458 self._weakValueDict = {}
459 self.autoGC = autoGC
460 self.orderBy = orderBy
461 self.writeable = writeable
462 if cursor is None:
463 if serverInfo is not None: # get cursor from serverInfo
464 cursor = serverInfo.cursor()
465 else: # try to read connection info from name or config file
466 name,cursor,serverInfo = get_name_cursor(name,**kwargs)
467 else:
468 warnings.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
469 DeprecationWarning, stacklevel=2)
470 self.cursor = cursor
471 if createTable is not None: # RUN COMMAND TO CREATE THIS TABLE
472 if dropIfExists: # get rid of any existing table
473 cursor.execute('drop table if exists ' + name)
474 self.get_table_schema(False) # check dbtype, init _format_query
475 sql,params = self._format_query(createTable, ()) # apply macros
476 cursor.execute(sql) # create the table
477 self.name = name
478 if graph is not None:
479 self.graph = graph
480 if maxCache is not None:
481 self.maxCache = maxCache
482 if arraysize is not None:
483 self.arraysize = arraysize
484 cursor.arraysize = arraysize
485 self.get_table_schema() # get schema of columns to serve as attrs
486 self.data = {} # map of all attributes, including aliases
487 for icol,field in enumerate(self.columnName):
488 self.data[field] = icol # 1st add mappings to columns
489 try:
490 self.data['id']=self.data[self.primary_key]
491 except (KeyError,TypeError):
492 pass
493 if hasattr(self,'_attr_alias'): # apply attribute aliases for this class
494 self.addAttrAlias(False,**self._attr_alias)
495 self.objclass(itemClass) # NEED TO SUBCLASS OUR ITEM CLASS
496 if itemSliceClass is not None:
497 self.itemSliceClass = itemSliceClass
498 get_bound_subclass(self, 'itemSliceClass', self.name) # need to subclass itemSliceClass
499 if attrAlias is not None: # ADD ATTRIBUTE ALIASES
500 self.attrAlias = attrAlias # RECORD FOR PICKLING PURPOSES
501 self.data.update(attrAlias)
502 if clusterKey is not None:
503 self.clusterKey=clusterKey
504 if serverInfo is not None:
505 self.serverInfo = serverInfo
507 def __len__(self):
508 self._select(selectCols='count(*)')
509 return self.cursor.fetchone()[0]
510 def __hash__(self):
511 return id(self)
512 _pickleAttrs = dict(name=0, clusterKey=0, maxCache=0, arraysize=0,
513 attrAlias=0, serverInfo=0, autoGC=0, orderBy=0,
514 writeable=0)
515 __getstate__ = standard_getstate
516 def __setstate__(self,state):
517 # default cursor provisioning by worldbase is deprecated!
518 ## if 'serverInfo' not in state: # hmm, no address for db server?
519 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
520 ## from Data import getResource
521 ## state['cursor'] = getResource.getTableCursor(state['name'])
522 ## except ImportError:
523 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
524 self.__init__(**state)
525 def __repr__(self):
526 return '<SQL table '+self.name+'>'
528 def clear_schema(self):
529 'reset all schema information for this table'
530 self.description={}
531 self.columnName = []
532 self.columnType = {}
533 self.usesIntID = None
534 self.primary_key = None
535 self.indexed = {}
536 def _attrSQL(self,attr,sqlColumn=False,columnNumber=False):
537 "Translate python attribute name to appropriate SQL expression"
538 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
539 field=self.data[attr]
540 except KeyError:
541 raise AttributeError('attribute %s not a valid column or alias in %s'
542 % (attr,self.name))
543 if sqlColumn: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
544 try: # CHECK IF field IS COLUMN NUMBER
545 return self.columnName[field] # RETURN SQL COLUMN NAME
546 except TypeError:
547 try: # CHECK IF field IS SQL COLUMN NAME
548 return self.columnName[self.data[field]] # THIS WILL JUST RETURN field...
549 except (KeyError,TypeError):
550 raise AttributeError('attribute %s does not map to an SQL column in %s'
551 % (attr,self.name))
552 if columnNumber:
553 try: # CHECK IF field IS A COLUMN NUMBER
554 return field+0 # ONLY RETURN AN INTEGER
555 except TypeError:
556 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
557 return self.data[field]+0 # ONLY RETURN AN INTEGER
558 except (KeyError,TypeError):
559 raise ValueError('attribute %s does not map to a SQL column!' % attr)
560 if isinstance(field,types.StringType):
561 attr=field # USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
562 elif attr=='id':
563 attr=self.primary_key
564 return attr
565 def addAttrAlias(self,saveToPickle=True,**kwargs):
566 """Add new attributes as aliases of existing attributes.
567 They can be specified either as named args:
568 t.addAttrAlias(newattr=oldattr)
569 or by passing a dictionary kwargs whose keys are newattr
570 and values are oldattr:
571 t.addAttrAlias(**kwargs)
572 saveToPickle=True forces these aliases to be saved if object is pickled.
574 if saveToPickle:
575 self.attrAlias.update(kwargs)
576 for key,val in kwargs.items():
577 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
578 self.data[val]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
579 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
580 except TypeError: # val IS ITSELF AN ALIAS
581 self.data[key] = self.data[val] # SO MAP TO WHAT IT MAPS TO
582 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
583 self.data[key] = val
584 def objclass(self,oclass=None):
585 "Create class representing a row in this table by subclassing oclass, adding data"
586 if oclass is not None: # use this as our base itemClass
587 self.itemClass = oclass
588 if self.writeable:
589 self.itemClass = self.itemClass._RWClass # use its writeable version
590 oclass = get_bound_subclass(self, 'itemClass', self.name,
591 subclassArgs=dict(db=self)) # bind itemClass
592 if issubclass(oclass, TupleO):
593 oclass._attrcol = self.data # BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
594 if hasattr(oclass,'_tableclass') and not isinstance(self,oclass._tableclass):
595 self.__class__=oclass._tableclass # ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
596 def _select(self, whereClause='', params=(), selectCols='t1.*',
597 cursor=None, orderBy=''):
598 'execute the specified query but do not fetch'
599 sql,params = self._format_query('select %s from %s t1 %s %s'
600 % (selectCols, self.name, whereClause, orderBy),
601 params)
602 if cursor is None:
603 self.cursor.execute(sql, params)
604 else:
605 cursor.execute(sql, params)
606 def select(self,whereClause,params=None,oclass=None,selectCols='t1.*'):
607 "Generate the list of objects that satisfy the database SELECT"
608 if oclass is None:
609 oclass=self.itemClass
610 self._select(whereClause,params,selectCols)
611 l=self.cursor.fetchall()
612 for t in l:
613 yield self.cacheItem(t,oclass)
614 def query(self,**kwargs):
615 'query for intersection of all specified kwargs, returned as iterator'
616 criteria = []
617 params = []
618 for k,v in kwargs.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
619 if v is None: # CONVERT TO SQL NULL TEST
620 criteria.append('%s IS NULL' % self._attrSQL(k))
621 else: # TEST FOR EQUALITY
622 criteria.append('%s=%%s' % self._attrSQL(k))
623 params.append(v)
624 return self.select('where '+' and '.join(criteria),params)
625 def _update(self,row_id,col,val):
626 'update a single field in the specified row to the specified value'
627 sql,params = self._format_query('update %s set %s=%%s where %s=%%s'
628 %(self.name,col,self.primary_key),
629 (val,row_id))
630 self.cursor.execute(sql, params)
631 def getID(self,t):
632 try:
633 return t[self.data['id']] # GET ID FROM TUPLE
634 except TypeError: # treat as alias
635 return t[self.data[self.data['id']]]
636 def cacheItem(self,t,oclass):
637 'get obj from cache if possible, or construct from tuple'
638 try:
639 id=self.getID(t)
640 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
641 return oclass(t)
642 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
643 return self._weakValueDict[id]
644 except KeyError:
645 pass
646 o = oclass(t)
647 self._weakValueDict[id] = o # CACHE THIS ITEM IN OUR DICTIONARY
648 return o
649 def cache_items(self,rows,oclass=None):
650 if oclass is None:
651 oclass=self.itemClass
652 for t in rows:
653 yield self.cacheItem(t,oclass)
654 def foreignKey(self,attr,k):
655 'get iterator for objects with specified foreign key value'
656 return self.select('where %s=%%s'%attr,(k,))
657 def limit_cache(self):
658 'APPLY maxCache LIMIT TO CACHE SIZE'
659 try:
660 if self.maxCache<len(self._weakValueDict):
661 self._weakValueDict.clear()
662 except AttributeError:
663 pass
665 def get_new_cursor(self):
666 """Return a new cursor object, or None if not possible """
667 try:
668 new_cursor = self.serverInfo.new_cursor
669 except AttributeError:
670 return None
671 return new_cursor(self.arraysize)
673 def generic_iterator(self, cursor=None, fetch_f=None, cache_f=None,
674 map_f=iter):
675 'generic iterator that runs fetch, cache and map functions'
676 if fetch_f is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
677 if cursor is None:
678 fetch_f = self.cursor.fetchmany
679 else: # isolate this iter from other queries
680 fetch_f = cursor.fetchmany
681 if cache_f is None:
682 cache_f = self.cache_items
683 while True:
684 self.limit_cache()
685 rows = fetch_f() # FETCH THE NEXT SET OF ROWS
686 if len(rows)==0: # NO MORE DATA SO ALL DONE
687 break
688 for v in map_f(cache_f(rows)): # CACHE AND GENERATE RESULTS
689 yield v
690 if cursor is not None: # close iterator now that we're done
691 cursor.close()
692 def tuple_from_dict(self, d):
693 'transform kwarg dict into tuple for storing in database'
694 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
695 for col,icol in self.data.items():
696 try:
697 l[icol] = d[col]
698 except (KeyError,TypeError):
699 pass
700 return l
701 def tuple_from_obj(self, obj):
702 'transform object attributes into tuple for storing in database'
703 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
704 for col,icol in self.data.items():
705 try:
706 l[icol] = getattr(obj,col)
707 except (AttributeError,TypeError):
708 pass
709 return l
710 def _insert(self, l):
711 '''insert tuple into the database. Note this uses the MySQL
712 extension REPLACE, which overwrites any duplicate key.'''
713 s = '%(REPLACE)s into ' + self.name + ' values (' \
714 + ','.join(['%s']*len(l)) + ')'
715 sql,params = self._format_query(s, l)
716 self.cursor.execute(sql, params)
717 def insert(self, obj):
718 '''insert new row by transforming obj to tuple of values'''
719 l = self.tuple_from_obj(obj)
720 self._insert(l)
721 def get_insert_id(self):
722 'get the primary key value for the last INSERT'
723 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
724 auto_id = self.cursor.lastrowid
725 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
726 raise NotImplementedError('''your db lacks lastrowid support?''')
727 if auto_id is None:
728 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
729 return auto_id
730 def new(self, **kwargs):
731 'return a new record with the assigned attributes, added to DB'
732 if not self.writeable:
733 raise ValueError('this database is read only!')
734 obj = self.itemClass(None, newRow=True, **kwargs) # saves itself to db
735 self._weakValueDict[obj.id] = obj # AND SAVE TO OUR LOCAL DICT CACHE
736 return obj
737 def clear_cache(self):
738 'empty the cache'
739 self._weakValueDict.clear()
740 def __delitem__(self, k):
741 if not self.writeable:
742 raise ValueError('this database is read only!')
743 sql,params = self._format_query('delete from %s where %s=%%s'
744 % (self.name,self.primary_key),(k,))
745 self.cursor.execute(sql, params)
746 try:
747 del self._weakValueDict[k]
748 except KeyError:
749 pass
751 def getKeys(self,queryOption='', selectCols=None):
752 'uses db select; does not force load'
753 if selectCols is None:
754 selectCols=self.primary_key
755 if queryOption=='' and self.orderBy is not None:
756 queryOption = self.orderBy # apply default ordering
757 self.cursor.execute('select %s from %s %s'
758 %(selectCols,self.name,queryOption))
759 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
761 def iter_keys(self, selectCols=None, orderBy='', map_f=iter,
762 cache_f=lambda x:[t[0] for t in x], get_f=None, **kwargs):
763 'guarantee correct iteration insulated from other queries'
764 if selectCols is None:
765 selectCols=self.primary_key
766 if orderBy=='' and self.orderBy is not None:
767 orderBy = self.orderBy # apply default ordering
768 cursor = self.get_new_cursor()
769 if cursor: # got our own cursor, guaranteeing query isolation
770 self._select(cursor=cursor, selectCols=selectCols, orderBy=orderBy,
771 **kwargs)
772 return self.generic_iterator(cursor=cursor, cache_f=cache_f, map_f=map_f)
773 else: # must pre-fetch all keys to ensure query isolation
774 if get_f is not None:
775 return iter(get_f())
776 else:
777 return iter(self.keys())
779 class SQLTable(SQLTableBase):
780 "Provide on-the-fly access to rows in the database, caching the results in dict"
781 itemClass = TupleO # our default itemClass; constructor can override
782 keys=getKeys
783 __iter__ = iter_keys
784 def load(self,oclass=None):
785 "Load all data from the table"
786 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
787 return self._isLoaded
788 except AttributeError:
789 pass
790 if oclass is None:
791 oclass=self.itemClass
792 self.cursor.execute('select * from %s' % self.name)
793 l=self.cursor.fetchall()
794 self._weakValueDict = {} # just store the whole dataset in memory
795 for t in l:
796 self.cacheItem(t,oclass) # CACHE IT IN LOCAL DICTIONARY
797 self._isLoaded=True # MARK THIS CONTAINER AS FULLY LOADED
799 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
800 try:
801 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
802 except KeyError: # NOT FOUND, SO TRY THE DATABASE
803 sql,params = self._format_query('select * from %s where %s=%%s limit 2'
804 % (self.name,self.primary_key),(k,))
805 self.cursor.execute(sql, params)
806 l = self.cursor.fetchmany(2) # get at most 2 rows
807 if len(l) != 1:
808 raise KeyError('%s not found in %s, or not unique' %(str(k),self.name))
809 self.limit_cache()
810 return self.cacheItem(l[0],self.itemClass) # CACHE IT IN LOCAL DICTIONARY
811 def __setitem__(self, k, v):
812 if not self.writeable:
813 raise ValueError('this database is read only!')
814 try:
815 if v.db != self:
816 raise AttributeError
817 except AttributeError:
818 raise ValueError('object not bound to itemClass for this db!')
819 try:
820 oldID = v.id
821 if oldID is None:
822 raise AttributeError
823 except AttributeError:
824 pass
825 else: # delete row with old ID
826 del self[v.id]
827 v.cache_id(k) # cache the new ID on the object
828 self.insert(v) # SAVE TO THE RELATIONAL DB SERVER
829 self._weakValueDict[k] = v # CACHE THIS ITEM IN OUR DICTIONARY
830 def items(self):
831 'forces load of entire table into memory'
832 self.load()
833 return self._weakValueDict.items()
834 def iteritems(self):
835 'uses arraysize / maxCache and fetchmany() to manage data transfer'
836 cursor = self.get_new_cursor()
837 self._select(cursor=cursor)
838 return self.generic_iterator(cursor=cursor, map_f=generate_items)
839 def values(self):
840 'forces load of entire table into memory'
841 self.load()
842 return self._weakValueDict.values()
843 def itervalues(self):
844 'uses arraysize / maxCache and fetchmany() to manage data transfer'
845 cursor = self.get_new_cursor()
846 self._select(cursor=cursor)
847 return self.generic_iterator(cursor=cursor)
849 def getClusterKeys(self,queryOption=''):
850 'uses db select; does not force load'
851 self.cursor.execute('select distinct %s from %s %s'
852 %(self.clusterKey,self.name,queryOption))
853 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
856 class SQLTableClustered(SQLTable):
857 '''use clusterKey to load a whole cluster of rows at once,
858 specifically, all rows that share the same clusterKey value.'''
859 def __init__(self, *args, **kwargs):
860 kwargs = kwargs.copy() # get a copy we can alter
861 kwargs['autoGC'] = False # don't use WeakValueDictionary
862 SQLTable.__init__(self, *args, **kwargs)
863 def keys(self):
864 return getKeys(self,'order by %s' %self.clusterKey)
865 def clusterkeys(self):
866 return getClusterKeys(self, 'order by %s' %self.clusterKey)
867 def __getitem__(self,k):
868 try:
869 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
870 except KeyError: # NOT FOUND, SO TRY THE DATABASE
871 sql,params = self._format_query('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
872 % (self.name,self.name,self.primary_key,
873 self.clusterKey,self.clusterKey),(k,))
874 self.cursor.execute(sql, params)
875 l=self.cursor.fetchall()
876 self.limit_cache()
877 for t in l: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
878 self.cacheItem(t,self.itemClass)
879 return self._weakValueDict[k] # should be in cache, if row k exists
880 def itercluster(self,cluster_id):
881 'iterate over all items from the specified cluster'
882 self.limit_cache()
883 return self.select('where %s=%%s'%self.clusterKey,(cluster_id,))
884 def fetch_cluster(self):
885 'use self.cursor.fetchmany to obtain all rows for next cluster'
886 icol = self._attrSQL(self.clusterKey,columnNumber=True)
887 result = []
888 try:
889 rows = self._fetch_cluster_cache # USE SAVED ROWS FROM PREVIOUS CALL
890 del self._fetch_cluster_cache
891 except AttributeError:
892 rows = self.cursor.fetchmany()
893 try:
894 cluster_id = rows[0][icol]
895 except IndexError:
896 return result
897 while len(rows)>0:
898 for i,t in enumerate(rows): # CHECK THAT ALL ROWS FROM THIS CLUSTER
899 if cluster_id != t[icol]: # START OF A NEW CLUSTER
900 result += rows[:i] # RETURN ROWS OF CURRENT CLUSTER
901 self._fetch_cluster_cache = rows[i:] # SAVE NEXT CLUSTER
902 return result
903 result += rows
904 rows = self.cursor.fetchmany() # GET NEXT SET OF ROWS
905 return result
906 def itervalues(self):
907 'uses arraysize / maxCache and fetchmany() to manage data transfer'
908 cursor = self.get_new_cursor()
909 self._select('order by %s' %self.clusterKey, cursor=cursor)
910 return self.generic_iterator(cursor, self.fetch_cluster)
911 def iteritems(self):
912 'uses arraysize / maxCache and fetchmany() to manage data transfer'
913 cursor = self.get_new_cursor()
914 self._select('order by %s' %self.clusterKey, cursor=cursor)
915 return self.generic_iterator(cursor, self.fetch_cluster,
916 map_f=generate_items)
918 class SQLForeignRelation(object):
919 'mapping based on matching a foreign key in an SQL table'
920 def __init__(self,table,keyName):
921 self.table=table
922 self.keyName=keyName
923 def __getitem__(self,k):
924 'get list of objects o with getattr(o,keyName)==k.id'
925 l=[]
926 for o in self.table.select('where %s=%%s'%self.keyName,(k.id,)):
927 l.append(o)
928 if len(l)==0:
929 raise KeyError('%s not found in %s' %(str(k),self.name))
930 return l
933 class SQLTableNoCache(SQLTableBase):
934 '''Provide on-the-fly access to rows in the database;
935 values are simply an object interface (SQLRow) to back-end db query.
936 Row data are not stored locally, but always accessed by querying the db'''
937 itemClass=SQLRow # DEFAULT OBJECT CLASS FOR ROWS...
938 keys=getKeys
939 __iter__ = iter_keys
940 def getID(self,t): return t[0] # GET ID FROM TUPLE
941 def select(self,whereClause,params):
942 return SQLTableBase.select(self,whereClause,params,self.oclass,
943 self._attrSQL('id'))
944 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
945 try:
946 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
947 except KeyError: # NOT FOUND, SO TRY THE DATABASE
948 self._select('where %s=%%s' % self.primary_key, (k,),
949 self.primary_key)
950 t = self.cursor.fetchmany(2)
951 if len(t) != 1:
952 raise KeyError('id %s non-existent or not unique' % k)
953 o = self.itemClass(k) # create obj referencing this ID
954 self._weakValueDict[k] = o # cache the SQLRow object
955 return o
956 def __setitem__(self, k, v):
957 if not self.writeable:
958 raise ValueError('this database is read only!')
959 try:
960 if v.db != self:
961 raise AttributeError
962 except AttributeError:
963 raise ValueError('object not bound to itemClass for this db!')
964 try:
965 del self[k] # delete row with new ID if any
966 except KeyError:
967 pass
968 try:
969 del self._weakValueDict[v.id] # delete from old cache location
970 except KeyError:
971 pass
972 self._update(v.id, self.primary_key, k) # just change its ID in db
973 v.cache_id(k) # change the cached ID value
974 self._weakValueDict[k] = v # assign to new cache location
975 def addAttrAlias(self,**kwargs):
976 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
978 SQLRow._tableclass=SQLTableNoCache # SQLRow IS FOR NON-CACHING TABLE INTERFACE
981 class SQLTableMultiNoCache(SQLTableBase):
982 "Trivial on-the-fly access for table with key that returns multiple rows"
983 itemClass = TupleO # default itemClass; constructor can override
984 _distinct_key='id' # DEFAULT COLUMN TO USE AS KEY
985 def keys(self):
986 return getKeys(self, selectCols='distinct(%s)'
987 % self._attrSQL(self._distinct_key))
988 def __iter__(self):
989 return iter_keys(self, 'distinct(%s)' % self._attrSQL(self._distinct_key))
990 def __getitem__(self,id):
991 sql,params = self._format_query('select * from %s where %s=%%s'
992 %(self.name,self._attrSQL(self._distinct_key)),(id,))
993 self.cursor.execute(sql, params)
994 l=self.cursor.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
995 for row in l:
996 yield self.itemClass(row)
997 def addAttrAlias(self,**kwargs):
998 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
1002 class SQLEdges(SQLTableMultiNoCache):
1003 '''provide iterator over edges as (source,target,edge)
1004 and getitem[edge] --> [(source,target),...]'''
1005 _distinct_key='edge_id'
1006 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1007 _pickleAttrs.update(dict(graph=0))
1008 def keys(self):
1009 self.cursor.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1010 %(self._attrSQL('source_id'),self._attrSQL('target_id'),
1011 self._attrSQL('edge_id'),self.name,
1012 self._attrSQL('target_id'),self._attrSQL('source_id'),
1013 self._attrSQL('target_id')))
1014 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1015 for source_id,target_id,edge_id in self.cursor.fetchall():
1016 l.append((self.graph.unpack_source(source_id),
1017 self.graph.unpack_target(target_id),
1018 self.graph.unpack_edge(edge_id)))
1019 return l
1020 __call__=keys
1021 def __iter__(self):
1022 return iter(self.keys())
1023 def __getitem__(self,edge):
1024 sql,params = self._format_query('select %s,%s from %s where %s=%%s'
1025 %(self._attrSQL('source_id'),
1026 self._attrSQL('target_id'),
1027 self.name,
1028 self._attrSQL(self._distinct_key)),
1029 (self.graph.pack_edge(edge),))
1030 self.cursor.execute(sql, params)
1031 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1032 for source_id,target_id in self.cursor.fetchall():
1033 l.append((self.graph.unpack_source(source_id),
1034 self.graph.unpack_target(target_id)))
1035 return l
1038 class SQLEdgeDict(object):
1039 '2nd level graph interface to SQL database'
1040 def __init__(self,fromNode,table):
1041 self.fromNode=fromNode
1042 self.table=table
1043 if not hasattr(self.table,'allowMissingNodes'):
1044 sql,params = self.table._format_query('select %s from %s where %s=%%s limit 1'
1045 %(self.table.sourceSQL,
1046 self.table.name,
1047 self.table.sourceSQL),
1048 (self.fromNode,))
1049 self.table.cursor.execute(sql, params)
1050 if len(self.table.cursor.fetchall())<1:
1051 raise KeyError('node not in graph!')
1053 def __getitem__(self,target):
1054 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1055 %(self.table.edgeSQL,
1056 self.table.name,
1057 self.table.sourceSQL,
1058 self.table.targetSQL),
1059 (self.fromNode,
1060 self.table.pack_target(target)))
1061 self.table.cursor.execute(sql, params)
1062 l = self.table.cursor.fetchmany(2) # get at most two rows
1063 if len(l) != 1:
1064 raise KeyError('either no edge from source to target or not unique!')
1065 try:
1066 return self.table.unpack_edge(l[0][0]) # RETURN EDGE
1067 except IndexError:
1068 raise KeyError('no edge from node to target')
1069 def __setitem__(self,target,edge):
1070 sql,params = self.table._format_query('replace into %s values (%%s,%%s,%%s)'
1071 %self.table.name,
1072 (self.fromNode,
1073 self.table.pack_target(target),
1074 self.table.pack_edge(edge)))
1075 self.table.cursor.execute(sql, params)
1076 if not hasattr(self.table,'sourceDB') or \
1077 (hasattr(self.table,'targetDB') and self.table.sourceDB==self.table.targetDB):
1078 self.table += target # ADD AS NODE TO GRAPH
1079 def __iadd__(self,target):
1080 self[target] = None
1081 return self # iadd MUST RETURN self!
1082 def __delitem__(self,target):
1083 sql,params = self.table._format_query('delete from %s where %s=%%s and %s=%%s'
1084 %(self.table.name,
1085 self.table.sourceSQL,
1086 self.table.targetSQL),
1087 (self.fromNode,
1088 self.table.pack_target(target)))
1089 self.table.cursor.execute(sql, params)
1090 if self.table.cursor.rowcount < 1: # no rows deleted?
1091 raise KeyError('no edge from node to target')
1093 def iterator_query(self):
1094 sql,params = self.table._format_query('select %s,%s from %s where %s=%%s and %s is not null'
1095 %(self.table.targetSQL,
1096 self.table.edgeSQL,
1097 self.table.name,
1098 self.table.sourceSQL,
1099 self.table.targetSQL),
1100 (self.fromNode,))
1101 self.table.cursor.execute(sql, params)
1102 return self.table.cursor.fetchall()
1103 def keys(self):
1104 return [self.table.unpack_target(target_id)
1105 for target_id,edge_id in self.iterator_query()]
1106 def values(self):
1107 return [self.table.unpack_edge(edge_id)
1108 for target_id,edge_id in self.iterator_query()]
1109 def edges(self):
1110 return [(self.table.unpack_source(self.fromNode),self.table.unpack_target(target_id),
1111 self.table.unpack_edge(edge_id))
1112 for target_id,edge_id in self.iterator_query()]
1113 def items(self):
1114 return [(self.table.unpack_target(target_id),self.table.unpack_edge(edge_id))
1115 for target_id,edge_id in self.iterator_query()]
1116 def __iter__(self): return iter(self.keys())
1117 def itervalues(self): return iter(self.values())
1118 def iteritems(self): return iter(self.items())
1119 def __len__(self):
1120 return len(self.keys())
1121 __cmp__ = graph_cmp
1123 class SQLEdgelessDict(SQLEdgeDict):
1124 'for SQLGraph tables that lack edge_id column'
1125 def __getitem__(self,target):
1126 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1127 %(self.table.targetSQL,
1128 self.table.name,
1129 self.table.sourceSQL,
1130 self.table.targetSQL),
1131 (self.fromNode,
1132 self.table.pack_target(target)))
1133 self.table.cursor.execute(sql, params)
1134 l = self.table.cursor.fetchmany(2)
1135 if len(l) != 1:
1136 raise KeyError('either no edge from source to target or not unique!')
1137 return None # no edge info!
1138 def iterator_query(self):
1139 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s is not null'
1140 %(self.table.targetSQL,
1141 self.table.name,
1142 self.table.sourceSQL,
1143 self.table.targetSQL),
1144 (self.fromNode,))
1145 self.table.cursor.execute(sql, params)
1146 return [(t[0],None) for t in self.table.cursor.fetchall()]
1148 SQLEdgeDict._edgelessClass = SQLEdgelessDict
1150 class SQLGraphEdgeDescriptor(object):
1151 'provide an SQLEdges interface on demand'
1152 def __get__(self,obj,objtype):
1153 try:
1154 attrAlias=obj.attrAlias.copy()
1155 except AttributeError:
1156 return SQLEdges(obj.name, obj.cursor, graph=obj)
1157 else:
1158 return SQLEdges(obj.name, obj.cursor, attrAlias=attrAlias,
1159 graph=obj)
1161 def getColumnTypes(createTable,attrAlias={},defaultColumnType='int',
1162 columnAttrs=('source','target','edge'),**kwargs):
1163 'return list of [(colname,coltype),...] for source,target,edge'
1164 l = []
1165 for attr in columnAttrs:
1166 try:
1167 attrName = attrAlias[attr+'_id']
1168 except KeyError:
1169 attrName = attr+'_id'
1170 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1171 l.append((attrName,createTable[attr+'_id']))
1172 continue
1173 except (KeyError,TypeError):
1174 pass
1175 try: # get type info from primary key for that database
1176 db = kwargs[attr+'DB']
1177 if db is None:
1178 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1179 except KeyError:
1180 pass
1181 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1182 it = iter(db)
1183 try: # GET ONE IDENTIFIER FROM THE DATABASE
1184 k = it.next()
1185 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1186 try:
1187 l.append((attrName,db.columnType[db.primary_key]))
1188 continue
1189 except AttributeError:
1190 pass
1191 else: # GET THE TYPE FROM THIS IDENTIFIER
1192 if isinstance(k,int) or isinstance(k,long):
1193 l.append((attrName,'int'))
1194 continue
1195 elif isinstance(k,str):
1196 l.append((attrName,'varchar(32)'))
1197 continue
1198 else:
1199 raise ValueError('SQLGraph node / edge must be int or str!')
1200 l.append((attrName,defaultColumnType))
1201 logger.warn('no type info found for %s, so using default: %s'
1202 % (attrName, defaultColumnType))
1205 return l
1208 class SQLGraph(SQLTableMultiNoCache):
1209 '''provide a graph interface via a SQL table. Key capabilities are:
1210 - setitem with an empty dictionary: a dummy operation
1211 - getitem with a key that exists: return a placeholder
1212 - setitem with non empty placeholder: again a dummy operation
1213 EXAMPLE TABLE SCHEMA:
1214 create table mygraph (source_id int not null,target_id int,edge_id int,
1215 unique(source_id,target_id));
1217 _distinct_key='source_id'
1218 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1219 _pickleAttrs.update(dict(sourceDB=0,targetDB=0,edgeDB=0,allowMissingNodes=0))
1220 _edgeClass = SQLEdgeDict
1221 def __init__(self,name,*l,**kwargs):
1222 graphArgs,tableArgs = split_kwargs(kwargs,
1223 ('attrAlias','defaultColumnType','columnAttrs',
1224 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1225 'edgeDictClass','graph'))
1226 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1227 c = getColumnTypes(**kwargs)
1228 tableArgs['createTable'] = \
1229 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1230 % (name,c[0][0],c[0][1],c[1][0],c[1][1],c[2][0],c[2][1],c[0][0],c[1][0])
1231 try:
1232 self.allowMissingNodes = kwargs['allowMissingNodes']
1233 except KeyError: pass
1234 SQLTableMultiNoCache.__init__(self,name,*l,**tableArgs)
1235 self.sourceSQL = self._attrSQL('source_id')
1236 self.targetSQL = self._attrSQL('target_id')
1237 try:
1238 self.edgeSQL = self._attrSQL('edge_id')
1239 except AttributeError:
1240 self.edgeSQL = None
1241 self._edgeClass = self._edgeClass._edgelessClass
1242 save_graph_db_refs(self,**kwargs)
1243 def __getitem__(self,k):
1244 return self._edgeClass(self.pack_source(k),self)
1245 def __iadd__(self,k):
1246 sql,params = self._format_query('delete from %s where %s=%%s and %s is null'
1247 % (self.name,self.sourceSQL,self.targetSQL),
1248 (self.pack_source(k),))
1249 self.cursor.execute(sql, params)
1250 sql,params = self._format_query('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1251 % self.name,(self.pack_source(k),))
1252 self.cursor.execute(sql, params)
1253 return self # iadd MUST RETURN SELF!
1254 def __isub__(self,k):
1255 sql,params = self._format_query('delete from %s where %s=%%s'
1256 % (self.name,self.sourceSQL),
1257 (self.pack_source(k),))
1258 self.cursor.execute(sql, params)
1259 if self.cursor.rowcount == 0:
1260 raise KeyError('node not found in graph')
1261 return self # iadd MUST RETURN SELF!
1262 __setitem__ = graph_setitem
1263 def __contains__(self,k):
1264 sql,params = self._format_query('select * from %s where %s=%%s limit 1'
1265 %(self.name,self.sourceSQL),
1266 (self.pack_source(k),))
1267 self.cursor.execute(sql, params)
1268 l = self.cursor.fetchmany(2)
1269 return len(l) > 0
1270 def __invert__(self):
1271 'get an interface to the inverse graph mapping'
1272 try: # CACHED
1273 return self._inverse
1274 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1275 attrAlias = dict(source_id=self.targetSQL, # SWAP SOURCE & TARGET
1276 target_id=self.sourceSQL,
1277 edge_id=self.edgeSQL)
1278 if self.edgeSQL is None: # no edge interface
1279 del attrAlias['edge_id']
1280 self._inverse=SQLGraph(self.name,self.cursor,
1281 attrAlias=attrAlias,
1282 **graph_db_inverse_refs(self))
1283 self._inverse._inverse=self
1284 return self._inverse
1285 def __iter__(self):
1286 for k in SQLTableMultiNoCache.__iter__(self):
1287 yield self.unpack_source(k)
1288 def iteritems(self):
1289 for k in SQLTableMultiNoCache.__iter__(self):
1290 yield (self.unpack_source(k), self._edgeClass(k, self))
1291 def itervalues(self):
1292 for k in SQLTableMultiNoCache.__iter__(self):
1293 yield self._edgeClass(k, self)
1294 def keys(self):
1295 return [self.unpack_source(k) for k in SQLTableMultiNoCache.keys(self)]
1296 def values(self): return list(self.itervalues())
1297 def items(self): return list(self.iteritems())
1298 edges=SQLGraphEdgeDescriptor()
1299 update = update_graph
1300 def __len__(self):
1301 'get number of source nodes in graph'
1302 self.cursor.execute('select count(distinct %s) from %s'
1303 %(self.sourceSQL,self.name))
1304 return self.cursor.fetchone()[0]
1305 __cmp__ = graph_cmp
1306 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1307 ## def __cmp__(self,other):
1308 ## node = ()
1309 ## n = 0
1310 ## d = None
1311 ## it = iter(self.edges)
1312 ## while True:
1313 ## try:
1314 ## source,target,edge = it.next()
1315 ## except StopIteration:
1316 ## source = None
1317 ## if source!=node:
1318 ## if d is not None:
1319 ## diff = cmp(n_target,len(d))
1320 ## if diff!=0:
1321 ## return diff
1322 ## if source is None:
1323 ## break
1324 ## node = source
1325 ## n += 1 # COUNT SOURCE NODES
1326 ## n_target = 0
1327 ## try:
1328 ## d = other[node]
1329 ## except KeyError:
1330 ## return 1
1331 ## try:
1332 ## diff = cmp(edge,d[target])
1333 ## except KeyError:
1334 ## return 1
1335 ## if diff!=0:
1336 ## return diff
1337 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1338 ## return cmp(n,len(other))
1340 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1342 class SQLIDGraph(SQLGraph):
1343 add_trivial_packing_methods(locals())
1344 SQLGraph._IDGraphClass = SQLIDGraph
1348 class SQLEdgeDictClustered(dict):
1349 'simple cache for 2nd level dictionary of target_id:edge_id'
1350 def __init__(self,g,fromNode):
1351 self.g=g
1352 self.fromNode=fromNode
1353 dict.__init__(self)
1354 def __iadd__(self,l):
1355 for target_id,edge_id in l:
1356 dict.__setitem__(self,target_id,edge_id)
1357 return self # iadd MUST RETURN SELF!
1359 class SQLEdgesClusteredDescr(object):
1360 def __get__(self,obj,objtype):
1361 e=SQLEdgesClustered(obj.table,obj.edge_id,obj.source_id,obj.target_id,
1362 graph=obj,**graph_db_inverse_refs(obj,True))
1363 for source_id,d in obj.d.iteritems(): # COPY EDGE CACHE
1364 e.load([(edge_id,source_id,target_id)
1365 for (target_id,edge_id) in d.iteritems()])
1366 return e
1368 class SQLGraphClustered(object):
1369 'SQL graph with clustered caching -- loads an entire cluster at a time'
1370 _edgeDictClass=SQLEdgeDictClustered
1371 def __init__(self,table,source_id='source_id',target_id='target_id',
1372 edge_id='edge_id',clusterKey=None,**kwargs):
1373 import types
1374 if isinstance(table,types.StringType): # CREATE THE TABLE INTERFACE
1375 if clusterKey is None:
1376 raise ValueError('you must provide a clusterKey argument!')
1377 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1378 c = getColumnTypes(attrAlias=dict(source_id=source_id,target_id=target_id,
1379 edge_id=edge_id),**kwargs)
1380 kwargs['createTable'] = \
1381 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1382 % (table,c[0][0],c[0][1],c[1][0],c[1][1],
1383 c[2][0],c[2][1],c[0][0],c[1][0])
1384 table = SQLTableClustered(table,clusterKey=clusterKey,**kwargs)
1385 self.table=table
1386 self.source_id=source_id
1387 self.target_id=target_id
1388 self.edge_id=edge_id
1389 self.d={}
1390 save_graph_db_refs(self,**kwargs)
1391 _pickleAttrs = dict(table=0,source_id=0,target_id=0,edge_id=0,sourceDB=0,targetDB=0,
1392 edgeDB=0)
1393 def __getstate__(self):
1394 state = standard_getstate(self)
1395 state['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1396 return state
1397 def __getitem__(self,k):
1398 'get edgeDict for source node k, from cache or by loading its cluster'
1399 try: # GET DIRECTLY FROM CACHE
1400 return self.d[k]
1401 except KeyError:
1402 if hasattr(self,'_isLoaded'):
1403 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1404 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1405 sql,params = self.table._format_query('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1406 %(self.source_id,self.target_id,
1407 self.edge_id,self.table.name,
1408 self.table.name,self.source_id,
1409 self.table.clusterKey,self.table.clusterKey,
1410 self.table.primary_key),
1411 (self.pack_source(k),))
1412 self.table.cursor.execute(sql, params)
1413 self.load(self.table.cursor.fetchall()) # CACHE THIS CLUSTER
1414 return self.d[k] # RETURN EDGE DICT FOR THIS NODE
1415 def load(self,l=None,unpack=True):
1416 'load the specified rows (or all, if None provided) into local cache'
1417 if l is None:
1418 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1419 return self._isLoaded
1420 except AttributeError:
1421 pass
1422 self.table.cursor.execute('select %s,%s,%s from %s'
1423 %(self.source_id,self.target_id,
1424 self.edge_id,self.table.name))
1425 l=self.table.cursor.fetchall()
1426 self._isLoaded=True
1427 self.d.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1428 for source,target,edge in l: # SAVE TO OUR CACHE
1429 if unpack:
1430 source = self.unpack_source(source)
1431 target = self.unpack_target(target)
1432 edge = self.unpack_edge(edge)
1433 try:
1434 self.d[source] += [(target,edge)]
1435 except KeyError:
1436 d = self._edgeDictClass(self,source)
1437 d += [(target,edge)]
1438 self.d[source] = d
1439 def __invert__(self):
1440 'interface to reverse graph mapping'
1441 try:
1442 return self._inverse # INVERSE MAP ALREADY EXISTS
1443 except AttributeError:
1444 pass
1445 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1446 self._inverse=SQLGraphClustered(self.table,self.target_id,self.source_id,
1447 self.edge_id,**graph_db_inverse_refs(self))
1448 self._inverse._inverse=self
1449 for source,d in self.d.iteritems(): # INVERT OUR CACHE
1450 self._inverse.load([(target,source,edge)
1451 for (target,edge) in d.iteritems()],unpack=False)
1452 return self._inverse
1453 edges=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1454 update = update_graph
1455 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1456 def __iter__(self): ################# ITERATORS
1457 'uses db select; does not force load'
1458 return iter(self.keys())
1459 def keys(self):
1460 'uses db select; does not force load'
1461 self.table.cursor.execute('select distinct(%s) from %s'
1462 %(self.source_id,self.table.name))
1463 return [self.unpack_source(t[0])
1464 for t in self.table.cursor.fetchall()]
1465 methodFactory(['iteritems','items','itervalues','values'],
1466 'lambda self:(self.load(),self.d.%s())[1]',locals())
1467 def __contains__(self,k):
1468 try:
1469 x=self[k]
1470 return True
1471 except KeyError:
1472 return False
1474 class SQLIDGraphClustered(SQLGraphClustered):
1475 add_trivial_packing_methods(locals())
1476 SQLGraphClustered._IDGraphClass = SQLIDGraphClustered
1478 class SQLEdgesClustered(SQLGraphClustered):
1479 'edges interface for SQLGraphClustered'
1480 _edgeDictClass = list
1481 _pickleAttrs = SQLGraphClustered._pickleAttrs.copy()
1482 _pickleAttrs.update(dict(graph=0))
1483 def keys(self):
1484 self.load()
1485 result = []
1486 for edge_id,l in self.d.iteritems():
1487 for source_id,target_id in l:
1488 result.append((self.graph.unpack_source(source_id),
1489 self.graph.unpack_target(target_id),
1490 self.graph.unpack_edge(edge_id)))
1491 return result
1493 class ForeignKeyInverse(object):
1494 'map each key to a single value according to its foreign key'
1495 def __init__(self,g):
1496 self.g = g
1497 def __getitem__(self,obj):
1498 self.check_obj(obj)
1499 source_id = getattr(obj,self.g.keyColumn)
1500 if source_id is None:
1501 return None
1502 return self.g.sourceDB[source_id]
1503 def __setitem__(self,obj,source):
1504 self.check_obj(obj)
1505 if source is not None:
1506 self.g[source][obj] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1507 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1508 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1509 old_source = self[obj]
1510 if old_source is not None:
1511 del self.g[old_source][obj]
1512 def check_obj(self,obj):
1513 'raise KeyError if obj not from this db'
1514 try:
1515 if obj.db != self.g.targetDB:
1516 raise AttributeError
1517 except AttributeError:
1518 raise KeyError('key is not from targetDB of this graph!')
1519 def __contains__(self,obj):
1520 try:
1521 self.check_obj(obj)
1522 return True
1523 except KeyError:
1524 return False
1525 def __iter__(self):
1526 return self.g.targetDB.itervalues()
1527 def keys(self):
1528 return self.g.targetDB.values()
1529 def iteritems(self):
1530 for obj in self:
1531 source_id = getattr(obj,self.g.keyColumn)
1532 if source_id is None:
1533 yield obj,None
1534 else:
1535 yield obj,self.g.sourceDB[source_id]
1536 def items(self):
1537 return list(self.iteritems())
1538 def itervalues(self):
1539 for obj,val in self.iteritems():
1540 yield val
1541 def values(self):
1542 return list(self.itervalues())
1543 def __invert__(self):
1544 return self.g
1546 class ForeignKeyEdge(dict):
1547 '''edge interface to a foreign key in an SQL table.
1548 Caches dict of target nodes in itself; provides dict interface.
1549 Adds or deletes edges by setting foreign key values in the table'''
1550 def __init__(self,g,k):
1551 dict.__init__(self)
1552 self.g = g
1553 self.src = k
1554 for v in g.targetDB.select('where %s=%%s' % g.keyColumn,(k.id,)): # SEARCH THE DB
1555 dict.__setitem__(self,v,None) # SAVE IN CACHE
1556 def __setitem__(self,dest,v):
1557 if not hasattr(dest,'db') or dest.db != self.g.targetDB:
1558 raise KeyError('dest is not in the targetDB bound to this graph!')
1559 if v is not None:
1560 raise ValueError('sorry,this graph cannot store edge information!')
1561 if not hasattr(dest,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1562 old_source = self.g._inverse[dest] # CHECK FOR PRE-EXISTING EDGE
1563 if old_source is not None: # REMOVE OLD EDGE FROM CACHE
1564 dict.__delitem__(self.g[old_source],dest)
1565 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1566 setattr(dest,self.g.keyColumn,self.src.id) # SAVE TO DB ATTRIBUTE
1567 dict.__setitem__(self,dest,None) # SAVE IN CACHE
1568 def __delitem__(self,dest):
1569 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1570 setattr(dest,self.g.keyColumn,None) # SAVE TO DB ATTRIBUTE
1571 dict.__delitem__(self,dest) # REMOVE FROM CACHE
1573 class ForeignKeyGraph(object, UserDict.DictMixin):
1574 '''graph interface to a foreign key in an SQL table
1575 Caches dict of target nodes in itself; provides dict interface.
1577 def __init__(self, sourceDB, targetDB, keyColumn, autoGC=True, **kwargs):
1578 '''sourceDB is any database of source nodes;
1579 targetDB must be an SQL database of target nodes;
1580 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1581 if autoGC: # automatically garbage collect unused objects
1582 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
1583 else:
1584 self._weakValueDict = {}
1585 self.autoGC = autoGC
1586 self.sourceDB = sourceDB
1587 self.targetDB = targetDB
1588 self.keyColumn = keyColumn
1589 self._inverse = ForeignKeyInverse(self)
1590 _pickleAttrs = dict(sourceDB=0, targetDB=0, keyColumn=0, autoGC=0)
1591 __getstate__ = standard_getstate ########### SUPPORT FOR PICKLING
1592 __setstate__ = standard_setstate
1593 def _inverse_schema(self):
1594 'provide custom schema rule for inverting this graph... just use keyColumn!'
1595 return dict(invert=True,uniqueMapping=True)
1596 def __getitem__(self,k):
1597 if not hasattr(k,'db') or k.db != self.sourceDB:
1598 raise KeyError('object is not in the sourceDB bound to this graph!')
1599 try:
1600 return self._weakValueDict[k.id] # get from cache
1601 except KeyError:
1602 pass
1603 d = ForeignKeyEdge(self,k)
1604 self._weakValueDict[k.id] = d # save in cache
1605 return d
1606 def __setitem__(self, k, v):
1607 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1608 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1609 def __delitem__(self, k):
1610 raise KeyError('''Instead of del g[k], follow a graph
1611 interface: del g[src][dest]''')
1612 def keys(self):
1613 return self.sourceDB.values()
1614 __invert__ = standard_invert
1616 def describeDBTables(name,cursor,idDict):
1618 Get table info about database <name> via <cursor>, and store primary keys
1619 in idDict, along with a list of the tables each key indexes.
1621 cursor.execute('use %s' % name)
1622 cursor.execute('show tables')
1623 tables={}
1624 l=[c[0] for c in cursor.fetchall()]
1625 for t in l:
1626 tname=name+'.'+t
1627 o=SQLTable(tname,cursor)
1628 tables[tname]=o
1629 for f in o.description:
1630 if f==o.primary_key:
1631 idDict.setdefault(f, []).append(o)
1632 elif f[-3:]=='_id' and f not in idDict:
1633 idDict[f]=[]
1634 return tables
1638 def indexIDs(tables,idDict=None):
1639 "Get an index of primary keys in the <tables> dictionary."
1640 if idDict==None:
1641 idDict={}
1642 for o in tables.values():
1643 if o.primary_key:
1644 if o.primary_key not in idDict:
1645 idDict[o.primary_key]=[]
1646 idDict[o.primary_key].append(o) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1647 for f in o.description:
1648 if f[-3:]=='_id' and f not in idDict:
1649 idDict[f]=[]
1650 return idDict
1654 def suffixSubset(tables,suffix):
1655 "Filter table index for those matching a specific suffix"
1656 subset={}
1657 for name,t in tables.items():
1658 if name.endswith(suffix):
1659 subset[name]=t
1660 return subset
1663 PRIMARY_KEY=1
1665 def graphDBTables(tables,idDict):
1666 g=dictgraph()
1667 for t in tables.values():
1668 for f in t.description:
1669 if f==t.primary_key:
1670 edgeInfo=PRIMARY_KEY
1671 else:
1672 edgeInfo=None
1673 g.setEdge(f,t,edgeInfo)
1674 g.setEdge(t,f,edgeInfo)
1675 return g
1677 SQLTypeTranslation= {types.StringType:'varchar(32)',
1678 types.IntType:'int',
1679 types.FloatType:'float'}
1681 def createTableFromRepr(rows,tableName,cursor,typeTranslation=None,
1682 optionalDict=None,indexDict=()):
1683 """Save rows into SQL tableName using cursor, with optional
1684 translations of columns to specific SQL types (specified
1685 by typeTranslation dict).
1686 - optionDict can specify columns that are allowed to be NULL.
1687 - indexDict can specify columns that must be indexed; columns
1688 whose names end in _id will be indexed by default.
1689 - rows must be an iterator which in turn returns dictionaries,
1690 each representing a tuple of values (indexed by their column
1691 names).
1693 try:
1694 row=rows.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1695 except StopIteration:
1696 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1697 try:
1698 createTableFromRow(cursor, tableName,row,typeTranslation,
1699 optionalDict,indexDict)
1700 except:
1701 pass
1702 storeRow(cursor,tableName,row) # SAVE OUR FIRST ROW
1703 for row in rows: # NOW SAVE ALL THE ROWS
1704 storeRow(cursor,tableName,row)
1706 def createTableFromRow(cursor, tableName, row,typeTranslation=None,
1707 optionalDict=None,indexDict=()):
1708 create_defs=[]
1709 for col,val in row.items(): # PREPARE SQL TYPES FOR COLUMNS
1710 coltype=None
1711 if typeTranslation!=None and col in typeTranslation:
1712 coltype=typeTranslation[col] # USER-SUPPLIED TRANSLATION
1713 elif type(val) in SQLTypeTranslation:
1714 coltype=SQLTypeTranslation[type(val)]
1715 else: # SEARCH FOR A COMPATIBLE TYPE
1716 for t in SQLTypeTranslation:
1717 if isinstance(val,t):
1718 coltype=SQLTypeTranslation[t]
1719 break
1720 if coltype==None:
1721 raise TypeError("Don't know SQL type to use for %s" % col)
1722 create_def='%s %s' %(col,coltype)
1723 if optionalDict==None or col not in optionalDict:
1724 create_def+=' not null'
1725 create_defs.append(create_def)
1726 for col in row: # CREATE INDEXES FOR ID COLUMNS
1727 if col[-3:]=='_id' or col in indexDict:
1728 create_defs.append('index(%s)' % col)
1729 cmd='create table if not exists %s (%s)' % (tableName,','.join(create_defs))
1730 cursor.execute(cmd) # CREATE THE TABLE IN THE DATABASE
1733 def storeRow(cursor, tableName, row):
1734 row_format=','.join(len(row)*['%s'])
1735 cmd='insert into %s values (%s)' % (tableName,row_format)
1736 cursor.execute(cmd,tuple(row.values()))
1738 def storeRowDelayed(cursor, tableName, row):
1739 row_format=','.join(len(row)*['%s'])
1740 cmd='insert delayed into %s values (%s)' % (tableName,row_format)
1741 cursor.execute(cmd,tuple(row.values()))
1744 class TableGroup(dict):
1745 'provide attribute access to dbname qualified tablenames'
1746 def __init__(self,db='test',suffix=None,**kw):
1747 dict.__init__(self)
1748 self.db=db
1749 if suffix is not None:
1750 self.suffix=suffix
1751 for k,v in kw.items():
1752 if v is not None and '.' not in v:
1753 v=self.db+'.'+v # ADD DATABASE NAME AS PREFIX
1754 self[k]=v
1755 def __getattr__(self,k):
1756 return self[k]
1758 def sqlite_connect(*args, **kwargs):
1759 sqlite = import_sqlite()
1760 connection = sqlite.connect(*args, **kwargs)
1761 cursor = connection.cursor()
1762 return connection, cursor
1764 class DBServerInfo(object):
1765 'picklable reference to a database server'
1766 def __init__(self, moduleName='MySQLdb', *args, **kwargs):
1767 try:
1768 self.__class__ = _DBServerModuleDict[moduleName]
1769 except KeyError:
1770 raise ValueError('Module name not found in _DBServerModuleDict: '\
1771 + moduleName)
1772 self.moduleName = moduleName
1773 self.args = args
1774 self.kwargs = kwargs # connection arguments
1776 def cursor(self):
1777 """returns cursor associated with the DB server info (reused)"""
1778 try:
1779 return self._cursor
1780 except AttributeError:
1781 self._start_connection()
1782 return self._cursor
1784 def new_cursor(self, arraysize=None):
1785 """returns a NEW cursor; you must close it yourself! """
1786 if not hasattr(self, '_connection'):
1787 self._start_connection()
1788 cursor = self._connection.cursor()
1789 if arraysize is not None:
1790 cursor.arraysize = arraysize
1791 return cursor
1793 def close(self):
1794 """Close file containing this database"""
1795 self._cursor.close()
1796 self._connection.close()
1797 del self._cursor
1798 del self._connection
1800 def __getstate__(self):
1801 """return all picklable arguments"""
1802 return dict(args=self.args, kwargs=self.kwargs,
1803 moduleName=self.moduleName)
1806 class MySQLServerInfo(DBServerInfo):
1807 'customized for MySQLdb SSCursor support via new_cursor()'
1808 def _start_connection(self):
1809 self._connection,self._cursor = mysql_connect(*self.args, **self.kwargs)
1810 def new_cursor(self, arraysize=None):
1811 'provide streaming cursor support'
1812 try:
1813 conn = self._conn_sscursor
1814 except AttributeError:
1815 self._conn_sscursor,cursor = mysql_connect(useStreaming=True,
1816 *self.args, **self.kwargs)
1817 else:
1818 cursor = self._conn_sscursor.cursor()
1819 if arraysize is not None:
1820 cursor.arraysize = arraysize
1821 return cursor
1822 def close(self):
1823 DBServerInfo.close(self)
1824 try:
1825 self._conn_sscursor.close()
1826 del self._conn_sscursor
1827 except AttributeError:
1828 pass
1830 class SQLiteServerInfo(DBServerInfo):
1831 """picklable reference to a sqlite database"""
1832 def __init__(self, database, *args, **kwargs):
1833 """Takes same arguments as sqlite3.connect()"""
1834 DBServerInfo.__init__(self, 'sqlite',
1835 SourceFileName(database), # save abs path!
1836 *args, **kwargs)
1837 def _start_connection(self):
1838 self._connection,self._cursor = sqlite_connect(*self.args, **self.kwargs)
1839 def __getstate__(self):
1840 if self.args[0] == ':memory:':
1841 raise ValueError('SQLite in-memory database is not picklable!')
1842 return DBServerInfo.__getstate__(self)
1844 # list of DBServerInfo subclasses for different modules
1845 _DBServerModuleDict = dict(MySQLdb=MySQLServerInfo, sqlite=SQLiteServerInfo)
1848 class MapView(object, UserDict.DictMixin):
1849 'general purpose 1:1 mapping defined by any SQL query'
1850 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None,
1851 serverInfo=None, inverseSQL=None, **kwargs):
1852 self.sourceDB = sourceDB
1853 self.targetDB = targetDB
1854 self.viewSQL = viewSQL
1855 self.inverseSQL = inverseSQL
1856 if cursor is None:
1857 if serverInfo is not None: # get cursor from serverInfo
1858 cursor = serverInfo.cursor()
1859 else:
1860 try: # can we get it from our other db?
1861 serverInfo = sourceDB.serverInfo
1862 except AttributeError:
1863 raise ValueError('you must provide serverInfo or cursor!')
1864 else:
1865 cursor = serverInfo.cursor()
1866 self.cursor = cursor
1867 self.serverInfo = serverInfo
1868 self.get_sql_format(False) # get sql formatter for this db interface
1869 _schemaModuleDict = _schemaModuleDict # default module list
1870 get_sql_format = get_table_schema
1871 def __getitem__(self, k):
1872 if not hasattr(k,'db') or k.db != self.sourceDB:
1873 raise KeyError('object is not in the sourceDB bound to this map!')
1874 sql,params = self._format_query(self.viewSQL, (k.id,))
1875 self.cursor.execute(sql, params) # formatted for this db interface
1876 t = self.cursor.fetchmany(2) # get at most two rows
1877 if len(t) != 1:
1878 raise KeyError('%s not found in MapView, or not unique'
1879 % str(k))
1880 return self.targetDB[t[0][0]] # get the corresponding object
1881 _pickleAttrs = dict(sourceDB=0, targetDB=0, viewSQL=0, serverInfo=0,
1882 inverseSQL=0)
1883 __getstate__ = standard_getstate
1884 __setstate__ = standard_setstate
1885 __setitem__ = __delitem__ = clear = pop = popitem = update = \
1886 setdefault = read_only_error
1887 def __iter__(self):
1888 'only yield sourceDB items that are actually in this mapping!'
1889 for k in self.sourceDB.itervalues():
1890 try:
1891 self[k]
1892 yield k
1893 except KeyError:
1894 pass
1895 def keys(self):
1896 return [k for k in self] # don't use list(self); causes infinite loop!
1897 def __invert__(self):
1898 try:
1899 return self._inverse
1900 except AttributeError:
1901 if self.inverseSQL is None:
1902 raise ValueError('this MapView has no inverseSQL!')
1903 self._inverse = self.__class__(self.targetDB, self.sourceDB,
1904 self.inverseSQL, self.cursor,
1905 serverInfo=self.serverInfo,
1906 inverseSQL=self.viewSQL)
1907 self._inverse._inverse = self
1908 return self._inverse
1910 class GraphViewEdgeDict(UserDict.DictMixin):
1911 'edge dictionary for GraphView: just pre-loaded on init'
1912 def __init__(self, g, k):
1913 self.g = g
1914 self.k = k
1915 sql,params = self.g._format_query(self.g.viewSQL, (k.id,))
1916 self.g.cursor.execute(sql, params) # run the query
1917 l = self.g.cursor.fetchall() # get results
1918 if len(l) <= 0:
1919 raise KeyError('key %s not in GraphView' % k.id)
1920 self.targets = [t[0] for t in l] # preserve order of the results
1921 d = {} # also keep targetID:edgeID mapping
1922 if self.g.edgeDB is not None: # save with edge info
1923 for t in l:
1924 d[t[0]] = t[1]
1925 else:
1926 for t in l:
1927 d[t[0]] = None
1928 self.targetDict = d
1929 def __len__(self):
1930 return len(self.targets)
1931 def __iter__(self):
1932 for k in self.targets:
1933 yield self.g.targetDB[k]
1934 def keys(self):
1935 return list(self)
1936 def iteritems(self):
1937 if self.g.edgeDB is not None: # save with edge info
1938 for k in self.targets:
1939 yield (self.g.targetDB[k], self.g.edgeDB[self.targetDict[k]])
1940 else: # just save the list of targets, no edge info
1941 for k in self.targets:
1942 yield (self.g.targetDB[k], None)
1943 def __getitem__(self, o, exitIfFound=False):
1944 'for the specified target object, return its associated edge object'
1945 try:
1946 if o.db is not self.g.targetDB:
1947 raise KeyError('key is not part of targetDB!')
1948 edgeID = self.targetDict[o.id]
1949 except AttributeError:
1950 raise KeyError('key has no id or db attribute?!')
1951 if exitIfFound:
1952 return
1953 if self.g.edgeDB is not None: # return the edge object
1954 return self.g.edgeDB[edgeID]
1955 else: # no edge info
1956 return None
1957 def __contains__(self, o):
1958 try:
1959 self.__getitem__(o, True) # raise KeyError if not found
1960 return True
1961 except KeyError:
1962 return False
1963 __setitem__ = __delitem__ = clear = pop = popitem = update = \
1964 setdefault = read_only_error
1966 class GraphView(MapView):
1967 'general purpose graph interface defined by any SQL query'
1968 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None, edgeDB=None,
1969 **kwargs):
1970 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
1971 self.edgeDB = edgeDB
1972 MapView.__init__(self, sourceDB, targetDB, viewSQL, cursor, **kwargs)
1973 def __getitem__(self, k):
1974 if not hasattr(k,'db') or k.db != self.sourceDB:
1975 raise KeyError('object is not in the sourceDB bound to this map!')
1976 return GraphViewEdgeDict(self, k)
1977 _pickleAttrs = MapView._pickleAttrs.copy()
1978 _pickleAttrs.update(dict(edgeDB=0))
1980 # @CTB move to sqlgraph.py?
1982 class SQLSequence(SQLRow, SequenceBase):
1983 """Transparent access to a DB row representing a sequence.
1985 Use attrAlias dict to rename 'length' to something else.
1987 def _init_subclass(cls, db, **kwargs):
1988 db.seqInfoDict = db # db will act as its own seqInfoDict
1989 SQLRow._init_subclass(db=db, **kwargs)
1990 _init_subclass = classmethod(_init_subclass)
1991 def __init__(self, id):
1992 SQLRow.__init__(self, id)
1993 SequenceBase.__init__(self)
1994 def __len__(self):
1995 return self.length
1996 def strslice(self,start,end):
1997 "Efficient access to slice of a sequence, useful for huge contigs"
1998 return self._select('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
1999 %(self.db._attrSQL('seq'),start+1,end-start))
2001 class DNASQLSequence(SQLSequence):
2002 _seqtype=DNA_SEQTYPE
2004 class RNASQLSequence(SQLSequence):
2005 _seqtype=RNA_SEQTYPE
2007 class ProteinSQLSequence(SQLSequence):
2008 _seqtype=PROTEIN_SEQTYPE