fixes to prevent unnecessary invocation of iter()
[pygr.git] / pygr / sqlgraph.py
blobcc8b52b7784765d3b1571b79c6e4b8a396ccf250
3 from __future__ import generators
4 from mapping import *
5 from sequence import SequenceBase, DNA_SEQTYPE, RNA_SEQTYPE, PROTEIN_SEQTYPE
6 import types
7 from classutil import methodFactory,standard_getstate,\
8 override_rich_cmp,generate_items,get_bound_subclass,standard_setstate,\
9 get_valid_path,standard_invert,RecentValueDictionary,read_only_error,\
10 SourceFileName, split_kwargs
11 import os
12 import platform
13 import UserDict
14 import warnings
15 import logger
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self, db, attr):
20 self.icol = db.data[attr] # index of this attribute in the tuple
21 def __get__(self, obj, klass):
22 return obj._data[self.icol]
23 def __set__(self, obj, val):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor):
27 def __set__(self, obj, val):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor):
32 'read-write interface to named attribute'
33 def __init__(self, db, attr):
34 self.attr = attr
35 self.icol = db.data[attr] # index of this attribute in the tuple
36 self.attrSQL = db._attrSQL(attr, sqlColumn=True) # SQL column name
37 def __set__(self, obj, val):
38 obj.db._update(obj.id, self.attrSQL, val) # AND UPDATE THE DATABASE
39 obj.save_local(self.attr, val)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self, db, attr):
44 self.selectSQL = db._attrSQL(attr) # SQL expression for this attr
45 def __get__(self, obj, klass):
46 return obj._select(self.selectSQL)
47 def __set__(self, obj, val):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self, obj, val):
53 obj.db._update(obj.id, self.selectSQL, val) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self, db, attr):
58 self.attr = '_'+attr
59 def __get__(self, obj, klass):
60 return getattr(obj, self.attr)
61 def __set__(self, obj, val):
62 raise AttributeError('attribute %s is read-only!' % self.attr)
65 def select_from_row(row, what):
66 "return value from SQL expression applied to this row"
67 sql,params = row.db._format_query('select %s from %s where %s=%%s limit 2'
68 % (what,row.db.name,row.db.primary_key),
69 (row.id,))
70 row.db.cursor.execute(sql, params)
71 t = row.db.cursor.fetchmany(2) # get at most two rows
72 if len(t) != 1:
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row.db.name,str(row.id),what))
75 return t[0][0] #return the single field we requested
77 def init_row_subclass(cls, db):
78 'add descriptors for db attributes'
79 for attr in db.data: # bind all database columns
80 if attr == 'id': # handle ID attribute specially
81 setattr(cls, attr, cls._idDescriptor(db, attr))
82 continue
83 try: # check if this attr maps to an SQL column
84 db._attrSQL(attr, columnNumber=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls, attr, cls._sqlDescriptor(db, attr))
87 else: # treat as interface to our stored tuple
88 setattr(cls, attr, cls._columnDescriptor(db, attr))
90 def dir_row(self):
91 """get list of column names as our attributes """
92 return self.db.data.keys()
94 class TupleO(object):
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor = TupleDescriptor
115 _idDescriptor = TupleIDDescriptor
116 _sqlDescriptor = SQLDescriptor
117 _init_subclass = classmethod(init_row_subclass)
118 _select = select_from_row
119 __dir__ = dir_row
120 def __init__(self, data):
121 self._data = data # save our data tuple
123 def insert_and_cache_id(self, l, **kwargs):
124 'insert tuple into db and cache its rowID on self'
125 self.db._insert(l) # save to database
126 try:
127 rowID = kwargs['id'] # use the ID supplied by user
128 except KeyError:
129 rowID = self.db.get_insert_id() # get auto-inc ID value
130 self.cache_id(rowID) # cache this ID on self
132 class TupleORW(TupleO):
133 'read-write version of TupleO'
134 _columnDescriptor = TupleDescriptorRW
135 insert_and_cache_id = insert_and_cache_id
136 def __init__(self, data, newRow=False, **kwargs):
137 if not newRow: # just cache data from the database
138 self._data = data
139 return
140 self._data = self.db.tuple_from_dict(kwargs) # convert to tuple
141 self.insert_and_cache_id(self._data, **kwargs)
142 def cache_id(self,row_id):
143 self.save_local('id',row_id)
144 def save_local(self,attr,val):
145 icol = self._attrcol[attr]
146 try:
147 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self._data = list(self._data)
150 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
152 TupleO._RWClass = TupleORW # record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self, db, attr, readOnly = False):
157 self.attr = attr
158 self.col = db._attrSQL(attr, sqlColumn=True) # MAP THIS TO SQL COLUMN NAME
159 self.db = db
160 if readOnly:
161 self.__class__ = self._readOnlyClass
162 def __get__(self, obj, objtype):
163 try:
164 return obj.__dict__[self.attr]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self.col==self.db.primary_key:
167 raise AttributeError
168 self.db._select('where %s=%%s' % self.db.primary_key,(obj.id,),self.col)
169 l = self.db.cursor.fetchall()
170 if len(l)!=1:
171 raise AttributeError('db row not found or not unique!')
172 obj.__dict__[self.attr] = l[0][0] # UPDATE THE CACHE
173 return l[0][0]
174 def __set__(self, obj, val):
175 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self.db._update(obj.id, self.col, val) # UPDATE THE DATABASE
177 obj.__dict__[self.attr] = val # UPDATE THE CACHE
178 ## try:
179 ## m = self.consequences
180 ## except AttributeError:
181 ## return
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
185 ## import new
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor):
188 def __set__(self, obj, val):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor._readOnlyClass = ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor = _sqlDescriptor = SQLDescriptor
200 _idDescriptor = ReadOnlyDescriptor
201 _init_subclass = classmethod(init_row_subclass)
202 _select = select_from_row
203 __dir__ = dir_row
204 def __init__(self, rowID):
205 self._id = rowID
208 class SQLRowRW(SQLRow):
209 'read-write version of SQLRow'
210 _columnDescriptor = SQLDescriptorRW
211 insert_and_cache_id = insert_and_cache_id
212 def __init__(self, rowID, newRow=False, **kwargs):
213 if not newRow: # just cache data from the database
214 return self.cache_id(rowID)
215 l = self.db.tuple_from_dict(kwargs) # convert to tuple
216 self.insert_and_cache_id(l, **kwargs)
217 def cache_id(self, rowID):
218 self._id = rowID
220 SQLRow._RWClass = SQLRowRW
224 def list_to_dict(names, values):
225 'return dictionary of those named args that are present in values[]'
226 d={}
227 for i,v in enumerate(values):
228 try:
229 d[names[i]] = v
230 except IndexError:
231 break
232 return d
235 def get_name_cursor(name=None, **kwargs):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
239 if name is not None:
240 argList = name.split() # TREAT AS WS-SEPARATED LIST
241 if len(argList)>1:
242 name = argList[0] # USE 1ST ARG AS TABLE NAME
243 argnames = ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs = kwargs.copy() # a copy we can overwrite
245 kwargs.update(list_to_dict(argnames, argList[1:]))
246 serverInfo = DBServerInfo(**kwargs)
247 return name,serverInfo.cursor(),serverInfo
249 def mysql_connect(connect=None, configFile=None, useStreaming=False, **args):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs = args.copy() # a copy we can modify
252 if 'user' not in kwargs and configFile is None: #Find where config file is
253 osname = platform.system()
254 if osname in('Microsoft', 'Windows'): # Machine is a Windows box
255 paths = []
256 try: # handle case where WINDIR not defined by Windows...
257 windir = os.environ['WINDIR']
258 paths += [(windir, 'my.ini'), (windir, 'my.cnf')]
259 except KeyError:
260 pass
261 try:
262 sysdrv = os.environ['SYSTEMDRIVE']
263 paths += [(sysdrv, os.path.sep + 'my.ini'),
264 (sysdrv, os.path.sep + 'my.cnf')]
265 except KeyError:
266 pass
267 if len(paths) > 0:
268 configFile = get_valid_path(*paths)
269 else: # treat as normal platform with home directories
270 configFile = os.path.join(os.path.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile = configFile or os.path.join( os.getcwd(), 'mysql.cnf' )
276 if configFile and os.path.exists(configFile):
277 kwargs['read_default_file'] = configFile
278 connect = None # force it to use MySQLdb
279 if connect is None:
280 import MySQLdb
281 connect = MySQLdb.connect
282 kwargs['compress'] = True
283 if useStreaming: # use server side cursors for scalable result sets
284 try:
285 from MySQLdb import cursors
286 kwargs['cursorclass'] = cursors.SSCursor
287 except (ImportError, AttributeError):
288 pass
289 conn = connect(**kwargs)
290 cursor = conn.cursor()
291 return conn,cursor
293 _mysqlMacros = dict(IGNORE='ignore', REPLACE='replace',
294 AUTO_INCREMENT='AUTO_INCREMENT', SUBSTRING='substring',
295 SUBSTR_FROM='FROM', SUBSTR_FOR='FOR')
297 def mysql_table_schema(self, analyzeSchema=True):
298 'retrieve table schema from a MySQL database, save on self'
299 import MySQLdb
300 self._format_query = SQLFormatDict(MySQLdb.paramstyle, _mysqlMacros)
301 if not analyzeSchema:
302 return
303 self.clear_schema() # reset settings and dictionaries
304 self.cursor.execute('describe %s' % self.name) # get info about columns
305 columns = self.cursor.fetchall()
306 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
307 for icol,c in enumerate(columns):
308 field = c[0]
309 self.columnName.append(field) # list of columns in same order as table
310 if c[3] == "PRI": # record as primary key
311 if self.primary_key is None:
312 self.primary_key = field
313 else:
314 try:
315 self.primary_key.append(field)
316 except AttributeError:
317 self.primary_key = [self.primary_key,field]
318 if c[1][:3].lower() == 'int':
319 self.usesIntID = True
320 else:
321 self.usesIntID = False
322 elif c[3] == "MUL":
323 self.indexed[field] = icol
324 self.description[field] = self.cursor.description[icol]
325 self.columnType[field] = c[1] # SQL COLUMN TYPE
327 _sqliteMacros = dict(IGNORE='or ignore', REPLACE='insert or replace',
328 AUTO_INCREMENT='', SUBSTRING='substr',
329 SUBSTR_FROM=',', SUBSTR_FOR=',')
331 def import_sqlite():
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
333 try:
334 import sqlite3 as sqlite
335 except ImportError:
336 from pysqlite2 import dbapi2 as sqlite
337 return sqlite
339 def sqlite_table_schema(self, analyzeSchema=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite = import_sqlite()
342 self._format_query = SQLFormatDict(sqlite.paramstyle, _sqliteMacros)
343 if not analyzeSchema:
344 return
345 self.clear_schema() # reset settings and dictionaries
346 self.cursor.execute('PRAGMA table_info("%s")' % self.name)
347 columns = self.cursor.fetchall()
348 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
349 for icol,c in enumerate(columns):
350 field = c[1]
351 self.columnName.append(field) # list of columns in same order as table
352 self.description[field] = self.cursor.description[icol]
353 self.columnType[field] = c[2] # SQL COLUMN TYPE
354 self.cursor.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self.name) # get primary key / unique indexes
355 for indexname in self.cursor.fetchall(): # search indexes for primary key
356 self.cursor.execute('PRAGMA index_info("%s")' % indexname)
357 l = self.cursor.fetchall() # get list of columns in this index
358 if len(l) == 1: # assume 1st single-column unique index is primary key!
359 self.primary_key = l[0][2]
360 break # done searching for primary key!
361 if self.primary_key is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self.cursor.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self.name)
363 sql = self.cursor.fetchall()[0][0]
364 for columnSQL in sql[sql.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL.lower(): # must be the primary key!
366 col = columnSQL.split()[0] # get column name
367 if col in self.columnType:
368 self.primary_key = col
369 break # done searching for primary key!
370 else:
371 raise ValueError('unknown primary key %s in table %s'
372 % (col,self.name))
373 if self.primary_key is not None: # check its type
374 if self.columnType[self.primary_key] == 'int' or \
375 self.columnType[self.primary_key] == 'integer':
376 self.usesIntID = True
377 else:
378 self.usesIntID = False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats = dict(pyformat='%%(%d)s', numeric=':%d', named=':%d',
393 qmark='(ignore)', format='(ignore)')
394 def __init__(self, paramstyle, substitutionDict={}):
395 self.substitutionDict = substitutionDict.copy()
396 self.paramstyle = paramstyle
397 self.paramFormat = self._paramFormats[paramstyle]
398 self.makeDict = (paramstyle == 'pyformat' or paramstyle == 'named')
399 if paramstyle == 'qmark': # handle these as simple substitution
400 self.substitutionDict['?'] = '?'
401 elif paramstyle == 'format':
402 self.substitutionDict['?'] = '%s'
403 def __getitem__(self, k):
404 'apply correct substitution for this SQL interface'
405 try:
406 return self.substitutionDict[k] # apply our substitutions
407 except KeyError:
408 pass
409 if k == '?': # sequential parameter
410 s = self.paramFormat % self.iparam
411 self.iparam += 1 # advance to the next parameter
412 return s
413 raise KeyError('unknown macro: %s' % k)
414 def __call__(self, sql, paramList):
415 'returns corrected sql,params for this interface'
416 self.iparam = 1 # DB-ABI param indexing begins at 1
417 sql = sql.replace('%s', '%(?)s') # convert format into pyformat
418 s = sql % self # apply all %(x)s replacements in sql
419 if self.makeDict: # construct a params dict
420 paramDict = {}
421 for i,param in enumerate(paramList):
422 paramDict[str(i + 1)] = param #DB-ABI param indexing begins at 1
423 return s,paramDict
424 else: # just return the original params list
425 return s,paramList
427 def get_table_schema(self, analyzeSchema=True):
428 'run the right schema function based on type of db server connection'
429 try:
430 modname = self.cursor.__class__.__module__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
433 try:
434 schema_func = self._schemaModuleDict[modname]
435 except KeyError:
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname)
439 schema_func(self, analyzeSchema) # run the schema function
442 _schemaModuleDict = {'MySQLdb.cursors':mysql_table_schema,
443 'pysqlite2.dbapi2':sqlite_table_schema,
444 'sqlite3':sqlite_table_schema}
446 class SQLTableBase(object, UserDict.DictMixin):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict = _schemaModuleDict # default module list
449 get_table_schema = get_table_schema
450 def __init__(self,name,cursor=None,itemClass=None,attrAlias=None,
451 clusterKey=None,createTable=None,graph=None,maxCache=None,
452 arraysize=1024, itemSliceClass=None, dropIfExists=False,
453 serverInfo=None, autoGC=True, orderBy=None,
454 writeable=False, **kwargs):
455 if autoGC: # automatically garbage collect unused objects
456 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
457 else:
458 self._weakValueDict = {}
459 self.autoGC = autoGC
460 self.orderBy = orderBy
461 self.writeable = writeable
462 if cursor is None:
463 if serverInfo is not None: # get cursor from serverInfo
464 cursor = serverInfo.cursor()
465 else: # try to read connection info from name or config file
466 name,cursor,serverInfo = get_name_cursor(name,**kwargs)
467 else:
468 warnings.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
469 DeprecationWarning, stacklevel=2)
470 self.cursor = cursor
471 if createTable is not None: # RUN COMMAND TO CREATE THIS TABLE
472 if dropIfExists: # get rid of any existing table
473 cursor.execute('drop table if exists ' + name)
474 self.get_table_schema(False) # check dbtype, init _format_query
475 sql,params = self._format_query(createTable, ()) # apply macros
476 cursor.execute(sql) # create the table
477 self.name = name
478 if graph is not None:
479 self.graph = graph
480 if maxCache is not None:
481 self.maxCache = maxCache
482 if arraysize is not None:
483 self.arraysize = arraysize
484 cursor.arraysize = arraysize
485 self.get_table_schema() # get schema of columns to serve as attrs
486 self.data = {} # map of all attributes, including aliases
487 for icol,field in enumerate(self.columnName):
488 self.data[field] = icol # 1st add mappings to columns
489 try:
490 self.data['id']=self.data[self.primary_key]
491 except (KeyError,TypeError):
492 pass
493 if hasattr(self,'_attr_alias'): # apply attribute aliases for this class
494 self.addAttrAlias(False,**self._attr_alias)
495 self.objclass(itemClass) # NEED TO SUBCLASS OUR ITEM CLASS
496 if itemSliceClass is not None:
497 self.itemSliceClass = itemSliceClass
498 get_bound_subclass(self, 'itemSliceClass', self.name) # need to subclass itemSliceClass
499 if attrAlias is not None: # ADD ATTRIBUTE ALIASES
500 self.attrAlias = attrAlias # RECORD FOR PICKLING PURPOSES
501 self.data.update(attrAlias)
502 if clusterKey is not None:
503 self.clusterKey=clusterKey
504 if serverInfo is not None:
505 self.serverInfo = serverInfo
507 def __len__(self):
508 self._select(selectCols='count(*)')
509 return self.cursor.fetchone()[0]
510 def __hash__(self):
511 return id(self)
512 def __cmp__(self, other):
513 'only match self and no other!'
514 if self is other:
515 return 0
516 else:
517 return cmp(id(self), id(other))
518 _pickleAttrs = dict(name=0, clusterKey=0, maxCache=0, arraysize=0,
519 attrAlias=0, serverInfo=0, autoGC=0, orderBy=0,
520 writeable=0)
521 __getstate__ = standard_getstate
522 def __setstate__(self,state):
523 # default cursor provisioning by worldbase is deprecated!
524 ## if 'serverInfo' not in state: # hmm, no address for db server?
525 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
526 ## from Data import getResource
527 ## state['cursor'] = getResource.getTableCursor(state['name'])
528 ## except ImportError:
529 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
530 self.__init__(**state)
531 def __repr__(self):
532 return '<SQL table '+self.name+'>'
534 def clear_schema(self):
535 'reset all schema information for this table'
536 self.description={}
537 self.columnName = []
538 self.columnType = {}
539 self.usesIntID = None
540 self.primary_key = None
541 self.indexed = {}
542 def _attrSQL(self,attr,sqlColumn=False,columnNumber=False):
543 "Translate python attribute name to appropriate SQL expression"
544 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
545 field=self.data[attr]
546 except KeyError:
547 raise AttributeError('attribute %s not a valid column or alias in %s'
548 % (attr,self.name))
549 if sqlColumn: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
550 try: # CHECK IF field IS COLUMN NUMBER
551 return self.columnName[field] # RETURN SQL COLUMN NAME
552 except TypeError:
553 try: # CHECK IF field IS SQL COLUMN NAME
554 return self.columnName[self.data[field]] # THIS WILL JUST RETURN field...
555 except (KeyError,TypeError):
556 raise AttributeError('attribute %s does not map to an SQL column in %s'
557 % (attr,self.name))
558 if columnNumber:
559 try: # CHECK IF field IS A COLUMN NUMBER
560 return field+0 # ONLY RETURN AN INTEGER
561 except TypeError:
562 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
563 return self.data[field]+0 # ONLY RETURN AN INTEGER
564 except (KeyError,TypeError):
565 raise ValueError('attribute %s does not map to a SQL column!' % attr)
566 if isinstance(field,types.StringType):
567 attr=field # USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
568 elif attr=='id':
569 attr=self.primary_key
570 return attr
571 def addAttrAlias(self,saveToPickle=True,**kwargs):
572 """Add new attributes as aliases of existing attributes.
573 They can be specified either as named args:
574 t.addAttrAlias(newattr=oldattr)
575 or by passing a dictionary kwargs whose keys are newattr
576 and values are oldattr:
577 t.addAttrAlias(**kwargs)
578 saveToPickle=True forces these aliases to be saved if object is pickled.
580 if saveToPickle:
581 self.attrAlias.update(kwargs)
582 for key,val in kwargs.items():
583 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
584 self.data[val]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
585 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
586 except TypeError: # val IS ITSELF AN ALIAS
587 self.data[key] = self.data[val] # SO MAP TO WHAT IT MAPS TO
588 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
589 self.data[key] = val
590 def objclass(self,oclass=None):
591 "Create class representing a row in this table by subclassing oclass, adding data"
592 if oclass is not None: # use this as our base itemClass
593 self.itemClass = oclass
594 if self.writeable:
595 self.itemClass = self.itemClass._RWClass # use its writeable version
596 oclass = get_bound_subclass(self, 'itemClass', self.name,
597 subclassArgs=dict(db=self)) # bind itemClass
598 if issubclass(oclass, TupleO):
599 oclass._attrcol = self.data # BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
600 if hasattr(oclass,'_tableclass') and not isinstance(self,oclass._tableclass):
601 self.__class__=oclass._tableclass # ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
602 def _select(self, whereClause='', params=(), selectCols='t1.*',
603 cursor=None, orderBy='', limit=''):
604 'execute the specified query but do not fetch'
605 sql,params = self._format_query('select %s from %s t1 %s %s %s'
606 % (selectCols, self.name, whereClause, orderBy,
607 limit), params)
608 if cursor is None:
609 self.cursor.execute(sql, params)
610 else:
611 cursor.execute(sql, params)
612 def select(self,whereClause,params=None,oclass=None,selectCols='t1.*'):
613 "Generate the list of objects that satisfy the database SELECT"
614 if oclass is None:
615 oclass=self.itemClass
616 self._select(whereClause,params,selectCols)
617 l=self.cursor.fetchall()
618 for t in l:
619 yield self.cacheItem(t,oclass)
620 def query(self,**kwargs):
621 'query for intersection of all specified kwargs, returned as iterator'
622 criteria = []
623 params = []
624 for k,v in kwargs.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
625 if v is None: # CONVERT TO SQL NULL TEST
626 criteria.append('%s IS NULL' % self._attrSQL(k))
627 else: # TEST FOR EQUALITY
628 criteria.append('%s=%%s' % self._attrSQL(k))
629 params.append(v)
630 return self.select('where '+' and '.join(criteria),params)
631 def _update(self,row_id,col,val):
632 'update a single field in the specified row to the specified value'
633 sql,params = self._format_query('update %s set %s=%%s where %s=%%s'
634 %(self.name,col,self.primary_key),
635 (val,row_id))
636 self.cursor.execute(sql, params)
637 def getID(self,t):
638 try:
639 return t[self.data['id']] # GET ID FROM TUPLE
640 except TypeError: # treat as alias
641 return t[self.data[self.data['id']]]
642 def cacheItem(self,t,oclass):
643 'get obj from cache if possible, or construct from tuple'
644 try:
645 id=self.getID(t)
646 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
647 return oclass(t)
648 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
649 return self._weakValueDict[id]
650 except KeyError:
651 pass
652 o = oclass(t)
653 self._weakValueDict[id] = o # CACHE THIS ITEM IN OUR DICTIONARY
654 return o
655 def cache_items(self,rows,oclass=None):
656 if oclass is None:
657 oclass=self.itemClass
658 for t in rows:
659 yield self.cacheItem(t,oclass)
660 def foreignKey(self,attr,k):
661 'get iterator for objects with specified foreign key value'
662 return self.select('where %s=%%s'%attr,(k,))
663 def limit_cache(self):
664 'APPLY maxCache LIMIT TO CACHE SIZE'
665 try:
666 if self.maxCache<len(self._weakValueDict):
667 self._weakValueDict.clear()
668 except AttributeError:
669 pass
671 def get_new_cursor(self):
672 """Return a new cursor object, or None if not possible """
673 try:
674 new_cursor = self.serverInfo.new_cursor
675 except AttributeError:
676 return None
677 return new_cursor(self.arraysize)
679 def generic_iterator(self, cursor=None, fetch_f=None, cache_f=None,
680 map_f=iter):
681 'generic iterator that runs fetch, cache and map functions'
682 if fetch_f is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
683 if cursor is None:
684 fetch_f = self.cursor.fetchmany
685 else: # isolate this iter from other queries
686 fetch_f = cursor.fetchmany
687 if cache_f is None:
688 cache_f = self.cache_items
689 while True:
690 self.limit_cache()
691 rows = fetch_f() # FETCH THE NEXT SET OF ROWS
692 print 'fetch_f rows:', len(rows)
693 if len(rows)==0: # NO MORE DATA SO ALL DONE
694 break
695 for v in map_f(cache_f(rows)): # CACHE AND GENERATE RESULTS
696 yield v
697 if cursor is not None: # close iterator now that we're done
698 cursor.close()
699 def tuple_from_dict(self, d):
700 'transform kwarg dict into tuple for storing in database'
701 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
702 for col,icol in self.data.items():
703 try:
704 l[icol] = d[col]
705 except (KeyError,TypeError):
706 pass
707 return l
708 def tuple_from_obj(self, obj):
709 'transform object attributes into tuple for storing in database'
710 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
711 for col,icol in self.data.items():
712 try:
713 l[icol] = getattr(obj,col)
714 except (AttributeError,TypeError):
715 pass
716 return l
717 def _insert(self, l):
718 '''insert tuple into the database. Note this uses the MySQL
719 extension REPLACE, which overwrites any duplicate key.'''
720 s = '%(REPLACE)s into ' + self.name + ' values (' \
721 + ','.join(['%s']*len(l)) + ')'
722 sql,params = self._format_query(s, l)
723 self.cursor.execute(sql, params)
724 def insert(self, obj):
725 '''insert new row by transforming obj to tuple of values'''
726 l = self.tuple_from_obj(obj)
727 self._insert(l)
728 def get_insert_id(self):
729 'get the primary key value for the last INSERT'
730 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
731 auto_id = self.cursor.lastrowid
732 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
733 raise NotImplementedError('''your db lacks lastrowid support?''')
734 if auto_id is None:
735 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
736 return auto_id
737 def new(self, **kwargs):
738 'return a new record with the assigned attributes, added to DB'
739 if not self.writeable:
740 raise ValueError('this database is read only!')
741 obj = self.itemClass(None, newRow=True, **kwargs) # saves itself to db
742 self._weakValueDict[obj.id] = obj # AND SAVE TO OUR LOCAL DICT CACHE
743 return obj
744 def clear_cache(self):
745 'empty the cache'
746 self._weakValueDict.clear()
747 def __delitem__(self, k):
748 if not self.writeable:
749 raise ValueError('this database is read only!')
750 sql,params = self._format_query('delete from %s where %s=%%s'
751 % (self.name,self.primary_key),(k,))
752 self.cursor.execute(sql, params)
753 try:
754 del self._weakValueDict[k]
755 except KeyError:
756 pass
758 def getKeys(self,queryOption='', selectCols=None):
759 'uses db select; does not force load'
760 if selectCols is None:
761 selectCols=self.primary_key
762 if queryOption=='' and self.orderBy is not None:
763 queryOption = self.orderBy # apply default ordering
764 self.cursor.execute('select %s from %s %s'
765 %(selectCols,self.name,queryOption))
766 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
768 def iter_keys(self, selectCols=None, orderBy='', map_f=iter,
769 cache_f=lambda x:[t[0] for t in x], get_f=None, **kwargs):
770 'guarantee correct iteration insulated from other queries'
771 if selectCols is None:
772 selectCols=self.primary_key
773 if orderBy=='' and self.orderBy is not None:
774 orderBy = self.orderBy # apply default ordering
775 cursor = self.get_new_cursor()
776 if cursor: # got our own cursor, guaranteeing query isolation
777 try:
778 iter_f = self.db.serverInfo.iter_keysAXAFDAA
779 except AttributeError:
780 self._select(cursor=cursor, selectCols=selectCols,
781 orderBy=orderBy, **kwargs)
782 return self.generic_iterator(cursor=cursor, cache_f=cache_f,
783 map_f=map_f)
784 else: # use custom iter_keys() method from serverInfo
785 return iter_f(self, cursor, selectCols=selectCols,
786 orderBy=orderBy, map_f=map_f, cache_f=cache_f,
787 **kwargs)
788 else: # must pre-fetch all keys to ensure query isolation
789 if get_f is not None:
790 return iter(get_f())
791 else:
792 return iter(self.keys())
794 class SQLTable(SQLTableBase):
795 "Provide on-the-fly access to rows in the database, caching the results in dict"
796 itemClass = TupleO # our default itemClass; constructor can override
797 keys=getKeys
798 __iter__ = iter_keys
799 def load(self,oclass=None):
800 "Load all data from the table"
801 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
802 return self._isLoaded
803 except AttributeError:
804 pass
805 if oclass is None:
806 oclass=self.itemClass
807 self.cursor.execute('select * from %s' % self.name)
808 l=self.cursor.fetchall()
809 self._weakValueDict = {} # just store the whole dataset in memory
810 for t in l:
811 self.cacheItem(t,oclass) # CACHE IT IN LOCAL DICTIONARY
812 self._isLoaded=True # MARK THIS CONTAINER AS FULLY LOADED
814 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
815 try:
816 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
817 except KeyError: # NOT FOUND, SO TRY THE DATABASE
818 sql,params = self._format_query('select * from %s where %s=%%s limit 2'
819 % (self.name,self.primary_key),(k,))
820 self.cursor.execute(sql, params)
821 l = self.cursor.fetchmany(2) # get at most 2 rows
822 if len(l) != 1:
823 raise KeyError('%s not found in %s, or not unique' %(str(k),self.name))
824 self.limit_cache()
825 return self.cacheItem(l[0],self.itemClass) # CACHE IT IN LOCAL DICTIONARY
826 def __setitem__(self, k, v):
827 if not self.writeable:
828 raise ValueError('this database is read only!')
829 try:
830 if v.db is not self:
831 raise AttributeError
832 except AttributeError:
833 raise ValueError('object not bound to itemClass for this db!')
834 try:
835 oldID = v.id
836 if oldID is None:
837 raise AttributeError
838 except AttributeError:
839 pass
840 else: # delete row with old ID
841 del self[v.id]
842 v.cache_id(k) # cache the new ID on the object
843 self.insert(v) # SAVE TO THE RELATIONAL DB SERVER
844 self._weakValueDict[k] = v # CACHE THIS ITEM IN OUR DICTIONARY
845 def items(self):
846 'forces load of entire table into memory'
847 self.load()
848 return [(k,self[k]) for k in self] # apply orderBy rules...
849 def iteritems(self):
850 'uses arraysize / maxCache and fetchmany() to manage data transfer'
851 return iter_keys(self, selectCols='*', cache_f=None,
852 map_f=generate_items, get_f=self.items)
853 def values(self):
854 'forces load of entire table into memory'
855 self.load()
856 return [self[k] for k in self] # apply orderBy rules...
857 def itervalues(self):
858 'uses arraysize / maxCache and fetchmany() to manage data transfer'
859 return iter_keys(self, selectCols='*', cache_f=None, get_f=self.values)
861 def getClusterKeys(self,queryOption=''):
862 'uses db select; does not force load'
863 self.cursor.execute('select distinct %s from %s %s'
864 %(self.clusterKey,self.name,queryOption))
865 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
868 class SQLTableClustered(SQLTable):
869 '''use clusterKey to load a whole cluster of rows at once,
870 specifically, all rows that share the same clusterKey value.'''
871 def __init__(self, *args, **kwargs):
872 kwargs = kwargs.copy() # get a copy we can alter
873 kwargs['autoGC'] = False # don't use WeakValueDictionary
874 SQLTable.__init__(self, *args, **kwargs)
875 def keys(self):
876 return getKeys(self,'order by %s' %self.clusterKey)
877 def clusterkeys(self):
878 return getClusterKeys(self, 'order by %s' %self.clusterKey)
879 def __getitem__(self,k):
880 try:
881 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
882 except KeyError: # NOT FOUND, SO TRY THE DATABASE
883 sql,params = self._format_query('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
884 % (self.name,self.name,self.primary_key,
885 self.clusterKey,self.clusterKey),(k,))
886 self.cursor.execute(sql, params)
887 l=self.cursor.fetchall()
888 self.limit_cache()
889 for t in l: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
890 self.cacheItem(t,self.itemClass)
891 return self._weakValueDict[k] # should be in cache, if row k exists
892 def itercluster(self,cluster_id):
893 'iterate over all items from the specified cluster'
894 self.limit_cache()
895 return self.select('where %s=%%s'%self.clusterKey,(cluster_id,))
896 def fetch_cluster(self):
897 'use self.cursor.fetchmany to obtain all rows for next cluster'
898 icol = self._attrSQL(self.clusterKey,columnNumber=True)
899 result = []
900 try:
901 rows = self._fetch_cluster_cache # USE SAVED ROWS FROM PREVIOUS CALL
902 del self._fetch_cluster_cache
903 except AttributeError:
904 rows = self.cursor.fetchmany()
905 try:
906 cluster_id = rows[0][icol]
907 except IndexError:
908 return result
909 while len(rows)>0:
910 for i,t in enumerate(rows): # CHECK THAT ALL ROWS FROM THIS CLUSTER
911 if cluster_id != t[icol]: # START OF A NEW CLUSTER
912 result += rows[:i] # RETURN ROWS OF CURRENT CLUSTER
913 self._fetch_cluster_cache = rows[i:] # SAVE NEXT CLUSTER
914 return result
915 result += rows
916 rows = self.cursor.fetchmany() # GET NEXT SET OF ROWS
917 return result
918 def itervalues(self):
919 'uses arraysize / maxCache and fetchmany() to manage data transfer'
920 cursor = self.get_new_cursor()
921 self._select('order by %s' %self.clusterKey, cursor=cursor)
922 return self.generic_iterator(cursor, self.fetch_cluster)
923 def iteritems(self):
924 'uses arraysize / maxCache and fetchmany() to manage data transfer'
925 cursor = self.get_new_cursor()
926 self._select('order by %s' %self.clusterKey, cursor=cursor)
927 return self.generic_iterator(cursor, self.fetch_cluster,
928 map_f=generate_items)
930 class SQLForeignRelation(object):
931 'mapping based on matching a foreign key in an SQL table'
932 def __init__(self,table,keyName):
933 self.table=table
934 self.keyName=keyName
935 def __getitem__(self,k):
936 'get list of objects o with getattr(o,keyName)==k.id'
937 l=[]
938 for o in self.table.select('where %s=%%s'%self.keyName,(k.id,)):
939 l.append(o)
940 if len(l)==0:
941 raise KeyError('%s not found in %s' %(str(k),self.name))
942 return l
945 class SQLTableNoCache(SQLTableBase):
946 '''Provide on-the-fly access to rows in the database;
947 values are simply an object interface (SQLRow) to back-end db query.
948 Row data are not stored locally, but always accessed by querying the db'''
949 itemClass=SQLRow # DEFAULT OBJECT CLASS FOR ROWS...
950 keys=getKeys
951 __iter__ = iter_keys
952 def getID(self,t): return t[0] # GET ID FROM TUPLE
953 def select(self,whereClause,params):
954 return SQLTableBase.select(self,whereClause,params,self.oclass,
955 self._attrSQL('id'))
956 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
957 try:
958 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
959 except KeyError: # NOT FOUND, SO TRY THE DATABASE
960 self._select('where %s=%%s' % self.primary_key, (k,),
961 self.primary_key)
962 t = self.cursor.fetchmany(2)
963 if len(t) != 1:
964 raise KeyError('id %s non-existent or not unique' % k)
965 o = self.itemClass(k) # create obj referencing this ID
966 self._weakValueDict[k] = o # cache the SQLRow object
967 return o
968 def __setitem__(self, k, v):
969 if not self.writeable:
970 raise ValueError('this database is read only!')
971 try:
972 if v.db != self:
973 raise AttributeError
974 except AttributeError:
975 raise ValueError('object not bound to itemClass for this db!')
976 try:
977 del self[k] # delete row with new ID if any
978 except KeyError:
979 pass
980 try:
981 del self._weakValueDict[v.id] # delete from old cache location
982 except KeyError:
983 pass
984 self._update(v.id, self.primary_key, k) # just change its ID in db
985 v.cache_id(k) # change the cached ID value
986 self._weakValueDict[k] = v # assign to new cache location
987 def addAttrAlias(self,**kwargs):
988 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
990 SQLRow._tableclass=SQLTableNoCache # SQLRow IS FOR NON-CACHING TABLE INTERFACE
993 class SQLTableMultiNoCache(SQLTableBase):
994 "Trivial on-the-fly access for table with key that returns multiple rows"
995 itemClass = TupleO # default itemClass; constructor can override
996 _distinct_key='id' # DEFAULT COLUMN TO USE AS KEY
997 def keys(self):
998 return getKeys(self, selectCols='distinct(%s)'
999 % self._attrSQL(self._distinct_key))
1000 def __iter__(self):
1001 return iter_keys(self, 'distinct(%s)' % self._attrSQL(self._distinct_key))
1002 def __getitem__(self,id):
1003 sql,params = self._format_query('select * from %s where %s=%%s'
1004 %(self.name,self._attrSQL(self._distinct_key)),(id,))
1005 self.cursor.execute(sql, params)
1006 l=self.cursor.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1007 for row in l:
1008 yield self.itemClass(row)
1009 def addAttrAlias(self,**kwargs):
1010 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
1014 class SQLEdges(SQLTableMultiNoCache):
1015 '''provide iterator over edges as (source,target,edge)
1016 and getitem[edge] --> [(source,target),...]'''
1017 _distinct_key='edge_id'
1018 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1019 _pickleAttrs.update(dict(graph=0))
1020 def keys(self):
1021 self.cursor.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1022 %(self._attrSQL('source_id'),self._attrSQL('target_id'),
1023 self._attrSQL('edge_id'),self.name,
1024 self._attrSQL('target_id'),self._attrSQL('source_id'),
1025 self._attrSQL('target_id')))
1026 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1027 for source_id,target_id,edge_id in self.cursor.fetchall():
1028 l.append((self.graph.unpack_source(source_id),
1029 self.graph.unpack_target(target_id),
1030 self.graph.unpack_edge(edge_id)))
1031 return l
1032 __call__=keys
1033 def __iter__(self):
1034 return iter(self.keys())
1035 def __getitem__(self,edge):
1036 sql,params = self._format_query('select %s,%s from %s where %s=%%s'
1037 %(self._attrSQL('source_id'),
1038 self._attrSQL('target_id'),
1039 self.name,
1040 self._attrSQL(self._distinct_key)),
1041 (self.graph.pack_edge(edge),))
1042 self.cursor.execute(sql, params)
1043 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1044 for source_id,target_id in self.cursor.fetchall():
1045 l.append((self.graph.unpack_source(source_id),
1046 self.graph.unpack_target(target_id)))
1047 return l
1050 class SQLEdgeDict(object):
1051 '2nd level graph interface to SQL database'
1052 def __init__(self,fromNode,table):
1053 self.fromNode=fromNode
1054 self.table=table
1055 if not hasattr(self.table,'allowMissingNodes'):
1056 sql,params = self.table._format_query('select %s from %s where %s=%%s limit 1'
1057 %(self.table.sourceSQL,
1058 self.table.name,
1059 self.table.sourceSQL),
1060 (self.fromNode,))
1061 self.table.cursor.execute(sql, params)
1062 if len(self.table.cursor.fetchall())<1:
1063 raise KeyError('node not in graph!')
1065 def __getitem__(self,target):
1066 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1067 %(self.table.edgeSQL,
1068 self.table.name,
1069 self.table.sourceSQL,
1070 self.table.targetSQL),
1071 (self.fromNode,
1072 self.table.pack_target(target)))
1073 self.table.cursor.execute(sql, params)
1074 l = self.table.cursor.fetchmany(2) # get at most two rows
1075 if len(l) != 1:
1076 raise KeyError('either no edge from source to target or not unique!')
1077 try:
1078 return self.table.unpack_edge(l[0][0]) # RETURN EDGE
1079 except IndexError:
1080 raise KeyError('no edge from node to target')
1081 def __setitem__(self,target,edge):
1082 sql,params = self.table._format_query('replace into %s values (%%s,%%s,%%s)'
1083 %self.table.name,
1084 (self.fromNode,
1085 self.table.pack_target(target),
1086 self.table.pack_edge(edge)))
1087 self.table.cursor.execute(sql, params)
1088 if not hasattr(self.table,'sourceDB') or \
1089 (hasattr(self.table,'targetDB') and self.table.sourceDB==self.table.targetDB):
1090 self.table += target # ADD AS NODE TO GRAPH
1091 def __iadd__(self,target):
1092 self[target] = None
1093 return self # iadd MUST RETURN self!
1094 def __delitem__(self,target):
1095 sql,params = self.table._format_query('delete from %s where %s=%%s and %s=%%s'
1096 %(self.table.name,
1097 self.table.sourceSQL,
1098 self.table.targetSQL),
1099 (self.fromNode,
1100 self.table.pack_target(target)))
1101 self.table.cursor.execute(sql, params)
1102 if self.table.cursor.rowcount < 1: # no rows deleted?
1103 raise KeyError('no edge from node to target')
1105 def iterator_query(self):
1106 sql,params = self.table._format_query('select %s,%s from %s where %s=%%s and %s is not null'
1107 %(self.table.targetSQL,
1108 self.table.edgeSQL,
1109 self.table.name,
1110 self.table.sourceSQL,
1111 self.table.targetSQL),
1112 (self.fromNode,))
1113 self.table.cursor.execute(sql, params)
1114 return self.table.cursor.fetchall()
1115 def keys(self):
1116 return [self.table.unpack_target(target_id)
1117 for target_id,edge_id in self.iterator_query()]
1118 def values(self):
1119 return [self.table.unpack_edge(edge_id)
1120 for target_id,edge_id in self.iterator_query()]
1121 def edges(self):
1122 return [(self.table.unpack_source(self.fromNode),self.table.unpack_target(target_id),
1123 self.table.unpack_edge(edge_id))
1124 for target_id,edge_id in self.iterator_query()]
1125 def items(self):
1126 return [(self.table.unpack_target(target_id),self.table.unpack_edge(edge_id))
1127 for target_id,edge_id in self.iterator_query()]
1128 def __iter__(self): return iter(self.keys())
1129 def itervalues(self): return iter(self.values())
1130 def iteritems(self): return iter(self.items())
1131 def __len__(self):
1132 return len(self.keys())
1133 __cmp__ = graph_cmp
1135 class SQLEdgelessDict(SQLEdgeDict):
1136 'for SQLGraph tables that lack edge_id column'
1137 def __getitem__(self,target):
1138 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1139 %(self.table.targetSQL,
1140 self.table.name,
1141 self.table.sourceSQL,
1142 self.table.targetSQL),
1143 (self.fromNode,
1144 self.table.pack_target(target)))
1145 self.table.cursor.execute(sql, params)
1146 l = self.table.cursor.fetchmany(2)
1147 if len(l) != 1:
1148 raise KeyError('either no edge from source to target or not unique!')
1149 return None # no edge info!
1150 def iterator_query(self):
1151 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s is not null'
1152 %(self.table.targetSQL,
1153 self.table.name,
1154 self.table.sourceSQL,
1155 self.table.targetSQL),
1156 (self.fromNode,))
1157 self.table.cursor.execute(sql, params)
1158 return [(t[0],None) for t in self.table.cursor.fetchall()]
1160 SQLEdgeDict._edgelessClass = SQLEdgelessDict
1162 class SQLGraphEdgeDescriptor(object):
1163 'provide an SQLEdges interface on demand'
1164 def __get__(self,obj,objtype):
1165 try:
1166 attrAlias=obj.attrAlias.copy()
1167 except AttributeError:
1168 return SQLEdges(obj.name, obj.cursor, graph=obj)
1169 else:
1170 return SQLEdges(obj.name, obj.cursor, attrAlias=attrAlias,
1171 graph=obj)
1173 def getColumnTypes(createTable,attrAlias={},defaultColumnType='int',
1174 columnAttrs=('source','target','edge'),**kwargs):
1175 'return list of [(colname,coltype),...] for source,target,edge'
1176 l = []
1177 for attr in columnAttrs:
1178 try:
1179 attrName = attrAlias[attr+'_id']
1180 except KeyError:
1181 attrName = attr+'_id'
1182 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1183 l.append((attrName,createTable[attr+'_id']))
1184 continue
1185 except (KeyError,TypeError):
1186 pass
1187 try: # get type info from primary key for that database
1188 db = kwargs[attr+'DB']
1189 if db is None:
1190 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1191 except KeyError:
1192 pass
1193 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1194 it = iter(db)
1195 try: # GET ONE IDENTIFIER FROM THE DATABASE
1196 k = it.next()
1197 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1198 try:
1199 l.append((attrName,db.columnType[db.primary_key]))
1200 continue
1201 except AttributeError:
1202 pass
1203 else: # GET THE TYPE FROM THIS IDENTIFIER
1204 if isinstance(k,int) or isinstance(k,long):
1205 l.append((attrName,'int'))
1206 continue
1207 elif isinstance(k,str):
1208 l.append((attrName,'varchar(32)'))
1209 continue
1210 else:
1211 raise ValueError('SQLGraph node / edge must be int or str!')
1212 l.append((attrName,defaultColumnType))
1213 logger.warn('no type info found for %s, so using default: %s'
1214 % (attrName, defaultColumnType))
1217 return l
1220 class SQLGraph(SQLTableMultiNoCache):
1221 '''provide a graph interface via a SQL table. Key capabilities are:
1222 - setitem with an empty dictionary: a dummy operation
1223 - getitem with a key that exists: return a placeholder
1224 - setitem with non empty placeholder: again a dummy operation
1225 EXAMPLE TABLE SCHEMA:
1226 create table mygraph (source_id int not null,target_id int,edge_id int,
1227 unique(source_id,target_id));
1229 _distinct_key='source_id'
1230 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1231 _pickleAttrs.update(dict(sourceDB=0,targetDB=0,edgeDB=0,allowMissingNodes=0))
1232 _edgeClass = SQLEdgeDict
1233 def __init__(self,name,*l,**kwargs):
1234 graphArgs,tableArgs = split_kwargs(kwargs,
1235 ('attrAlias','defaultColumnType','columnAttrs',
1236 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1237 'edgeDictClass','graph'))
1238 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1239 c = getColumnTypes(**kwargs)
1240 tableArgs['createTable'] = \
1241 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1242 % (name,c[0][0],c[0][1],c[1][0],c[1][1],c[2][0],c[2][1],c[0][0],c[1][0])
1243 try:
1244 self.allowMissingNodes = kwargs['allowMissingNodes']
1245 except KeyError: pass
1246 SQLTableMultiNoCache.__init__(self,name,*l,**tableArgs)
1247 self.sourceSQL = self._attrSQL('source_id')
1248 self.targetSQL = self._attrSQL('target_id')
1249 try:
1250 self.edgeSQL = self._attrSQL('edge_id')
1251 except AttributeError:
1252 self.edgeSQL = None
1253 self._edgeClass = self._edgeClass._edgelessClass
1254 save_graph_db_refs(self,**kwargs)
1255 def __getitem__(self,k):
1256 return self._edgeClass(self.pack_source(k),self)
1257 def __iadd__(self,k):
1258 sql,params = self._format_query('delete from %s where %s=%%s and %s is null'
1259 % (self.name,self.sourceSQL,self.targetSQL),
1260 (self.pack_source(k),))
1261 self.cursor.execute(sql, params)
1262 sql,params = self._format_query('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1263 % self.name,(self.pack_source(k),))
1264 self.cursor.execute(sql, params)
1265 return self # iadd MUST RETURN SELF!
1266 def __isub__(self,k):
1267 sql,params = self._format_query('delete from %s where %s=%%s'
1268 % (self.name,self.sourceSQL),
1269 (self.pack_source(k),))
1270 self.cursor.execute(sql, params)
1271 if self.cursor.rowcount == 0:
1272 raise KeyError('node not found in graph')
1273 return self # iadd MUST RETURN SELF!
1274 __setitem__ = graph_setitem
1275 def __contains__(self,k):
1276 sql,params = self._format_query('select * from %s where %s=%%s limit 1'
1277 %(self.name,self.sourceSQL),
1278 (self.pack_source(k),))
1279 self.cursor.execute(sql, params)
1280 l = self.cursor.fetchmany(2)
1281 return len(l) > 0
1282 def __invert__(self):
1283 'get an interface to the inverse graph mapping'
1284 try: # CACHED
1285 return self._inverse
1286 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1287 attrAlias = dict(source_id=self.targetSQL, # SWAP SOURCE & TARGET
1288 target_id=self.sourceSQL,
1289 edge_id=self.edgeSQL)
1290 if self.edgeSQL is None: # no edge interface
1291 del attrAlias['edge_id']
1292 self._inverse=SQLGraph(self.name,self.cursor,
1293 attrAlias=attrAlias,
1294 **graph_db_inverse_refs(self))
1295 self._inverse._inverse=self
1296 return self._inverse
1297 def __iter__(self):
1298 for k in SQLTableMultiNoCache.__iter__(self):
1299 yield self.unpack_source(k)
1300 def iteritems(self):
1301 for k in SQLTableMultiNoCache.__iter__(self):
1302 yield (self.unpack_source(k), self._edgeClass(k, self))
1303 def itervalues(self):
1304 for k in SQLTableMultiNoCache.__iter__(self):
1305 yield self._edgeClass(k, self)
1306 def keys(self):
1307 return [self.unpack_source(k) for k in SQLTableMultiNoCache.keys(self)]
1308 def values(self): return list(self.itervalues())
1309 def items(self): return list(self.iteritems())
1310 edges=SQLGraphEdgeDescriptor()
1311 update = update_graph
1312 def __len__(self):
1313 'get number of source nodes in graph'
1314 self.cursor.execute('select count(distinct %s) from %s'
1315 %(self.sourceSQL,self.name))
1316 return self.cursor.fetchone()[0]
1317 __cmp__ = graph_cmp
1318 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1319 ## def __cmp__(self,other):
1320 ## node = ()
1321 ## n = 0
1322 ## d = None
1323 ## it = iter(self.edges)
1324 ## while True:
1325 ## try:
1326 ## source,target,edge = it.next()
1327 ## except StopIteration:
1328 ## source = None
1329 ## if source!=node:
1330 ## if d is not None:
1331 ## diff = cmp(n_target,len(d))
1332 ## if diff!=0:
1333 ## return diff
1334 ## if source is None:
1335 ## break
1336 ## node = source
1337 ## n += 1 # COUNT SOURCE NODES
1338 ## n_target = 0
1339 ## try:
1340 ## d = other[node]
1341 ## except KeyError:
1342 ## return 1
1343 ## try:
1344 ## diff = cmp(edge,d[target])
1345 ## except KeyError:
1346 ## return 1
1347 ## if diff!=0:
1348 ## return diff
1349 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1350 ## return cmp(n,len(other))
1352 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1354 class SQLIDGraph(SQLGraph):
1355 add_trivial_packing_methods(locals())
1356 SQLGraph._IDGraphClass = SQLIDGraph
1360 class SQLEdgeDictClustered(dict):
1361 'simple cache for 2nd level dictionary of target_id:edge_id'
1362 def __init__(self,g,fromNode):
1363 self.g=g
1364 self.fromNode=fromNode
1365 dict.__init__(self)
1366 def __iadd__(self,l):
1367 for target_id,edge_id in l:
1368 dict.__setitem__(self,target_id,edge_id)
1369 return self # iadd MUST RETURN SELF!
1371 class SQLEdgesClusteredDescr(object):
1372 def __get__(self,obj,objtype):
1373 e=SQLEdgesClustered(obj.table,obj.edge_id,obj.source_id,obj.target_id,
1374 graph=obj,**graph_db_inverse_refs(obj,True))
1375 for source_id,d in obj.d.iteritems(): # COPY EDGE CACHE
1376 e.load([(edge_id,source_id,target_id)
1377 for (target_id,edge_id) in d.iteritems()])
1378 return e
1380 class SQLGraphClustered(object):
1381 'SQL graph with clustered caching -- loads an entire cluster at a time'
1382 _edgeDictClass=SQLEdgeDictClustered
1383 def __init__(self,table,source_id='source_id',target_id='target_id',
1384 edge_id='edge_id',clusterKey=None,**kwargs):
1385 import types
1386 if isinstance(table,types.StringType): # CREATE THE TABLE INTERFACE
1387 if clusterKey is None:
1388 raise ValueError('you must provide a clusterKey argument!')
1389 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1390 c = getColumnTypes(attrAlias=dict(source_id=source_id,target_id=target_id,
1391 edge_id=edge_id),**kwargs)
1392 kwargs['createTable'] = \
1393 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1394 % (table,c[0][0],c[0][1],c[1][0],c[1][1],
1395 c[2][0],c[2][1],c[0][0],c[1][0])
1396 table = SQLTableClustered(table,clusterKey=clusterKey,**kwargs)
1397 self.table=table
1398 self.source_id=source_id
1399 self.target_id=target_id
1400 self.edge_id=edge_id
1401 self.d={}
1402 save_graph_db_refs(self,**kwargs)
1403 _pickleAttrs = dict(table=0,source_id=0,target_id=0,edge_id=0,sourceDB=0,targetDB=0,
1404 edgeDB=0)
1405 def __getstate__(self):
1406 state = standard_getstate(self)
1407 state['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1408 return state
1409 def __getitem__(self,k):
1410 'get edgeDict for source node k, from cache or by loading its cluster'
1411 try: # GET DIRECTLY FROM CACHE
1412 return self.d[k]
1413 except KeyError:
1414 if hasattr(self,'_isLoaded'):
1415 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1416 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1417 sql,params = self.table._format_query('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1418 %(self.source_id,self.target_id,
1419 self.edge_id,self.table.name,
1420 self.table.name,self.source_id,
1421 self.table.clusterKey,self.table.clusterKey,
1422 self.table.primary_key),
1423 (self.pack_source(k),))
1424 self.table.cursor.execute(sql, params)
1425 self.load(self.table.cursor.fetchall()) # CACHE THIS CLUSTER
1426 return self.d[k] # RETURN EDGE DICT FOR THIS NODE
1427 def load(self,l=None,unpack=True):
1428 'load the specified rows (or all, if None provided) into local cache'
1429 if l is None:
1430 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1431 return self._isLoaded
1432 except AttributeError:
1433 pass
1434 self.table.cursor.execute('select %s,%s,%s from %s'
1435 %(self.source_id,self.target_id,
1436 self.edge_id,self.table.name))
1437 l=self.table.cursor.fetchall()
1438 self._isLoaded=True
1439 self.d.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1440 for source,target,edge in l: # SAVE TO OUR CACHE
1441 if unpack:
1442 source = self.unpack_source(source)
1443 target = self.unpack_target(target)
1444 edge = self.unpack_edge(edge)
1445 try:
1446 self.d[source] += [(target,edge)]
1447 except KeyError:
1448 d = self._edgeDictClass(self,source)
1449 d += [(target,edge)]
1450 self.d[source] = d
1451 def __invert__(self):
1452 'interface to reverse graph mapping'
1453 try:
1454 return self._inverse # INVERSE MAP ALREADY EXISTS
1455 except AttributeError:
1456 pass
1457 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1458 self._inverse=SQLGraphClustered(self.table,self.target_id,self.source_id,
1459 self.edge_id,**graph_db_inverse_refs(self))
1460 self._inverse._inverse=self
1461 for source,d in self.d.iteritems(): # INVERT OUR CACHE
1462 self._inverse.load([(target,source,edge)
1463 for (target,edge) in d.iteritems()],unpack=False)
1464 return self._inverse
1465 edges=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1466 update = update_graph
1467 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1468 def __iter__(self): ################# ITERATORS
1469 'uses db select; does not force load'
1470 return iter(self.keys())
1471 def keys(self):
1472 'uses db select; does not force load'
1473 self.table.cursor.execute('select distinct(%s) from %s'
1474 %(self.source_id,self.table.name))
1475 return [self.unpack_source(t[0])
1476 for t in self.table.cursor.fetchall()]
1477 methodFactory(['iteritems','items','itervalues','values'],
1478 'lambda self:(self.load(),self.d.%s())[1]',locals())
1479 def __contains__(self,k):
1480 try:
1481 x=self[k]
1482 return True
1483 except KeyError:
1484 return False
1486 class SQLIDGraphClustered(SQLGraphClustered):
1487 add_trivial_packing_methods(locals())
1488 SQLGraphClustered._IDGraphClass = SQLIDGraphClustered
1490 class SQLEdgesClustered(SQLGraphClustered):
1491 'edges interface for SQLGraphClustered'
1492 _edgeDictClass = list
1493 _pickleAttrs = SQLGraphClustered._pickleAttrs.copy()
1494 _pickleAttrs.update(dict(graph=0))
1495 def keys(self):
1496 self.load()
1497 result = []
1498 for edge_id,l in self.d.iteritems():
1499 for source_id,target_id in l:
1500 result.append((self.graph.unpack_source(source_id),
1501 self.graph.unpack_target(target_id),
1502 self.graph.unpack_edge(edge_id)))
1503 return result
1505 class ForeignKeyInverse(object):
1506 'map each key to a single value according to its foreign key'
1507 def __init__(self,g):
1508 self.g = g
1509 def __getitem__(self,obj):
1510 self.check_obj(obj)
1511 source_id = getattr(obj,self.g.keyColumn)
1512 if source_id is None:
1513 return None
1514 return self.g.sourceDB[source_id]
1515 def __setitem__(self,obj,source):
1516 self.check_obj(obj)
1517 if source is not None:
1518 self.g[source][obj] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1519 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1520 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1521 old_source = self[obj]
1522 if old_source is not None:
1523 del self.g[old_source][obj]
1524 def check_obj(self,obj):
1525 'raise KeyError if obj not from this db'
1526 try:
1527 if obj.db != self.g.targetDB:
1528 raise AttributeError
1529 except AttributeError:
1530 raise KeyError('key is not from targetDB of this graph!')
1531 def __contains__(self,obj):
1532 try:
1533 self.check_obj(obj)
1534 return True
1535 except KeyError:
1536 return False
1537 def __iter__(self):
1538 return self.g.targetDB.itervalues()
1539 def keys(self):
1540 return self.g.targetDB.values()
1541 def iteritems(self):
1542 for obj in self:
1543 source_id = getattr(obj,self.g.keyColumn)
1544 if source_id is None:
1545 yield obj,None
1546 else:
1547 yield obj,self.g.sourceDB[source_id]
1548 def items(self):
1549 return list(self.iteritems())
1550 def itervalues(self):
1551 for obj,val in self.iteritems():
1552 yield val
1553 def values(self):
1554 return list(self.itervalues())
1555 def __invert__(self):
1556 return self.g
1558 class ForeignKeyEdge(dict):
1559 '''edge interface to a foreign key in an SQL table.
1560 Caches dict of target nodes in itself; provides dict interface.
1561 Adds or deletes edges by setting foreign key values in the table'''
1562 def __init__(self,g,k):
1563 dict.__init__(self)
1564 self.g = g
1565 self.src = k
1566 for v in g.targetDB.select('where %s=%%s' % g.keyColumn,(k.id,)): # SEARCH THE DB
1567 dict.__setitem__(self,v,None) # SAVE IN CACHE
1568 def __setitem__(self,dest,v):
1569 if not hasattr(dest,'db') or dest.db != self.g.targetDB:
1570 raise KeyError('dest is not in the targetDB bound to this graph!')
1571 if v is not None:
1572 raise ValueError('sorry,this graph cannot store edge information!')
1573 if not hasattr(dest,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1574 old_source = self.g._inverse[dest] # CHECK FOR PRE-EXISTING EDGE
1575 if old_source is not None: # REMOVE OLD EDGE FROM CACHE
1576 dict.__delitem__(self.g[old_source],dest)
1577 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1578 setattr(dest,self.g.keyColumn,self.src.id) # SAVE TO DB ATTRIBUTE
1579 dict.__setitem__(self,dest,None) # SAVE IN CACHE
1580 def __delitem__(self,dest):
1581 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1582 setattr(dest,self.g.keyColumn,None) # SAVE TO DB ATTRIBUTE
1583 dict.__delitem__(self,dest) # REMOVE FROM CACHE
1585 class ForeignKeyGraph(object, UserDict.DictMixin):
1586 '''graph interface to a foreign key in an SQL table
1587 Caches dict of target nodes in itself; provides dict interface.
1589 def __init__(self, sourceDB, targetDB, keyColumn, autoGC=True, **kwargs):
1590 '''sourceDB is any database of source nodes;
1591 targetDB must be an SQL database of target nodes;
1592 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1593 if autoGC: # automatically garbage collect unused objects
1594 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
1595 else:
1596 self._weakValueDict = {}
1597 self.autoGC = autoGC
1598 self.sourceDB = sourceDB
1599 self.targetDB = targetDB
1600 self.keyColumn = keyColumn
1601 self._inverse = ForeignKeyInverse(self)
1602 _pickleAttrs = dict(sourceDB=0, targetDB=0, keyColumn=0, autoGC=0)
1603 __getstate__ = standard_getstate ########### SUPPORT FOR PICKLING
1604 __setstate__ = standard_setstate
1605 def _inverse_schema(self):
1606 'provide custom schema rule for inverting this graph... just use keyColumn!'
1607 return dict(invert=True,uniqueMapping=True)
1608 def __getitem__(self,k):
1609 if not hasattr(k,'db') or k.db != self.sourceDB:
1610 raise KeyError('object is not in the sourceDB bound to this graph!')
1611 try:
1612 return self._weakValueDict[k.id] # get from cache
1613 except KeyError:
1614 pass
1615 d = ForeignKeyEdge(self,k)
1616 self._weakValueDict[k.id] = d # save in cache
1617 return d
1618 def __setitem__(self, k, v):
1619 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1620 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1621 def __delitem__(self, k):
1622 raise KeyError('''Instead of del g[k], follow a graph
1623 interface: del g[src][dest]''')
1624 def keys(self):
1625 return self.sourceDB.values()
1626 __invert__ = standard_invert
1628 def describeDBTables(name,cursor,idDict):
1630 Get table info about database <name> via <cursor>, and store primary keys
1631 in idDict, along with a list of the tables each key indexes.
1633 cursor.execute('use %s' % name)
1634 cursor.execute('show tables')
1635 tables={}
1636 l=[c[0] for c in cursor.fetchall()]
1637 for t in l:
1638 tname=name+'.'+t
1639 o=SQLTable(tname,cursor)
1640 tables[tname]=o
1641 for f in o.description:
1642 if f==o.primary_key:
1643 idDict.setdefault(f, []).append(o)
1644 elif f[-3:]=='_id' and f not in idDict:
1645 idDict[f]=[]
1646 return tables
1650 def indexIDs(tables,idDict=None):
1651 "Get an index of primary keys in the <tables> dictionary."
1652 if idDict==None:
1653 idDict={}
1654 for o in tables.values():
1655 if o.primary_key:
1656 if o.primary_key not in idDict:
1657 idDict[o.primary_key]=[]
1658 idDict[o.primary_key].append(o) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1659 for f in o.description:
1660 if f[-3:]=='_id' and f not in idDict:
1661 idDict[f]=[]
1662 return idDict
1666 def suffixSubset(tables,suffix):
1667 "Filter table index for those matching a specific suffix"
1668 subset={}
1669 for name,t in tables.items():
1670 if name.endswith(suffix):
1671 subset[name]=t
1672 return subset
1675 PRIMARY_KEY=1
1677 def graphDBTables(tables,idDict):
1678 g=dictgraph()
1679 for t in tables.values():
1680 for f in t.description:
1681 if f==t.primary_key:
1682 edgeInfo=PRIMARY_KEY
1683 else:
1684 edgeInfo=None
1685 g.setEdge(f,t,edgeInfo)
1686 g.setEdge(t,f,edgeInfo)
1687 return g
1689 SQLTypeTranslation= {types.StringType:'varchar(32)',
1690 types.IntType:'int',
1691 types.FloatType:'float'}
1693 def createTableFromRepr(rows,tableName,cursor,typeTranslation=None,
1694 optionalDict=None,indexDict=()):
1695 """Save rows into SQL tableName using cursor, with optional
1696 translations of columns to specific SQL types (specified
1697 by typeTranslation dict).
1698 - optionDict can specify columns that are allowed to be NULL.
1699 - indexDict can specify columns that must be indexed; columns
1700 whose names end in _id will be indexed by default.
1701 - rows must be an iterator which in turn returns dictionaries,
1702 each representing a tuple of values (indexed by their column
1703 names).
1705 try:
1706 row=rows.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1707 except StopIteration:
1708 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1709 try:
1710 createTableFromRow(cursor, tableName,row,typeTranslation,
1711 optionalDict,indexDict)
1712 except:
1713 pass
1714 storeRow(cursor,tableName,row) # SAVE OUR FIRST ROW
1715 for row in rows: # NOW SAVE ALL THE ROWS
1716 storeRow(cursor,tableName,row)
1718 def createTableFromRow(cursor, tableName, row,typeTranslation=None,
1719 optionalDict=None,indexDict=()):
1720 create_defs=[]
1721 for col,val in row.items(): # PREPARE SQL TYPES FOR COLUMNS
1722 coltype=None
1723 if typeTranslation!=None and col in typeTranslation:
1724 coltype=typeTranslation[col] # USER-SUPPLIED TRANSLATION
1725 elif type(val) in SQLTypeTranslation:
1726 coltype=SQLTypeTranslation[type(val)]
1727 else: # SEARCH FOR A COMPATIBLE TYPE
1728 for t in SQLTypeTranslation:
1729 if isinstance(val,t):
1730 coltype=SQLTypeTranslation[t]
1731 break
1732 if coltype==None:
1733 raise TypeError("Don't know SQL type to use for %s" % col)
1734 create_def='%s %s' %(col,coltype)
1735 if optionalDict==None or col not in optionalDict:
1736 create_def+=' not null'
1737 create_defs.append(create_def)
1738 for col in row: # CREATE INDEXES FOR ID COLUMNS
1739 if col[-3:]=='_id' or col in indexDict:
1740 create_defs.append('index(%s)' % col)
1741 cmd='create table if not exists %s (%s)' % (tableName,','.join(create_defs))
1742 cursor.execute(cmd) # CREATE THE TABLE IN THE DATABASE
1745 def storeRow(cursor, tableName, row):
1746 row_format=','.join(len(row)*['%s'])
1747 cmd='insert into %s values (%s)' % (tableName,row_format)
1748 cursor.execute(cmd,tuple(row.values()))
1750 def storeRowDelayed(cursor, tableName, row):
1751 row_format=','.join(len(row)*['%s'])
1752 cmd='insert delayed into %s values (%s)' % (tableName,row_format)
1753 cursor.execute(cmd,tuple(row.values()))
1756 class TableGroup(dict):
1757 'provide attribute access to dbname qualified tablenames'
1758 def __init__(self,db='test',suffix=None,**kw):
1759 dict.__init__(self)
1760 self.db=db
1761 if suffix is not None:
1762 self.suffix=suffix
1763 for k,v in kw.items():
1764 if v is not None and '.' not in v:
1765 v=self.db+'.'+v # ADD DATABASE NAME AS PREFIX
1766 self[k]=v
1767 def __getattr__(self,k):
1768 return self[k]
1770 def sqlite_connect(*args, **kwargs):
1771 sqlite = import_sqlite()
1772 connection = sqlite.connect(*args, **kwargs)
1773 cursor = connection.cursor()
1774 return connection, cursor
1776 class DBServerInfo(object):
1777 'picklable reference to a database server'
1778 def __init__(self, moduleName='MySQLdb', *args, **kwargs):
1779 try:
1780 self.__class__ = _DBServerModuleDict[moduleName]
1781 except KeyError:
1782 raise ValueError('Module name not found in _DBServerModuleDict: '\
1783 + moduleName)
1784 self.moduleName = moduleName
1785 self.args = args
1786 self.kwargs = kwargs # connection arguments
1788 def cursor(self):
1789 """returns cursor associated with the DB server info (reused)"""
1790 try:
1791 return self._cursor
1792 except AttributeError:
1793 self._start_connection()
1794 return self._cursor
1796 def new_cursor(self, arraysize=None):
1797 """returns a NEW cursor; you must close it yourself! """
1798 if not hasattr(self, '_connection'):
1799 self._start_connection()
1800 cursor = self._connection.cursor()
1801 if arraysize is not None:
1802 cursor.arraysize = arraysize
1803 return cursor
1805 def close(self):
1806 """Close file containing this database"""
1807 self._cursor.close()
1808 self._connection.close()
1809 del self._cursor
1810 del self._connection
1812 def __getstate__(self):
1813 """return all picklable arguments"""
1814 return dict(args=self.args, kwargs=self.kwargs,
1815 moduleName=self.moduleName)
1818 class MySQLServerInfo(DBServerInfo):
1819 'customized for MySQLdb SSCursor support via new_cursor()'
1820 def _start_connection(self):
1821 self._connection,self._cursor = mysql_connect(*self.args, **self.kwargs)
1822 def new_cursor(self, arraysize=None):
1823 'provide streaming cursor support'
1824 try:
1825 conn = self._conn_sscursor
1826 except AttributeError:
1827 self._conn_sscursor,cursor = mysql_connect(useStreaming=True,
1828 *self.args, **self.kwargs)
1829 else:
1830 cursor = self._conn_sscursor.cursor()
1831 if arraysize is not None:
1832 cursor.arraysize = arraysize
1833 return cursor
1834 def close(self):
1835 DBServerInfo.close(self)
1836 try:
1837 self._conn_sscursor.close()
1838 del self._conn_sscursor
1839 except AttributeError:
1840 pass
1841 def iter_keys(self, db, cursor, map_f=iter,
1842 cache_f=lambda x:[t[0] for t in x], **kwargs):
1843 block_generator = BlockGenerator(db, self, cursor, **kwargs)
1844 return db.generic_iterator(cursor=cursor, cache_f=cache_f,
1845 map_f=map_f, fetch_f=block_generator)
1848 class BlockGenerator(object):
1849 def __init__(self, db, serverInfo, cursor, whereClause='', **kwargs):
1850 self.db = db
1851 self.serverInfo = serverInfo
1852 self.cursor = cursor
1853 self.kwargs = kwargs
1854 self.blockSize = 10000
1855 self.whereClause = ''
1856 #self.__iter__() # start me up!
1858 ## def __iter__(self):
1859 ## 'initialize this iterator'
1860 ## self.db._select(cursor=cursor, selectCols='min(%s),max(%s),count(*)'
1861 ## % (self.db.name, self.db.name))
1862 ## l = self.cursor.fetchall()
1863 ## self.minID, self.maxID, self.count = l[0]
1864 ## self.start = self.minID - 1 # only works for int
1865 ## return self
1867 ## def next(self):
1868 ## 'get the next start position'
1869 ## if self.start >= self.maxID:
1870 ## raise StopIteration
1871 ## return start
1873 def __call__(self):
1874 'get the next block of data'
1875 ## try:
1876 ## start = self.next()
1877 ## except StopIteration:
1878 ## return ()
1879 print 'SELECT ... %s LIMIT %s' % (self.whereClause, self.blockSize)
1880 self.db._select(cursor=self.cursor, whereClause=self.whereClause,
1881 limit='LIMIT %s' % self.blockSize, **kwargs)
1882 rows = self.cursor.fetchall()
1883 lastrow = rows[-1]
1884 if len(lastrow) > 1: # extract the last ID value in this block
1885 start = lastrow[self.db.data['id']]
1886 else:
1887 start = lastrow[0]
1888 self.whereClause = 'WHERE %s>%s' %(self.db.primary_key,start)
1889 return rows
1893 class SQLiteServerInfo(DBServerInfo):
1894 """picklable reference to a sqlite database"""
1895 def __init__(self, database, *args, **kwargs):
1896 """Takes same arguments as sqlite3.connect()"""
1897 DBServerInfo.__init__(self, 'sqlite',
1898 SourceFileName(database), # save abs path!
1899 *args, **kwargs)
1900 def _start_connection(self):
1901 self._connection,self._cursor = sqlite_connect(*self.args, **self.kwargs)
1902 def __getstate__(self):
1903 if self.args[0] == ':memory:':
1904 raise ValueError('SQLite in-memory database is not picklable!')
1905 return DBServerInfo.__getstate__(self)
1907 # list of DBServerInfo subclasses for different modules
1908 _DBServerModuleDict = dict(MySQLdb=MySQLServerInfo, sqlite=SQLiteServerInfo)
1911 class MapView(object, UserDict.DictMixin):
1912 'general purpose 1:1 mapping defined by any SQL query'
1913 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None,
1914 serverInfo=None, inverseSQL=None, **kwargs):
1915 self.sourceDB = sourceDB
1916 self.targetDB = targetDB
1917 self.viewSQL = viewSQL
1918 self.inverseSQL = inverseSQL
1919 if cursor is None:
1920 if serverInfo is not None: # get cursor from serverInfo
1921 cursor = serverInfo.cursor()
1922 else:
1923 try: # can we get it from our other db?
1924 serverInfo = sourceDB.serverInfo
1925 except AttributeError:
1926 raise ValueError('you must provide serverInfo or cursor!')
1927 else:
1928 cursor = serverInfo.cursor()
1929 self.cursor = cursor
1930 self.serverInfo = serverInfo
1931 self.get_sql_format(False) # get sql formatter for this db interface
1932 _schemaModuleDict = _schemaModuleDict # default module list
1933 get_sql_format = get_table_schema
1934 def __getitem__(self, k):
1935 if not hasattr(k,'db') or k.db is not self.sourceDB:
1936 raise KeyError('object is not in the sourceDB bound to this map!')
1937 sql,params = self._format_query(self.viewSQL, (k.id,))
1938 self.cursor.execute(sql, params) # formatted for this db interface
1939 t = self.cursor.fetchmany(2) # get at most two rows
1940 if len(t) != 1:
1941 raise KeyError('%s not found in MapView, or not unique'
1942 % str(k))
1943 return self.targetDB[t[0][0]] # get the corresponding object
1944 _pickleAttrs = dict(sourceDB=0, targetDB=0, viewSQL=0, serverInfo=0,
1945 inverseSQL=0)
1946 __getstate__ = standard_getstate
1947 __setstate__ = standard_setstate
1948 __setitem__ = __delitem__ = clear = pop = popitem = update = \
1949 setdefault = read_only_error
1950 def __iter__(self):
1951 'only yield sourceDB items that are actually in this mapping!'
1952 for k in self.sourceDB.itervalues():
1953 try:
1954 self[k]
1955 yield k
1956 except KeyError:
1957 pass
1958 def keys(self):
1959 return [k for k in self] # don't use list(self); causes infinite loop!
1960 def __invert__(self):
1961 try:
1962 return self._inverse
1963 except AttributeError:
1964 if self.inverseSQL is None:
1965 raise ValueError('this MapView has no inverseSQL!')
1966 self._inverse = self.__class__(self.targetDB, self.sourceDB,
1967 self.inverseSQL, self.cursor,
1968 serverInfo=self.serverInfo,
1969 inverseSQL=self.viewSQL)
1970 self._inverse._inverse = self
1971 return self._inverse
1973 class GraphViewEdgeDict(UserDict.DictMixin):
1974 'edge dictionary for GraphView: just pre-loaded on init'
1975 def __init__(self, g, k):
1976 self.g = g
1977 self.k = k
1978 sql,params = self.g._format_query(self.g.viewSQL, (k.id,))
1979 self.g.cursor.execute(sql, params) # run the query
1980 l = self.g.cursor.fetchall() # get results
1981 if len(l) <= 0:
1982 raise KeyError('key %s not in GraphView' % k.id)
1983 self.targets = [t[0] for t in l] # preserve order of the results
1984 d = {} # also keep targetID:edgeID mapping
1985 if self.g.edgeDB is not None: # save with edge info
1986 for t in l:
1987 d[t[0]] = t[1]
1988 else:
1989 for t in l:
1990 d[t[0]] = None
1991 self.targetDict = d
1992 def __len__(self):
1993 return len(self.targets)
1994 def __iter__(self):
1995 for k in self.targets:
1996 yield self.g.targetDB[k]
1997 def keys(self):
1998 return list(self)
1999 def iteritems(self):
2000 if self.g.edgeDB is not None: # save with edge info
2001 for k in self.targets:
2002 yield (self.g.targetDB[k], self.g.edgeDB[self.targetDict[k]])
2003 else: # just save the list of targets, no edge info
2004 for k in self.targets:
2005 yield (self.g.targetDB[k], None)
2006 def __getitem__(self, o, exitIfFound=False):
2007 'for the specified target object, return its associated edge object'
2008 try:
2009 if o.db is not self.g.targetDB:
2010 raise KeyError('key is not part of targetDB!')
2011 edgeID = self.targetDict[o.id]
2012 except AttributeError:
2013 raise KeyError('key has no id or db attribute?!')
2014 if exitIfFound:
2015 return
2016 if self.g.edgeDB is not None: # return the edge object
2017 return self.g.edgeDB[edgeID]
2018 else: # no edge info
2019 return None
2020 def __contains__(self, o):
2021 try:
2022 self.__getitem__(o, True) # raise KeyError if not found
2023 return True
2024 except KeyError:
2025 return False
2026 __setitem__ = __delitem__ = clear = pop = popitem = update = \
2027 setdefault = read_only_error
2029 class GraphView(MapView):
2030 'general purpose graph interface defined by any SQL query'
2031 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None, edgeDB=None,
2032 **kwargs):
2033 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
2034 self.edgeDB = edgeDB
2035 MapView.__init__(self, sourceDB, targetDB, viewSQL, cursor, **kwargs)
2036 def __getitem__(self, k):
2037 if not hasattr(k,'db') or k.db is not self.sourceDB:
2038 raise KeyError('object is not in the sourceDB bound to this map!')
2039 return GraphViewEdgeDict(self, k)
2040 _pickleAttrs = MapView._pickleAttrs.copy()
2041 _pickleAttrs.update(dict(edgeDB=0))
2043 # @CTB move to sqlgraph.py?
2045 class SQLSequence(SQLRow, SequenceBase):
2046 """Transparent access to a DB row representing a sequence.
2048 Use attrAlias dict to rename 'length' to something else.
2050 def _init_subclass(cls, db, **kwargs):
2051 db.seqInfoDict = db # db will act as its own seqInfoDict
2052 SQLRow._init_subclass(db=db, **kwargs)
2053 _init_subclass = classmethod(_init_subclass)
2054 def __init__(self, id):
2055 SQLRow.__init__(self, id)
2056 SequenceBase.__init__(self)
2057 def __len__(self):
2058 return self.length
2059 def strslice(self,start,end):
2060 "Efficient access to slice of a sequence, useful for huge contigs"
2061 return self._select('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
2062 %(self.db._attrSQL('seq'),start+1,end-start))
2064 class DNASQLSequence(SQLSequence):
2065 _seqtype=DNA_SEQTYPE
2067 class RNASQLSequence(SQLSequence):
2068 _seqtype=RNA_SEQTYPE
2070 class ProteinSQLSequence(SQLSequence):
2071 _seqtype=PROTEIN_SEQTYPE