changed BlockGenerator to select multiple columns at once, to handle cases where...
[pygr.git] / pygr / sqlgraph.py
blob65f904081c145c66f388d300dfa27e43ee07d678
3 from __future__ import generators
4 from mapping import *
5 from sequence import SequenceBase, DNA_SEQTYPE, RNA_SEQTYPE, PROTEIN_SEQTYPE
6 import types
7 from classutil import methodFactory,standard_getstate,\
8 override_rich_cmp,generate_items,get_bound_subclass,standard_setstate,\
9 get_valid_path,standard_invert,RecentValueDictionary,read_only_error,\
10 SourceFileName, split_kwargs
11 import os
12 import platform
13 import UserDict
14 import warnings
15 import logger
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self, db, attr):
20 self.icol = db.data[attr] # index of this attribute in the tuple
21 def __get__(self, obj, klass):
22 return obj._data[self.icol]
23 def __set__(self, obj, val):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor):
27 def __set__(self, obj, val):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor):
32 'read-write interface to named attribute'
33 def __init__(self, db, attr):
34 self.attr = attr
35 self.icol = db.data[attr] # index of this attribute in the tuple
36 self.attrSQL = db._attrSQL(attr, sqlColumn=True) # SQL column name
37 def __set__(self, obj, val):
38 obj.db._update(obj.id, self.attrSQL, val) # AND UPDATE THE DATABASE
39 obj.save_local(self.attr, val)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self, db, attr):
44 self.selectSQL = db._attrSQL(attr) # SQL expression for this attr
45 def __get__(self, obj, klass):
46 return obj._select(self.selectSQL)
47 def __set__(self, obj, val):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self, obj, val):
53 obj.db._update(obj.id, self.selectSQL, val) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self, db, attr):
58 self.attr = '_'+attr
59 def __get__(self, obj, klass):
60 return getattr(obj, self.attr)
61 def __set__(self, obj, val):
62 raise AttributeError('attribute %s is read-only!' % self.attr)
65 def select_from_row(row, what):
66 "return value from SQL expression applied to this row"
67 sql,params = row.db._format_query('select %s from %s where %s=%%s limit 2'
68 % (what,row.db.name,row.db.primary_key),
69 (row.id,))
70 row.db.cursor.execute(sql, params)
71 t = row.db.cursor.fetchmany(2) # get at most two rows
72 if len(t) != 1:
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row.db.name,str(row.id),what))
75 return t[0][0] #return the single field we requested
77 def init_row_subclass(cls, db):
78 'add descriptors for db attributes'
79 for attr in db.data: # bind all database columns
80 if attr == 'id': # handle ID attribute specially
81 setattr(cls, attr, cls._idDescriptor(db, attr))
82 continue
83 try: # check if this attr maps to an SQL column
84 db._attrSQL(attr, columnNumber=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls, attr, cls._sqlDescriptor(db, attr))
87 else: # treat as interface to our stored tuple
88 setattr(cls, attr, cls._columnDescriptor(db, attr))
90 def dir_row(self):
91 """get list of column names as our attributes """
92 return self.db.data.keys()
94 class TupleO(object):
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor = TupleDescriptor
115 _idDescriptor = TupleIDDescriptor
116 _sqlDescriptor = SQLDescriptor
117 _init_subclass = classmethod(init_row_subclass)
118 _select = select_from_row
119 __dir__ = dir_row
120 def __init__(self, data):
121 self._data = data # save our data tuple
123 def insert_and_cache_id(self, l, **kwargs):
124 'insert tuple into db and cache its rowID on self'
125 self.db._insert(l) # save to database
126 try:
127 rowID = kwargs['id'] # use the ID supplied by user
128 except KeyError:
129 rowID = self.db.get_insert_id() # get auto-inc ID value
130 self.cache_id(rowID) # cache this ID on self
132 class TupleORW(TupleO):
133 'read-write version of TupleO'
134 _columnDescriptor = TupleDescriptorRW
135 insert_and_cache_id = insert_and_cache_id
136 def __init__(self, data, newRow=False, **kwargs):
137 if not newRow: # just cache data from the database
138 self._data = data
139 return
140 self._data = self.db.tuple_from_dict(kwargs) # convert to tuple
141 self.insert_and_cache_id(self._data, **kwargs)
142 def cache_id(self,row_id):
143 self.save_local('id',row_id)
144 def save_local(self,attr,val):
145 icol = self._attrcol[attr]
146 try:
147 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self._data = list(self._data)
150 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
152 TupleO._RWClass = TupleORW # record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self, db, attr, readOnly = False):
157 self.attr = attr
158 self.col = db._attrSQL(attr, sqlColumn=True) # MAP THIS TO SQL COLUMN NAME
159 self.db = db
160 if readOnly:
161 self.__class__ = self._readOnlyClass
162 def __get__(self, obj, objtype):
163 try:
164 return obj.__dict__[self.attr]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self.col==self.db.primary_key:
167 raise AttributeError
168 self.db._select('where %s=%%s' % self.db.primary_key,(obj.id,),self.col)
169 l = self.db.cursor.fetchall()
170 if len(l)!=1:
171 raise AttributeError('db row not found or not unique!')
172 obj.__dict__[self.attr] = l[0][0] # UPDATE THE CACHE
173 return l[0][0]
174 def __set__(self, obj, val):
175 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self.db._update(obj.id, self.col, val) # UPDATE THE DATABASE
177 obj.__dict__[self.attr] = val # UPDATE THE CACHE
178 ## try:
179 ## m = self.consequences
180 ## except AttributeError:
181 ## return
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
185 ## import new
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor):
188 def __set__(self, obj, val):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor._readOnlyClass = ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor = _sqlDescriptor = SQLDescriptor
200 _idDescriptor = ReadOnlyDescriptor
201 _init_subclass = classmethod(init_row_subclass)
202 _select = select_from_row
203 __dir__ = dir_row
204 def __init__(self, rowID):
205 self._id = rowID
208 class SQLRowRW(SQLRow):
209 'read-write version of SQLRow'
210 _columnDescriptor = SQLDescriptorRW
211 insert_and_cache_id = insert_and_cache_id
212 def __init__(self, rowID, newRow=False, **kwargs):
213 if not newRow: # just cache data from the database
214 return self.cache_id(rowID)
215 l = self.db.tuple_from_dict(kwargs) # convert to tuple
216 self.insert_and_cache_id(l, **kwargs)
217 def cache_id(self, rowID):
218 self._id = rowID
220 SQLRow._RWClass = SQLRowRW
224 def list_to_dict(names, values):
225 'return dictionary of those named args that are present in values[]'
226 d={}
227 for i,v in enumerate(values):
228 try:
229 d[names[i]] = v
230 except IndexError:
231 break
232 return d
235 def get_name_cursor(name=None, **kwargs):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
239 if name is not None:
240 argList = name.split() # TREAT AS WS-SEPARATED LIST
241 if len(argList)>1:
242 name = argList[0] # USE 1ST ARG AS TABLE NAME
243 argnames = ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs = kwargs.copy() # a copy we can overwrite
245 kwargs.update(list_to_dict(argnames, argList[1:]))
246 serverInfo = DBServerInfo(**kwargs)
247 return name,serverInfo.cursor(),serverInfo
249 def mysql_connect(connect=None, configFile=None, useStreaming=False, **args):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs = args.copy() # a copy we can modify
252 if 'user' not in kwargs and configFile is None: #Find where config file is
253 osname = platform.system()
254 if osname in('Microsoft', 'Windows'): # Machine is a Windows box
255 paths = []
256 try: # handle case where WINDIR not defined by Windows...
257 windir = os.environ['WINDIR']
258 paths += [(windir, 'my.ini'), (windir, 'my.cnf')]
259 except KeyError:
260 pass
261 try:
262 sysdrv = os.environ['SYSTEMDRIVE']
263 paths += [(sysdrv, os.path.sep + 'my.ini'),
264 (sysdrv, os.path.sep + 'my.cnf')]
265 except KeyError:
266 pass
267 if len(paths) > 0:
268 configFile = get_valid_path(*paths)
269 else: # treat as normal platform with home directories
270 configFile = os.path.join(os.path.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile = configFile or os.path.join( os.getcwd(), 'mysql.cnf' )
276 if configFile and os.path.exists(configFile):
277 kwargs['read_default_file'] = configFile
278 connect = None # force it to use MySQLdb
279 if connect is None:
280 import MySQLdb
281 connect = MySQLdb.connect
282 kwargs['compress'] = True
283 if useStreaming: # use server side cursors for scalable result sets
284 try:
285 from MySQLdb import cursors
286 kwargs['cursorclass'] = cursors.SSCursor
287 except (ImportError, AttributeError):
288 pass
289 conn = connect(**kwargs)
290 cursor = conn.cursor()
291 return conn,cursor
293 _mysqlMacros = dict(IGNORE='ignore', REPLACE='replace',
294 AUTO_INCREMENT='AUTO_INCREMENT', SUBSTRING='substring',
295 SUBSTR_FROM='FROM', SUBSTR_FOR='FOR')
297 def mysql_table_schema(self, analyzeSchema=True):
298 'retrieve table schema from a MySQL database, save on self'
299 import MySQLdb
300 self._format_query = SQLFormatDict(MySQLdb.paramstyle, _mysqlMacros)
301 if not analyzeSchema:
302 return
303 self.clear_schema() # reset settings and dictionaries
304 self.cursor.execute('describe %s' % self.name) # get info about columns
305 columns = self.cursor.fetchall()
306 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
307 for icol,c in enumerate(columns):
308 field = c[0]
309 self.columnName.append(field) # list of columns in same order as table
310 if c[3] == "PRI": # record as primary key
311 if self.primary_key is None:
312 self.primary_key = field
313 else:
314 try:
315 self.primary_key.append(field)
316 except AttributeError:
317 self.primary_key = [self.primary_key,field]
318 if c[1][:3].lower() == 'int':
319 self.usesIntID = True
320 else:
321 self.usesIntID = False
322 elif c[3] == "MUL":
323 self.indexed[field] = icol
324 self.description[field] = self.cursor.description[icol]
325 self.columnType[field] = c[1] # SQL COLUMN TYPE
327 _sqliteMacros = dict(IGNORE='or ignore', REPLACE='insert or replace',
328 AUTO_INCREMENT='', SUBSTRING='substr',
329 SUBSTR_FROM=',', SUBSTR_FOR=',')
331 def import_sqlite():
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
333 try:
334 import sqlite3 as sqlite
335 except ImportError:
336 from pysqlite2 import dbapi2 as sqlite
337 return sqlite
339 def sqlite_table_schema(self, analyzeSchema=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite = import_sqlite()
342 self._format_query = SQLFormatDict(sqlite.paramstyle, _sqliteMacros)
343 if not analyzeSchema:
344 return
345 self.clear_schema() # reset settings and dictionaries
346 self.cursor.execute('PRAGMA table_info("%s")' % self.name)
347 columns = self.cursor.fetchall()
348 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
349 for icol,c in enumerate(columns):
350 field = c[1]
351 self.columnName.append(field) # list of columns in same order as table
352 self.description[field] = self.cursor.description[icol]
353 self.columnType[field] = c[2] # SQL COLUMN TYPE
354 self.cursor.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self.name) # get primary key / unique indexes
355 for indexname in self.cursor.fetchall(): # search indexes for primary key
356 self.cursor.execute('PRAGMA index_info("%s")' % indexname)
357 l = self.cursor.fetchall() # get list of columns in this index
358 if len(l) == 1: # assume 1st single-column unique index is primary key!
359 self.primary_key = l[0][2]
360 break # done searching for primary key!
361 if self.primary_key is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self.cursor.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self.name)
363 sql = self.cursor.fetchall()[0][0]
364 for columnSQL in sql[sql.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL.lower(): # must be the primary key!
366 col = columnSQL.split()[0] # get column name
367 if col in self.columnType:
368 self.primary_key = col
369 break # done searching for primary key!
370 else:
371 raise ValueError('unknown primary key %s in table %s'
372 % (col,self.name))
373 if self.primary_key is not None: # check its type
374 if self.columnType[self.primary_key] == 'int' or \
375 self.columnType[self.primary_key] == 'integer':
376 self.usesIntID = True
377 else:
378 self.usesIntID = False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats = dict(pyformat='%%(%d)s', numeric=':%d', named=':%d',
393 qmark='(ignore)', format='(ignore)')
394 def __init__(self, paramstyle, substitutionDict={}):
395 self.substitutionDict = substitutionDict.copy()
396 self.paramstyle = paramstyle
397 self.paramFormat = self._paramFormats[paramstyle]
398 self.makeDict = (paramstyle == 'pyformat' or paramstyle == 'named')
399 if paramstyle == 'qmark': # handle these as simple substitution
400 self.substitutionDict['?'] = '?'
401 elif paramstyle == 'format':
402 self.substitutionDict['?'] = '%s'
403 def __getitem__(self, k):
404 'apply correct substitution for this SQL interface'
405 try:
406 return self.substitutionDict[k] # apply our substitutions
407 except KeyError:
408 pass
409 if k == '?': # sequential parameter
410 s = self.paramFormat % self.iparam
411 self.iparam += 1 # advance to the next parameter
412 return s
413 raise KeyError('unknown macro: %s' % k)
414 def __call__(self, sql, paramList):
415 'returns corrected sql,params for this interface'
416 self.iparam = 1 # DB-ABI param indexing begins at 1
417 sql = sql.replace('%s', '%(?)s') # convert format into pyformat
418 s = sql % self # apply all %(x)s replacements in sql
419 if self.makeDict: # construct a params dict
420 paramDict = {}
421 for i,param in enumerate(paramList):
422 paramDict[str(i + 1)] = param #DB-ABI param indexing begins at 1
423 return s,paramDict
424 else: # just return the original params list
425 return s,paramList
427 def get_table_schema(self, analyzeSchema=True):
428 'run the right schema function based on type of db server connection'
429 try:
430 modname = self.cursor.__class__.__module__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
433 try:
434 schema_func = self._schemaModuleDict[modname]
435 except KeyError:
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname)
439 schema_func(self, analyzeSchema) # run the schema function
442 _schemaModuleDict = {'MySQLdb.cursors':mysql_table_schema,
443 'pysqlite2.dbapi2':sqlite_table_schema,
444 'sqlite3':sqlite_table_schema}
446 class SQLTableBase(object, UserDict.DictMixin):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict = _schemaModuleDict # default module list
449 get_table_schema = get_table_schema
450 def __init__(self,name,cursor=None,itemClass=None,attrAlias=None,
451 clusterKey=None,createTable=None,graph=None,maxCache=None,
452 arraysize=1024, itemSliceClass=None, dropIfExists=False,
453 serverInfo=None, autoGC=True, orderBy=None,
454 writeable=False, iterSQL=None, iterColumns=None, **kwargs):
455 if autoGC: # automatically garbage collect unused objects
456 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
457 else:
458 self._weakValueDict = {}
459 self.autoGC = autoGC
460 self.orderBy = orderBy
461 if orderBy and serverInfo and serverInfo._serverType == 'mysql':
462 if iterSQL and iterColumns: # both required for mysql!
463 self.iterSQL, self.iterColumns = iterSQL, iterColumns
464 else:
465 raise ValueError('For MySQL tables with orderBy, you MUST specify iterSQL and iterColumns as well!')
467 self.writeable = writeable
468 if cursor is None:
469 if serverInfo is not None: # get cursor from serverInfo
470 cursor = serverInfo.cursor()
471 else: # try to read connection info from name or config file
472 name,cursor,serverInfo = get_name_cursor(name,**kwargs)
473 else:
474 warnings.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
475 DeprecationWarning, stacklevel=2)
476 self.cursor = cursor
477 if createTable is not None: # RUN COMMAND TO CREATE THIS TABLE
478 if dropIfExists: # get rid of any existing table
479 cursor.execute('drop table if exists ' + name)
480 self.get_table_schema(False) # check dbtype, init _format_query
481 sql,params = self._format_query(createTable, ()) # apply macros
482 cursor.execute(sql) # create the table
483 self.name = name
484 if graph is not None:
485 self.graph = graph
486 if maxCache is not None:
487 self.maxCache = maxCache
488 if arraysize is not None:
489 self.arraysize = arraysize
490 cursor.arraysize = arraysize
491 self.get_table_schema() # get schema of columns to serve as attrs
492 self.data = {} # map of all attributes, including aliases
493 for icol,field in enumerate(self.columnName):
494 self.data[field] = icol # 1st add mappings to columns
495 try:
496 self.data['id']=self.data[self.primary_key]
497 except (KeyError,TypeError):
498 pass
499 if hasattr(self,'_attr_alias'): # apply attribute aliases for this class
500 self.addAttrAlias(False,**self._attr_alias)
501 self.objclass(itemClass) # NEED TO SUBCLASS OUR ITEM CLASS
502 if itemSliceClass is not None:
503 self.itemSliceClass = itemSliceClass
504 get_bound_subclass(self, 'itemSliceClass', self.name) # need to subclass itemSliceClass
505 if attrAlias is not None: # ADD ATTRIBUTE ALIASES
506 self.attrAlias = attrAlias # RECORD FOR PICKLING PURPOSES
507 self.data.update(attrAlias)
508 if clusterKey is not None:
509 self.clusterKey=clusterKey
510 if serverInfo is not None:
511 self.serverInfo = serverInfo
513 def __len__(self):
514 self._select(selectCols='count(*)')
515 return self.cursor.fetchone()[0]
516 def __hash__(self):
517 return id(self)
518 def __cmp__(self, other):
519 'only match self and no other!'
520 if self is other:
521 return 0
522 else:
523 return cmp(id(self), id(other))
524 _pickleAttrs = dict(name=0, clusterKey=0, maxCache=0, arraysize=0,
525 attrAlias=0, serverInfo=0, autoGC=0, orderBy=0,
526 writeable=0, iterSQL=0, iterColumns=0)
527 __getstate__ = standard_getstate
528 def __setstate__(self,state):
529 # default cursor provisioning by worldbase is deprecated!
530 ## if 'serverInfo' not in state: # hmm, no address for db server?
531 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
532 ## from Data import getResource
533 ## state['cursor'] = getResource.getTableCursor(state['name'])
534 ## except ImportError:
535 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
536 self.__init__(**state)
537 def __repr__(self):
538 return '<SQL table '+self.name+'>'
540 def clear_schema(self):
541 'reset all schema information for this table'
542 self.description={}
543 self.columnName = []
544 self.columnType = {}
545 self.usesIntID = None
546 self.primary_key = None
547 self.indexed = {}
548 def _attrSQL(self,attr,sqlColumn=False,columnNumber=False):
549 "Translate python attribute name to appropriate SQL expression"
550 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
551 field=self.data[attr]
552 except KeyError:
553 raise AttributeError('attribute %s not a valid column or alias in %s'
554 % (attr,self.name))
555 if sqlColumn: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
556 try: # CHECK IF field IS COLUMN NUMBER
557 return self.columnName[field] # RETURN SQL COLUMN NAME
558 except TypeError:
559 try: # CHECK IF field IS SQL COLUMN NAME
560 return self.columnName[self.data[field]] # THIS WILL JUST RETURN field...
561 except (KeyError,TypeError):
562 raise AttributeError('attribute %s does not map to an SQL column in %s'
563 % (attr,self.name))
564 if columnNumber:
565 try: # CHECK IF field IS A COLUMN NUMBER
566 return field+0 # ONLY RETURN AN INTEGER
567 except TypeError:
568 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
569 return self.data[field]+0 # ONLY RETURN AN INTEGER
570 except (KeyError,TypeError):
571 raise ValueError('attribute %s does not map to a SQL column!' % attr)
572 if isinstance(field,types.StringType):
573 attr=field # USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
574 elif attr=='id':
575 attr=self.primary_key
576 return attr
577 def addAttrAlias(self,saveToPickle=True,**kwargs):
578 """Add new attributes as aliases of existing attributes.
579 They can be specified either as named args:
580 t.addAttrAlias(newattr=oldattr)
581 or by passing a dictionary kwargs whose keys are newattr
582 and values are oldattr:
583 t.addAttrAlias(**kwargs)
584 saveToPickle=True forces these aliases to be saved if object is pickled.
586 if saveToPickle:
587 self.attrAlias.update(kwargs)
588 for key,val in kwargs.items():
589 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
590 self.data[val]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
591 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
592 except TypeError: # val IS ITSELF AN ALIAS
593 self.data[key] = self.data[val] # SO MAP TO WHAT IT MAPS TO
594 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
595 self.data[key] = val
596 def objclass(self,oclass=None):
597 "Create class representing a row in this table by subclassing oclass, adding data"
598 if oclass is not None: # use this as our base itemClass
599 self.itemClass = oclass
600 if self.writeable:
601 self.itemClass = self.itemClass._RWClass # use its writeable version
602 oclass = get_bound_subclass(self, 'itemClass', self.name,
603 subclassArgs=dict(db=self)) # bind itemClass
604 if issubclass(oclass, TupleO):
605 oclass._attrcol = self.data # BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
606 if hasattr(oclass,'_tableclass') and not isinstance(self,oclass._tableclass):
607 self.__class__=oclass._tableclass # ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
608 def _select(self, whereClause='', params=(), selectCols='t1.*',
609 cursor=None, orderBy='', limit=''):
610 'execute the specified query but do not fetch'
611 sql,params = self._format_query('select %s from %s t1 %s %s %s'
612 % (selectCols, self.name, whereClause, orderBy,
613 limit), params)
614 if cursor is None:
615 self.cursor.execute(sql, params)
616 else:
617 cursor.execute(sql, params)
618 def select(self,whereClause,params=None,oclass=None,selectCols='t1.*'):
619 "Generate the list of objects that satisfy the database SELECT"
620 if oclass is None:
621 oclass=self.itemClass
622 self._select(whereClause,params,selectCols)
623 l=self.cursor.fetchall()
624 for t in l:
625 yield self.cacheItem(t,oclass)
626 def query(self,**kwargs):
627 'query for intersection of all specified kwargs, returned as iterator'
628 criteria = []
629 params = []
630 for k,v in kwargs.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
631 if v is None: # CONVERT TO SQL NULL TEST
632 criteria.append('%s IS NULL' % self._attrSQL(k))
633 else: # TEST FOR EQUALITY
634 criteria.append('%s=%%s' % self._attrSQL(k))
635 params.append(v)
636 return self.select('where '+' and '.join(criteria),params)
637 def _update(self,row_id,col,val):
638 'update a single field in the specified row to the specified value'
639 sql,params = self._format_query('update %s set %s=%%s where %s=%%s'
640 %(self.name,col,self.primary_key),
641 (val,row_id))
642 self.cursor.execute(sql, params)
643 def getID(self,t):
644 try:
645 return t[self.data['id']] # GET ID FROM TUPLE
646 except TypeError: # treat as alias
647 return t[self.data[self.data['id']]]
648 def cacheItem(self,t,oclass):
649 'get obj from cache if possible, or construct from tuple'
650 try:
651 id=self.getID(t)
652 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
653 return oclass(t)
654 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
655 return self._weakValueDict[id]
656 except KeyError:
657 pass
658 o = oclass(t)
659 self._weakValueDict[id] = o # CACHE THIS ITEM IN OUR DICTIONARY
660 return o
661 def cache_items(self,rows,oclass=None):
662 if oclass is None:
663 oclass=self.itemClass
664 for t in rows:
665 yield self.cacheItem(t,oclass)
666 def foreignKey(self,attr,k):
667 'get iterator for objects with specified foreign key value'
668 return self.select('where %s=%%s'%attr,(k,))
669 def limit_cache(self):
670 'APPLY maxCache LIMIT TO CACHE SIZE'
671 try:
672 if self.maxCache<len(self._weakValueDict):
673 self._weakValueDict.clear()
674 except AttributeError:
675 pass
677 def get_new_cursor(self):
678 """Return a new cursor object, or None if not possible """
679 try:
680 new_cursor = self.serverInfo.new_cursor
681 except AttributeError:
682 return None
683 return new_cursor(self.arraysize)
685 def generic_iterator(self, cursor=None, fetch_f=None, cache_f=None,
686 map_f=iter, cursorHolder=None):
687 """generic iterator that runs fetch, cache and map functions.
688 cursorHolder is used only to keep a ref in this function's locals,
689 so that if it is prematurely terminated (by deleting its
690 iterator), cursorHolder.__del__() will close the cursor."""
691 if fetch_f is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
692 if cursor is None:
693 fetch_f = self.cursor.fetchmany
694 else: # isolate this iter from other queries
695 fetch_f = cursor.fetchmany
696 if cache_f is None:
697 cache_f = self.cache_items
698 while True:
699 self.limit_cache()
700 rows = fetch_f() # FETCH THE NEXT SET OF ROWS
701 if len(rows)==0: # NO MORE DATA SO ALL DONE
702 break
703 for v in map_f(cache_f(rows)): # CACHE AND GENERATE RESULTS
704 yield v
705 def tuple_from_dict(self, d):
706 'transform kwarg dict into tuple for storing in database'
707 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
708 for col,icol in self.data.items():
709 try:
710 l[icol] = d[col]
711 except (KeyError,TypeError):
712 pass
713 return l
714 def tuple_from_obj(self, obj):
715 'transform object attributes into tuple for storing in database'
716 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
717 for col,icol in self.data.items():
718 try:
719 l[icol] = getattr(obj,col)
720 except (AttributeError,TypeError):
721 pass
722 return l
723 def _insert(self, l):
724 '''insert tuple into the database. Note this uses the MySQL
725 extension REPLACE, which overwrites any duplicate key.'''
726 s = '%(REPLACE)s into ' + self.name + ' values (' \
727 + ','.join(['%s']*len(l)) + ')'
728 sql,params = self._format_query(s, l)
729 self.cursor.execute(sql, params)
730 def insert(self, obj):
731 '''insert new row by transforming obj to tuple of values'''
732 l = self.tuple_from_obj(obj)
733 self._insert(l)
734 def get_insert_id(self):
735 'get the primary key value for the last INSERT'
736 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
737 auto_id = self.cursor.lastrowid
738 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
739 raise NotImplementedError('''your db lacks lastrowid support?''')
740 if auto_id is None:
741 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
742 return auto_id
743 def new(self, **kwargs):
744 'return a new record with the assigned attributes, added to DB'
745 if not self.writeable:
746 raise ValueError('this database is read only!')
747 obj = self.itemClass(None, newRow=True, **kwargs) # saves itself to db
748 self._weakValueDict[obj.id] = obj # AND SAVE TO OUR LOCAL DICT CACHE
749 return obj
750 def clear_cache(self):
751 'empty the cache'
752 self._weakValueDict.clear()
753 def __delitem__(self, k):
754 if not self.writeable:
755 raise ValueError('this database is read only!')
756 sql,params = self._format_query('delete from %s where %s=%%s'
757 % (self.name,self.primary_key),(k,))
758 self.cursor.execute(sql, params)
759 try:
760 del self._weakValueDict[k]
761 except KeyError:
762 pass
764 def getKeys(self,queryOption='', selectCols=None):
765 'uses db select; does not force load'
766 if selectCols is None:
767 selectCols=self.primary_key
768 if queryOption=='' and self.orderBy is not None:
769 queryOption = self.orderBy # apply default ordering
770 self.cursor.execute('select %s from %s %s'
771 %(selectCols,self.name,queryOption))
772 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
774 def iter_keys(self, selectCols=None, orderBy='', map_f=iter,
775 cache_f=lambda x:[t[0] for t in x], get_f=None, **kwargs):
776 'guarantee correct iteration insulated from other queries'
777 if selectCols is None:
778 selectCols=self.primary_key
779 if orderBy=='' and self.orderBy is not None:
780 orderBy = self.orderBy # apply default ordering
781 cursor = self.get_new_cursor()
782 if cursor: # got our own cursor, guaranteeing query isolation
783 if hasattr(self.serverInfo, 'iter_keys') \
784 and self.serverInfo.custom_iter_keys:
785 # use custom iter_keys() method from serverInfo
786 return self.serverInfo.iter_keys(self, cursor, selectCols=selectCols,
787 map_f=map_f, orderBy=orderBy,
788 cache_f=cache_f, **kwargs)
789 else:
790 self._select(cursor=cursor, selectCols=selectCols,
791 orderBy=orderBy, **kwargs)
792 return self.generic_iterator(cursor=cursor, cache_f=cache_f,
793 map_f=map_f,
794 cursorHolder=CursorCloser(cursor))
795 else: # must pre-fetch all keys to ensure query isolation
796 if get_f is not None:
797 return iter(get_f())
798 else:
799 return iter(self.keys())
801 class SQLTable(SQLTableBase):
802 "Provide on-the-fly access to rows in the database, caching the results in dict"
803 itemClass = TupleO # our default itemClass; constructor can override
804 keys=getKeys
805 __iter__ = iter_keys
806 def load(self,oclass=None):
807 "Load all data from the table"
808 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
809 return self._isLoaded
810 except AttributeError:
811 pass
812 if oclass is None:
813 oclass=self.itemClass
814 self.cursor.execute('select * from %s' % self.name)
815 l=self.cursor.fetchall()
816 self._weakValueDict = {} # just store the whole dataset in memory
817 for t in l:
818 self.cacheItem(t,oclass) # CACHE IT IN LOCAL DICTIONARY
819 self._isLoaded=True # MARK THIS CONTAINER AS FULLY LOADED
821 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
822 try:
823 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
824 except KeyError: # NOT FOUND, SO TRY THE DATABASE
825 sql,params = self._format_query('select * from %s where %s=%%s limit 2'
826 % (self.name,self.primary_key),(k,))
827 self.cursor.execute(sql, params)
828 l = self.cursor.fetchmany(2) # get at most 2 rows
829 if len(l) != 1:
830 raise KeyError('%s not found in %s, or not unique' %(str(k),self.name))
831 self.limit_cache()
832 return self.cacheItem(l[0],self.itemClass) # CACHE IT IN LOCAL DICTIONARY
833 def __setitem__(self, k, v):
834 if not self.writeable:
835 raise ValueError('this database is read only!')
836 try:
837 if v.db is not self:
838 raise AttributeError
839 except AttributeError:
840 raise ValueError('object not bound to itemClass for this db!')
841 try:
842 oldID = v.id
843 if oldID is None:
844 raise AttributeError
845 except AttributeError:
846 pass
847 else: # delete row with old ID
848 del self[v.id]
849 v.cache_id(k) # cache the new ID on the object
850 self.insert(v) # SAVE TO THE RELATIONAL DB SERVER
851 self._weakValueDict[k] = v # CACHE THIS ITEM IN OUR DICTIONARY
852 def items(self):
853 'forces load of entire table into memory'
854 self.load()
855 return [(k,self[k]) for k in self] # apply orderBy rules...
856 def iteritems(self):
857 'uses arraysize / maxCache and fetchmany() to manage data transfer'
858 return iter_keys(self, selectCols='*', cache_f=None,
859 map_f=generate_items, get_f=self.items)
860 def values(self):
861 'forces load of entire table into memory'
862 self.load()
863 return [self[k] for k in self] # apply orderBy rules...
864 def itervalues(self):
865 'uses arraysize / maxCache and fetchmany() to manage data transfer'
866 return iter_keys(self, selectCols='*', cache_f=None, get_f=self.values)
868 def getClusterKeys(self,queryOption=''):
869 'uses db select; does not force load'
870 self.cursor.execute('select distinct %s from %s %s'
871 %(self.clusterKey,self.name,queryOption))
872 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
875 class SQLTableClustered(SQLTable):
876 '''use clusterKey to load a whole cluster of rows at once,
877 specifically, all rows that share the same clusterKey value.'''
878 def __init__(self, *args, **kwargs):
879 kwargs = kwargs.copy() # get a copy we can alter
880 kwargs['autoGC'] = False # don't use WeakValueDictionary
881 SQLTable.__init__(self, *args, **kwargs)
882 def keys(self):
883 return getKeys(self,'order by %s' %self.clusterKey)
884 def clusterkeys(self):
885 return getClusterKeys(self, 'order by %s' %self.clusterKey)
886 def __getitem__(self,k):
887 try:
888 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
889 except KeyError: # NOT FOUND, SO TRY THE DATABASE
890 sql,params = self._format_query('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
891 % (self.name,self.name,self.primary_key,
892 self.clusterKey,self.clusterKey),(k,))
893 self.cursor.execute(sql, params)
894 l=self.cursor.fetchall()
895 self.limit_cache()
896 for t in l: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
897 self.cacheItem(t,self.itemClass)
898 return self._weakValueDict[k] # should be in cache, if row k exists
899 def itercluster(self,cluster_id):
900 'iterate over all items from the specified cluster'
901 self.limit_cache()
902 return self.select('where %s=%%s'%self.clusterKey,(cluster_id,))
903 def fetch_cluster(self):
904 'use self.cursor.fetchmany to obtain all rows for next cluster'
905 icol = self._attrSQL(self.clusterKey,columnNumber=True)
906 result = []
907 try:
908 rows = self._fetch_cluster_cache # USE SAVED ROWS FROM PREVIOUS CALL
909 del self._fetch_cluster_cache
910 except AttributeError:
911 rows = self.cursor.fetchmany()
912 try:
913 cluster_id = rows[0][icol]
914 except IndexError:
915 return result
916 while len(rows)>0:
917 for i,t in enumerate(rows): # CHECK THAT ALL ROWS FROM THIS CLUSTER
918 if cluster_id != t[icol]: # START OF A NEW CLUSTER
919 result += rows[:i] # RETURN ROWS OF CURRENT CLUSTER
920 self._fetch_cluster_cache = rows[i:] # SAVE NEXT CLUSTER
921 return result
922 result += rows
923 rows = self.cursor.fetchmany() # GET NEXT SET OF ROWS
924 return result
925 def itervalues(self):
926 'uses arraysize / maxCache and fetchmany() to manage data transfer'
927 cursor = self.get_new_cursor()
928 self._select('order by %s' %self.clusterKey, cursor=cursor)
929 return self.generic_iterator(cursor, self.fetch_cluster,
930 cursorHolder=CursorHolder(cursor))
931 def iteritems(self):
932 'uses arraysize / maxCache and fetchmany() to manage data transfer'
933 cursor = self.get_new_cursor()
934 self._select('order by %s' %self.clusterKey, cursor=cursor)
935 return self.generic_iterator(cursor, self.fetch_cluster,
936 map_f=generate_items,
937 cursorHolder=CursorHolder(cursor))
939 class SQLForeignRelation(object):
940 'mapping based on matching a foreign key in an SQL table'
941 def __init__(self,table,keyName):
942 self.table=table
943 self.keyName=keyName
944 def __getitem__(self,k):
945 'get list of objects o with getattr(o,keyName)==k.id'
946 l=[]
947 for o in self.table.select('where %s=%%s'%self.keyName,(k.id,)):
948 l.append(o)
949 if len(l)==0:
950 raise KeyError('%s not found in %s' %(str(k),self.name))
951 return l
954 class SQLTableNoCache(SQLTableBase):
955 '''Provide on-the-fly access to rows in the database;
956 values are simply an object interface (SQLRow) to back-end db query.
957 Row data are not stored locally, but always accessed by querying the db'''
958 itemClass=SQLRow # DEFAULT OBJECT CLASS FOR ROWS...
959 keys=getKeys
960 __iter__ = iter_keys
961 def getID(self,t): return t[0] # GET ID FROM TUPLE
962 def select(self,whereClause,params):
963 return SQLTableBase.select(self,whereClause,params,self.oclass,
964 self._attrSQL('id'))
965 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
966 try:
967 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
968 except KeyError: # NOT FOUND, SO TRY THE DATABASE
969 self._select('where %s=%%s' % self.primary_key, (k,),
970 self.primary_key)
971 t = self.cursor.fetchmany(2)
972 if len(t) != 1:
973 raise KeyError('id %s non-existent or not unique' % k)
974 o = self.itemClass(k) # create obj referencing this ID
975 self._weakValueDict[k] = o # cache the SQLRow object
976 return o
977 def __setitem__(self, k, v):
978 if not self.writeable:
979 raise ValueError('this database is read only!')
980 try:
981 if v.db is not self:
982 raise AttributeError
983 except AttributeError:
984 raise ValueError('object not bound to itemClass for this db!')
985 try:
986 del self[k] # delete row with new ID if any
987 except KeyError:
988 pass
989 try:
990 del self._weakValueDict[v.id] # delete from old cache location
991 except KeyError:
992 pass
993 self._update(v.id, self.primary_key, k) # just change its ID in db
994 v.cache_id(k) # change the cached ID value
995 self._weakValueDict[k] = v # assign to new cache location
996 def addAttrAlias(self,**kwargs):
997 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
999 SQLRow._tableclass=SQLTableNoCache # SQLRow IS FOR NON-CACHING TABLE INTERFACE
1002 class SQLTableMultiNoCache(SQLTableBase):
1003 "Trivial on-the-fly access for table with key that returns multiple rows"
1004 itemClass = TupleO # default itemClass; constructor can override
1005 _distinct_key='id' # DEFAULT COLUMN TO USE AS KEY
1006 def __init__(self, *args, **kwargs):
1007 SQLTableBase.__init__(self, *args, **kwargs)
1008 self.distinct_key = self._attrSQL(self._distinct_key)
1009 if not self.orderBy:
1010 self.orderBy = 'GROUP BY %s ORDER BY %s' % (self.distinct_key,
1011 self.distinct_key)
1012 self.iterSQL = 'WHERE %s>%%s' % self.distinct_key
1013 self.iterColumns = (self.distinct_key,)
1014 def keys(self):
1015 return getKeys(self, selectCols=self.distinct_key)
1016 def __iter__(self):
1017 return iter_keys(self, selectCols=self.distinct_key)
1018 def __getitem__(self,id):
1019 sql,params = self._format_query('select * from %s where %s=%%s'
1020 %(self.name,self._attrSQL(self._distinct_key)),(id,))
1021 self.cursor.execute(sql, params)
1022 l=self.cursor.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1023 for row in l:
1024 yield self.itemClass(row)
1025 def addAttrAlias(self,**kwargs):
1026 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
1030 class SQLEdges(SQLTableMultiNoCache):
1031 '''provide iterator over edges as (source,target,edge)
1032 and getitem[edge] --> [(source,target),...]'''
1033 _distinct_key='edge_id'
1034 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1035 _pickleAttrs.update(dict(graph=0))
1036 def keys(self):
1037 self.cursor.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1038 %(self._attrSQL('source_id'),self._attrSQL('target_id'),
1039 self._attrSQL('edge_id'),self.name,
1040 self._attrSQL('target_id'),self._attrSQL('source_id'),
1041 self._attrSQL('target_id')))
1042 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1043 for source_id,target_id,edge_id in self.cursor.fetchall():
1044 l.append((self.graph.unpack_source(source_id),
1045 self.graph.unpack_target(target_id),
1046 self.graph.unpack_edge(edge_id)))
1047 return l
1048 __call__=keys
1049 def __iter__(self):
1050 return iter(self.keys())
1051 def __getitem__(self,edge):
1052 sql,params = self._format_query('select %s,%s from %s where %s=%%s'
1053 %(self._attrSQL('source_id'),
1054 self._attrSQL('target_id'),
1055 self.name,
1056 self._attrSQL(self._distinct_key)),
1057 (self.graph.pack_edge(edge),))
1058 self.cursor.execute(sql, params)
1059 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1060 for source_id,target_id in self.cursor.fetchall():
1061 l.append((self.graph.unpack_source(source_id),
1062 self.graph.unpack_target(target_id)))
1063 return l
1066 class SQLEdgeDict(object):
1067 '2nd level graph interface to SQL database'
1068 def __init__(self,fromNode,table):
1069 self.fromNode=fromNode
1070 self.table=table
1071 if not hasattr(self.table,'allowMissingNodes'):
1072 sql,params = self.table._format_query('select %s from %s where %s=%%s limit 1'
1073 %(self.table.sourceSQL,
1074 self.table.name,
1075 self.table.sourceSQL),
1076 (self.fromNode,))
1077 self.table.cursor.execute(sql, params)
1078 if len(self.table.cursor.fetchall())<1:
1079 raise KeyError('node not in graph!')
1081 def __getitem__(self,target):
1082 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1083 %(self.table.edgeSQL,
1084 self.table.name,
1085 self.table.sourceSQL,
1086 self.table.targetSQL),
1087 (self.fromNode,
1088 self.table.pack_target(target)))
1089 self.table.cursor.execute(sql, params)
1090 l = self.table.cursor.fetchmany(2) # get at most two rows
1091 if len(l) != 1:
1092 raise KeyError('either no edge from source to target or not unique!')
1093 try:
1094 return self.table.unpack_edge(l[0][0]) # RETURN EDGE
1095 except IndexError:
1096 raise KeyError('no edge from node to target')
1097 def __setitem__(self,target,edge):
1098 sql,params = self.table._format_query('replace into %s values (%%s,%%s,%%s)'
1099 %self.table.name,
1100 (self.fromNode,
1101 self.table.pack_target(target),
1102 self.table.pack_edge(edge)))
1103 self.table.cursor.execute(sql, params)
1104 if not hasattr(self.table,'sourceDB') or \
1105 (hasattr(self.table,'targetDB') and
1106 self.table.sourceDB is self.table.targetDB):
1107 self.table += target # ADD AS NODE TO GRAPH
1108 def __iadd__(self,target):
1109 self[target] = None
1110 return self # iadd MUST RETURN self!
1111 def __delitem__(self,target):
1112 sql,params = self.table._format_query('delete from %s where %s=%%s and %s=%%s'
1113 %(self.table.name,
1114 self.table.sourceSQL,
1115 self.table.targetSQL),
1116 (self.fromNode,
1117 self.table.pack_target(target)))
1118 self.table.cursor.execute(sql, params)
1119 if self.table.cursor.rowcount < 1: # no rows deleted?
1120 raise KeyError('no edge from node to target')
1122 def iterator_query(self):
1123 sql,params = self.table._format_query('select %s,%s from %s where %s=%%s and %s is not null'
1124 %(self.table.targetSQL,
1125 self.table.edgeSQL,
1126 self.table.name,
1127 self.table.sourceSQL,
1128 self.table.targetSQL),
1129 (self.fromNode,))
1130 self.table.cursor.execute(sql, params)
1131 return self.table.cursor.fetchall()
1132 def keys(self):
1133 return [self.table.unpack_target(target_id)
1134 for target_id,edge_id in self.iterator_query()]
1135 def values(self):
1136 return [self.table.unpack_edge(edge_id)
1137 for target_id,edge_id in self.iterator_query()]
1138 def edges(self):
1139 return [(self.table.unpack_source(self.fromNode),self.table.unpack_target(target_id),
1140 self.table.unpack_edge(edge_id))
1141 for target_id,edge_id in self.iterator_query()]
1142 def items(self):
1143 return [(self.table.unpack_target(target_id),self.table.unpack_edge(edge_id))
1144 for target_id,edge_id in self.iterator_query()]
1145 def __iter__(self): return iter(self.keys())
1146 def itervalues(self): return iter(self.values())
1147 def iteritems(self): return iter(self.items())
1148 def __len__(self):
1149 return len(self.keys())
1150 __cmp__ = graph_cmp
1152 class SQLEdgelessDict(SQLEdgeDict):
1153 'for SQLGraph tables that lack edge_id column'
1154 def __getitem__(self,target):
1155 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1156 %(self.table.targetSQL,
1157 self.table.name,
1158 self.table.sourceSQL,
1159 self.table.targetSQL),
1160 (self.fromNode,
1161 self.table.pack_target(target)))
1162 self.table.cursor.execute(sql, params)
1163 l = self.table.cursor.fetchmany(2)
1164 if len(l) != 1:
1165 raise KeyError('either no edge from source to target or not unique!')
1166 return None # no edge info!
1167 def iterator_query(self):
1168 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s is not null'
1169 %(self.table.targetSQL,
1170 self.table.name,
1171 self.table.sourceSQL,
1172 self.table.targetSQL),
1173 (self.fromNode,))
1174 self.table.cursor.execute(sql, params)
1175 return [(t[0],None) for t in self.table.cursor.fetchall()]
1177 SQLEdgeDict._edgelessClass = SQLEdgelessDict
1179 class SQLGraphEdgeDescriptor(object):
1180 'provide an SQLEdges interface on demand'
1181 def __get__(self,obj,objtype):
1182 try:
1183 attrAlias=obj.attrAlias.copy()
1184 except AttributeError:
1185 return SQLEdges(obj.name, obj.cursor, graph=obj)
1186 else:
1187 return SQLEdges(obj.name, obj.cursor, attrAlias=attrAlias,
1188 graph=obj)
1190 def getColumnTypes(createTable,attrAlias={},defaultColumnType='int',
1191 columnAttrs=('source','target','edge'),**kwargs):
1192 'return list of [(colname,coltype),...] for source,target,edge'
1193 l = []
1194 for attr in columnAttrs:
1195 try:
1196 attrName = attrAlias[attr+'_id']
1197 except KeyError:
1198 attrName = attr+'_id'
1199 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1200 l.append((attrName,createTable[attr+'_id']))
1201 continue
1202 except (KeyError,TypeError):
1203 pass
1204 try: # get type info from primary key for that database
1205 db = kwargs[attr+'DB']
1206 if db is None:
1207 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1208 except KeyError:
1209 pass
1210 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1211 it = iter(db)
1212 try: # GET ONE IDENTIFIER FROM THE DATABASE
1213 k = it.next()
1214 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1215 try:
1216 l.append((attrName,db.columnType[db.primary_key]))
1217 continue
1218 except AttributeError:
1219 pass
1220 else: # GET THE TYPE FROM THIS IDENTIFIER
1221 if isinstance(k,int) or isinstance(k,long):
1222 l.append((attrName,'int'))
1223 continue
1224 elif isinstance(k,str):
1225 l.append((attrName,'varchar(32)'))
1226 continue
1227 else:
1228 raise ValueError('SQLGraph node / edge must be int or str!')
1229 l.append((attrName,defaultColumnType))
1230 logger.warn('no type info found for %s, so using default: %s'
1231 % (attrName, defaultColumnType))
1234 return l
1237 class SQLGraph(SQLTableMultiNoCache):
1238 '''provide a graph interface via a SQL table. Key capabilities are:
1239 - setitem with an empty dictionary: a dummy operation
1240 - getitem with a key that exists: return a placeholder
1241 - setitem with non empty placeholder: again a dummy operation
1242 EXAMPLE TABLE SCHEMA:
1243 create table mygraph (source_id int not null,target_id int,edge_id int,
1244 unique(source_id,target_id));
1246 _distinct_key='source_id'
1247 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1248 _pickleAttrs.update(dict(sourceDB=0,targetDB=0,edgeDB=0,allowMissingNodes=0))
1249 _edgeClass = SQLEdgeDict
1250 def __init__(self,name,*l,**kwargs):
1251 graphArgs,tableArgs = split_kwargs(kwargs,
1252 ('attrAlias','defaultColumnType','columnAttrs',
1253 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1254 'edgeDictClass','graph'))
1255 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1256 c = getColumnTypes(**kwargs)
1257 tableArgs['createTable'] = \
1258 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1259 % (name,c[0][0],c[0][1],c[1][0],c[1][1],c[2][0],c[2][1],c[0][0],c[1][0])
1260 try:
1261 self.allowMissingNodes = kwargs['allowMissingNodes']
1262 except KeyError: pass
1263 SQLTableMultiNoCache.__init__(self,name,*l,**tableArgs)
1264 self.sourceSQL = self._attrSQL('source_id')
1265 self.targetSQL = self._attrSQL('target_id')
1266 try:
1267 self.edgeSQL = self._attrSQL('edge_id')
1268 except AttributeError:
1269 self.edgeSQL = None
1270 self._edgeClass = self._edgeClass._edgelessClass
1271 save_graph_db_refs(self,**kwargs)
1272 def __getitem__(self,k):
1273 return self._edgeClass(self.pack_source(k),self)
1274 def __iadd__(self,k):
1275 sql,params = self._format_query('delete from %s where %s=%%s and %s is null'
1276 % (self.name,self.sourceSQL,self.targetSQL),
1277 (self.pack_source(k),))
1278 self.cursor.execute(sql, params)
1279 sql,params = self._format_query('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1280 % self.name,(self.pack_source(k),))
1281 self.cursor.execute(sql, params)
1282 return self # iadd MUST RETURN SELF!
1283 def __isub__(self,k):
1284 sql,params = self._format_query('delete from %s where %s=%%s'
1285 % (self.name,self.sourceSQL),
1286 (self.pack_source(k),))
1287 self.cursor.execute(sql, params)
1288 if self.cursor.rowcount == 0:
1289 raise KeyError('node not found in graph')
1290 return self # iadd MUST RETURN SELF!
1291 __setitem__ = graph_setitem
1292 def __contains__(self,k):
1293 sql,params = self._format_query('select * from %s where %s=%%s limit 1'
1294 %(self.name,self.sourceSQL),
1295 (self.pack_source(k),))
1296 self.cursor.execute(sql, params)
1297 l = self.cursor.fetchmany(2)
1298 return len(l) > 0
1299 def __invert__(self):
1300 'get an interface to the inverse graph mapping'
1301 try: # CACHED
1302 return self._inverse
1303 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1304 attrAlias = dict(source_id=self.targetSQL, # SWAP SOURCE & TARGET
1305 target_id=self.sourceSQL,
1306 edge_id=self.edgeSQL)
1307 if self.edgeSQL is None: # no edge interface
1308 del attrAlias['edge_id']
1309 self._inverse=SQLGraph(self.name,self.cursor,
1310 attrAlias=attrAlias,
1311 **graph_db_inverse_refs(self))
1312 self._inverse._inverse=self
1313 return self._inverse
1314 def __iter__(self):
1315 for k in SQLTableMultiNoCache.__iter__(self):
1316 yield self.unpack_source(k)
1317 def iteritems(self):
1318 for k in SQLTableMultiNoCache.__iter__(self):
1319 yield (self.unpack_source(k), self._edgeClass(k, self))
1320 def itervalues(self):
1321 for k in SQLTableMultiNoCache.__iter__(self):
1322 yield self._edgeClass(k, self)
1323 def keys(self):
1324 return [self.unpack_source(k) for k in SQLTableMultiNoCache.keys(self)]
1325 def values(self): return list(self.itervalues())
1326 def items(self): return list(self.iteritems())
1327 edges=SQLGraphEdgeDescriptor()
1328 update = update_graph
1329 def __len__(self):
1330 'get number of source nodes in graph'
1331 self.cursor.execute('select count(distinct %s) from %s'
1332 %(self.sourceSQL,self.name))
1333 return self.cursor.fetchone()[0]
1334 __cmp__ = graph_cmp
1335 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1336 ## def __cmp__(self,other):
1337 ## node = ()
1338 ## n = 0
1339 ## d = None
1340 ## it = iter(self.edges)
1341 ## while True:
1342 ## try:
1343 ## source,target,edge = it.next()
1344 ## except StopIteration:
1345 ## source = None
1346 ## if source!=node:
1347 ## if d is not None:
1348 ## diff = cmp(n_target,len(d))
1349 ## if diff!=0:
1350 ## return diff
1351 ## if source is None:
1352 ## break
1353 ## node = source
1354 ## n += 1 # COUNT SOURCE NODES
1355 ## n_target = 0
1356 ## try:
1357 ## d = other[node]
1358 ## except KeyError:
1359 ## return 1
1360 ## try:
1361 ## diff = cmp(edge,d[target])
1362 ## except KeyError:
1363 ## return 1
1364 ## if diff!=0:
1365 ## return diff
1366 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1367 ## return cmp(n,len(other))
1369 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1371 class SQLIDGraph(SQLGraph):
1372 add_trivial_packing_methods(locals())
1373 SQLGraph._IDGraphClass = SQLIDGraph
1377 class SQLEdgeDictClustered(dict):
1378 'simple cache for 2nd level dictionary of target_id:edge_id'
1379 def __init__(self,g,fromNode):
1380 self.g=g
1381 self.fromNode=fromNode
1382 dict.__init__(self)
1383 def __iadd__(self,l):
1384 for target_id,edge_id in l:
1385 dict.__setitem__(self,target_id,edge_id)
1386 return self # iadd MUST RETURN SELF!
1388 class SQLEdgesClusteredDescr(object):
1389 def __get__(self,obj,objtype):
1390 e=SQLEdgesClustered(obj.table,obj.edge_id,obj.source_id,obj.target_id,
1391 graph=obj,**graph_db_inverse_refs(obj,True))
1392 for source_id,d in obj.d.iteritems(): # COPY EDGE CACHE
1393 e.load([(edge_id,source_id,target_id)
1394 for (target_id,edge_id) in d.iteritems()])
1395 return e
1397 class SQLGraphClustered(object):
1398 'SQL graph with clustered caching -- loads an entire cluster at a time'
1399 _edgeDictClass=SQLEdgeDictClustered
1400 def __init__(self,table,source_id='source_id',target_id='target_id',
1401 edge_id='edge_id',clusterKey=None,**kwargs):
1402 import types
1403 if isinstance(table,types.StringType): # CREATE THE TABLE INTERFACE
1404 if clusterKey is None:
1405 raise ValueError('you must provide a clusterKey argument!')
1406 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1407 c = getColumnTypes(attrAlias=dict(source_id=source_id,target_id=target_id,
1408 edge_id=edge_id),**kwargs)
1409 kwargs['createTable'] = \
1410 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1411 % (table,c[0][0],c[0][1],c[1][0],c[1][1],
1412 c[2][0],c[2][1],c[0][0],c[1][0])
1413 table = SQLTableClustered(table,clusterKey=clusterKey,**kwargs)
1414 self.table=table
1415 self.source_id=source_id
1416 self.target_id=target_id
1417 self.edge_id=edge_id
1418 self.d={}
1419 save_graph_db_refs(self,**kwargs)
1420 _pickleAttrs = dict(table=0,source_id=0,target_id=0,edge_id=0,sourceDB=0,targetDB=0,
1421 edgeDB=0)
1422 def __getstate__(self):
1423 state = standard_getstate(self)
1424 state['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1425 return state
1426 def __getitem__(self,k):
1427 'get edgeDict for source node k, from cache or by loading its cluster'
1428 try: # GET DIRECTLY FROM CACHE
1429 return self.d[k]
1430 except KeyError:
1431 if hasattr(self,'_isLoaded'):
1432 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1433 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1434 sql,params = self.table._format_query('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1435 %(self.source_id,self.target_id,
1436 self.edge_id,self.table.name,
1437 self.table.name,self.source_id,
1438 self.table.clusterKey,self.table.clusterKey,
1439 self.table.primary_key),
1440 (self.pack_source(k),))
1441 self.table.cursor.execute(sql, params)
1442 self.load(self.table.cursor.fetchall()) # CACHE THIS CLUSTER
1443 return self.d[k] # RETURN EDGE DICT FOR THIS NODE
1444 def load(self,l=None,unpack=True):
1445 'load the specified rows (or all, if None provided) into local cache'
1446 if l is None:
1447 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1448 return self._isLoaded
1449 except AttributeError:
1450 pass
1451 self.table.cursor.execute('select %s,%s,%s from %s'
1452 %(self.source_id,self.target_id,
1453 self.edge_id,self.table.name))
1454 l=self.table.cursor.fetchall()
1455 self._isLoaded=True
1456 self.d.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1457 for source,target,edge in l: # SAVE TO OUR CACHE
1458 if unpack:
1459 source = self.unpack_source(source)
1460 target = self.unpack_target(target)
1461 edge = self.unpack_edge(edge)
1462 try:
1463 self.d[source] += [(target,edge)]
1464 except KeyError:
1465 d = self._edgeDictClass(self,source)
1466 d += [(target,edge)]
1467 self.d[source] = d
1468 def __invert__(self):
1469 'interface to reverse graph mapping'
1470 try:
1471 return self._inverse # INVERSE MAP ALREADY EXISTS
1472 except AttributeError:
1473 pass
1474 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1475 self._inverse=SQLGraphClustered(self.table,self.target_id,self.source_id,
1476 self.edge_id,**graph_db_inverse_refs(self))
1477 self._inverse._inverse=self
1478 for source,d in self.d.iteritems(): # INVERT OUR CACHE
1479 self._inverse.load([(target,source,edge)
1480 for (target,edge) in d.iteritems()],unpack=False)
1481 return self._inverse
1482 edges=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1483 update = update_graph
1484 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1485 def __iter__(self): ################# ITERATORS
1486 'uses db select; does not force load'
1487 return iter(self.keys())
1488 def keys(self):
1489 'uses db select; does not force load'
1490 self.table.cursor.execute('select distinct(%s) from %s'
1491 %(self.source_id,self.table.name))
1492 return [self.unpack_source(t[0])
1493 for t in self.table.cursor.fetchall()]
1494 methodFactory(['iteritems','items','itervalues','values'],
1495 'lambda self:(self.load(),self.d.%s())[1]',locals())
1496 def __contains__(self,k):
1497 try:
1498 x=self[k]
1499 return True
1500 except KeyError:
1501 return False
1503 class SQLIDGraphClustered(SQLGraphClustered):
1504 add_trivial_packing_methods(locals())
1505 SQLGraphClustered._IDGraphClass = SQLIDGraphClustered
1507 class SQLEdgesClustered(SQLGraphClustered):
1508 'edges interface for SQLGraphClustered'
1509 _edgeDictClass = list
1510 _pickleAttrs = SQLGraphClustered._pickleAttrs.copy()
1511 _pickleAttrs.update(dict(graph=0))
1512 def keys(self):
1513 self.load()
1514 result = []
1515 for edge_id,l in self.d.iteritems():
1516 for source_id,target_id in l:
1517 result.append((self.graph.unpack_source(source_id),
1518 self.graph.unpack_target(target_id),
1519 self.graph.unpack_edge(edge_id)))
1520 return result
1522 class ForeignKeyInverse(object):
1523 'map each key to a single value according to its foreign key'
1524 def __init__(self,g):
1525 self.g = g
1526 def __getitem__(self,obj):
1527 self.check_obj(obj)
1528 source_id = getattr(obj,self.g.keyColumn)
1529 if source_id is None:
1530 return None
1531 return self.g.sourceDB[source_id]
1532 def __setitem__(self,obj,source):
1533 self.check_obj(obj)
1534 if source is not None:
1535 self.g[source][obj] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1536 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1537 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1538 old_source = self[obj]
1539 if old_source is not None:
1540 del self.g[old_source][obj]
1541 def check_obj(self,obj):
1542 'raise KeyError if obj not from this db'
1543 try:
1544 if obj.db is not self.g.targetDB:
1545 raise AttributeError
1546 except AttributeError:
1547 raise KeyError('key is not from targetDB of this graph!')
1548 def __contains__(self,obj):
1549 try:
1550 self.check_obj(obj)
1551 return True
1552 except KeyError:
1553 return False
1554 def __iter__(self):
1555 return self.g.targetDB.itervalues()
1556 def keys(self):
1557 return self.g.targetDB.values()
1558 def iteritems(self):
1559 for obj in self:
1560 source_id = getattr(obj,self.g.keyColumn)
1561 if source_id is None:
1562 yield obj,None
1563 else:
1564 yield obj,self.g.sourceDB[source_id]
1565 def items(self):
1566 return list(self.iteritems())
1567 def itervalues(self):
1568 for obj,val in self.iteritems():
1569 yield val
1570 def values(self):
1571 return list(self.itervalues())
1572 def __invert__(self):
1573 return self.g
1575 class ForeignKeyEdge(dict):
1576 '''edge interface to a foreign key in an SQL table.
1577 Caches dict of target nodes in itself; provides dict interface.
1578 Adds or deletes edges by setting foreign key values in the table'''
1579 def __init__(self,g,k):
1580 dict.__init__(self)
1581 self.g = g
1582 self.src = k
1583 for v in g.targetDB.select('where %s=%%s' % g.keyColumn,(k.id,)): # SEARCH THE DB
1584 dict.__setitem__(self,v,None) # SAVE IN CACHE
1585 def __setitem__(self,dest,v):
1586 if not hasattr(dest,'db') or dest.db is not self.g.targetDB:
1587 raise KeyError('dest is not in the targetDB bound to this graph!')
1588 if v is not None:
1589 raise ValueError('sorry,this graph cannot store edge information!')
1590 if not hasattr(dest,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1591 old_source = self.g._inverse[dest] # CHECK FOR PRE-EXISTING EDGE
1592 if old_source is not None: # REMOVE OLD EDGE FROM CACHE
1593 dict.__delitem__(self.g[old_source],dest)
1594 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1595 setattr(dest,self.g.keyColumn,self.src.id) # SAVE TO DB ATTRIBUTE
1596 dict.__setitem__(self,dest,None) # SAVE IN CACHE
1597 def __delitem__(self,dest):
1598 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1599 setattr(dest,self.g.keyColumn,None) # SAVE TO DB ATTRIBUTE
1600 dict.__delitem__(self,dest) # REMOVE FROM CACHE
1602 class ForeignKeyGraph(object, UserDict.DictMixin):
1603 '''graph interface to a foreign key in an SQL table
1604 Caches dict of target nodes in itself; provides dict interface.
1606 def __init__(self, sourceDB, targetDB, keyColumn, autoGC=True, **kwargs):
1607 '''sourceDB is any database of source nodes;
1608 targetDB must be an SQL database of target nodes;
1609 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1610 if autoGC: # automatically garbage collect unused objects
1611 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
1612 else:
1613 self._weakValueDict = {}
1614 self.autoGC = autoGC
1615 self.sourceDB = sourceDB
1616 self.targetDB = targetDB
1617 self.keyColumn = keyColumn
1618 self._inverse = ForeignKeyInverse(self)
1619 _pickleAttrs = dict(sourceDB=0, targetDB=0, keyColumn=0, autoGC=0)
1620 __getstate__ = standard_getstate ########### SUPPORT FOR PICKLING
1621 __setstate__ = standard_setstate
1622 def _inverse_schema(self):
1623 'provide custom schema rule for inverting this graph... just use keyColumn!'
1624 return dict(invert=True,uniqueMapping=True)
1625 def __getitem__(self,k):
1626 if not hasattr(k,'db') or k.db is not self.sourceDB:
1627 raise KeyError('object is not in the sourceDB bound to this graph!')
1628 try:
1629 return self._weakValueDict[k.id] # get from cache
1630 except KeyError:
1631 pass
1632 d = ForeignKeyEdge(self,k)
1633 self._weakValueDict[k.id] = d # save in cache
1634 return d
1635 def __setitem__(self, k, v):
1636 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1637 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1638 def __delitem__(self, k):
1639 raise KeyError('''Instead of del g[k], follow a graph
1640 interface: del g[src][dest]''')
1641 def keys(self):
1642 return self.sourceDB.values()
1643 __invert__ = standard_invert
1645 def describeDBTables(name,cursor,idDict):
1647 Get table info about database <name> via <cursor>, and store primary keys
1648 in idDict, along with a list of the tables each key indexes.
1650 cursor.execute('use %s' % name)
1651 cursor.execute('show tables')
1652 tables={}
1653 l=[c[0] for c in cursor.fetchall()]
1654 for t in l:
1655 tname=name+'.'+t
1656 o=SQLTable(tname,cursor)
1657 tables[tname]=o
1658 for f in o.description:
1659 if f==o.primary_key:
1660 idDict.setdefault(f, []).append(o)
1661 elif f[-3:]=='_id' and f not in idDict:
1662 idDict[f]=[]
1663 return tables
1667 def indexIDs(tables,idDict=None):
1668 "Get an index of primary keys in the <tables> dictionary."
1669 if idDict==None:
1670 idDict={}
1671 for o in tables.values():
1672 if o.primary_key:
1673 if o.primary_key not in idDict:
1674 idDict[o.primary_key]=[]
1675 idDict[o.primary_key].append(o) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1676 for f in o.description:
1677 if f[-3:]=='_id' and f not in idDict:
1678 idDict[f]=[]
1679 return idDict
1683 def suffixSubset(tables,suffix):
1684 "Filter table index for those matching a specific suffix"
1685 subset={}
1686 for name,t in tables.items():
1687 if name.endswith(suffix):
1688 subset[name]=t
1689 return subset
1692 PRIMARY_KEY=1
1694 def graphDBTables(tables,idDict):
1695 g=dictgraph()
1696 for t in tables.values():
1697 for f in t.description:
1698 if f==t.primary_key:
1699 edgeInfo=PRIMARY_KEY
1700 else:
1701 edgeInfo=None
1702 g.setEdge(f,t,edgeInfo)
1703 g.setEdge(t,f,edgeInfo)
1704 return g
1706 SQLTypeTranslation= {types.StringType:'varchar(32)',
1707 types.IntType:'int',
1708 types.FloatType:'float'}
1710 def createTableFromRepr(rows,tableName,cursor,typeTranslation=None,
1711 optionalDict=None,indexDict=()):
1712 """Save rows into SQL tableName using cursor, with optional
1713 translations of columns to specific SQL types (specified
1714 by typeTranslation dict).
1715 - optionDict can specify columns that are allowed to be NULL.
1716 - indexDict can specify columns that must be indexed; columns
1717 whose names end in _id will be indexed by default.
1718 - rows must be an iterator which in turn returns dictionaries,
1719 each representing a tuple of values (indexed by their column
1720 names).
1722 try:
1723 row=rows.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1724 except StopIteration:
1725 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1726 try:
1727 createTableFromRow(cursor, tableName,row,typeTranslation,
1728 optionalDict,indexDict)
1729 except:
1730 pass
1731 storeRow(cursor,tableName,row) # SAVE OUR FIRST ROW
1732 for row in rows: # NOW SAVE ALL THE ROWS
1733 storeRow(cursor,tableName,row)
1735 def createTableFromRow(cursor, tableName, row,typeTranslation=None,
1736 optionalDict=None,indexDict=()):
1737 create_defs=[]
1738 for col,val in row.items(): # PREPARE SQL TYPES FOR COLUMNS
1739 coltype=None
1740 if typeTranslation!=None and col in typeTranslation:
1741 coltype=typeTranslation[col] # USER-SUPPLIED TRANSLATION
1742 elif type(val) in SQLTypeTranslation:
1743 coltype=SQLTypeTranslation[type(val)]
1744 else: # SEARCH FOR A COMPATIBLE TYPE
1745 for t in SQLTypeTranslation:
1746 if isinstance(val,t):
1747 coltype=SQLTypeTranslation[t]
1748 break
1749 if coltype==None:
1750 raise TypeError("Don't know SQL type to use for %s" % col)
1751 create_def='%s %s' %(col,coltype)
1752 if optionalDict==None or col not in optionalDict:
1753 create_def+=' not null'
1754 create_defs.append(create_def)
1755 for col in row: # CREATE INDEXES FOR ID COLUMNS
1756 if col[-3:]=='_id' or col in indexDict:
1757 create_defs.append('index(%s)' % col)
1758 cmd='create table if not exists %s (%s)' % (tableName,','.join(create_defs))
1759 cursor.execute(cmd) # CREATE THE TABLE IN THE DATABASE
1762 def storeRow(cursor, tableName, row):
1763 row_format=','.join(len(row)*['%s'])
1764 cmd='insert into %s values (%s)' % (tableName,row_format)
1765 cursor.execute(cmd,tuple(row.values()))
1767 def storeRowDelayed(cursor, tableName, row):
1768 row_format=','.join(len(row)*['%s'])
1769 cmd='insert delayed into %s values (%s)' % (tableName,row_format)
1770 cursor.execute(cmd,tuple(row.values()))
1773 class TableGroup(dict):
1774 'provide attribute access to dbname qualified tablenames'
1775 def __init__(self,db='test',suffix=None,**kw):
1776 dict.__init__(self)
1777 self.db=db
1778 if suffix is not None:
1779 self.suffix=suffix
1780 for k,v in kw.items():
1781 if v is not None and '.' not in v:
1782 v=self.db+'.'+v # ADD DATABASE NAME AS PREFIX
1783 self[k]=v
1784 def __getattr__(self,k):
1785 return self[k]
1787 def sqlite_connect(*args, **kwargs):
1788 sqlite = import_sqlite()
1789 connection = sqlite.connect(*args, **kwargs)
1790 cursor = connection.cursor()
1791 return connection, cursor
1793 class DBServerInfo(object):
1794 'picklable reference to a database server'
1795 def __init__(self, moduleName='MySQLdb', serverSideCursors=True,
1796 blockIterators=True, *args, **kwargs):
1797 try:
1798 self.__class__ = _DBServerModuleDict[moduleName]
1799 except KeyError:
1800 raise ValueError('Module name not found in _DBServerModuleDict: '\
1801 + moduleName)
1802 self.moduleName = moduleName
1803 self.args = args # connection arguments
1804 self.kwargs = kwargs
1805 self.serverSideCursors = serverSideCursors
1806 self.custom_iter_keys = blockIterators
1807 if self.serverSideCursors and not self.custom_iter_keys:
1808 raise ValueError('serverSideCursors=True requires blockIterators=True!')
1810 def cursor(self):
1811 """returns cursor associated with the DB server info (reused)"""
1812 try:
1813 return self._cursor
1814 except AttributeError:
1815 self._start_connection()
1816 return self._cursor
1818 def new_cursor(self, arraysize=None):
1819 """returns a NEW cursor; you must close it yourself! """
1820 if not hasattr(self, '_connection'):
1821 self._start_connection()
1822 cursor = self._connection.cursor()
1823 if arraysize is not None:
1824 cursor.arraysize = arraysize
1825 return cursor
1827 def close(self):
1828 """Close file containing this database"""
1829 self._cursor.close()
1830 self._connection.close()
1831 del self._cursor
1832 del self._connection
1834 def __getstate__(self):
1835 """return all picklable arguments"""
1836 return dict(args=self.args, kwargs=self.kwargs,
1837 moduleName=self.moduleName,
1838 serverSideCursors=self.serverSideCursors,
1839 custom_iter_keys=self.custom_iter_keys)
1842 class MySQLServerInfo(DBServerInfo):
1843 'customized for MySQLdb SSCursor support via new_cursor()'
1844 _serverType = 'mysql'
1845 def _start_connection(self):
1846 self._connection,self._cursor = mysql_connect(*self.args, **self.kwargs)
1847 def new_cursor(self, arraysize=None):
1848 'provide streaming cursor support'
1849 if not self.serverSideCursors: # use regular MySQLdb cursor
1850 return DBServerInfo.new_cursor(self, arraysize)
1851 try:
1852 conn = self._conn_sscursor
1853 except AttributeError:
1854 self._conn_sscursor,cursor = mysql_connect(useStreaming=True,
1855 *self.args, **self.kwargs)
1856 else:
1857 cursor = self._conn_sscursor.cursor()
1858 if arraysize is not None:
1859 cursor.arraysize = arraysize
1860 return cursor
1861 def close(self):
1862 DBServerInfo.close(self)
1863 try:
1864 self._conn_sscursor.close()
1865 del self._conn_sscursor
1866 except AttributeError:
1867 pass
1868 def iter_keys(self, db, cursor, map_f=iter,
1869 cache_f=lambda x:[t[0] for t in x], **kwargs):
1870 block_generator = BlockGenerator(db, cursor, **kwargs)
1871 try:
1872 cache_f = block_generator.cache_f
1873 except AttributeError:
1874 pass
1875 return db.generic_iterator(cursor=cursor, cache_f=cache_f,
1876 map_f=map_f, fetch_f=block_generator)
1878 class CursorCloser(object):
1879 """container for ensuring cursor.close() is called, when this obj deleted.
1880 For Python 2.5+, we could replace this with a try... finally clause
1881 in a generator function such as generic_iterator(); see PEP 342 or
1882 What's New in Python 2.5. """
1883 def __init__(self, cursor):
1884 self.cursor = cursor
1885 def __del__(self):
1886 self.cursor.close()
1888 class BlockGenerator(CursorCloser):
1889 'workaround for MySQLdb iteration horrible performance'
1890 def __init__(self, db, cursor, selectCols, whereClause='', **kwargs):
1891 self.db = db
1892 self.cursor = cursor
1893 self.selectCols = selectCols
1894 self.kwargs = kwargs
1895 self.whereClause = ''
1896 if kwargs['orderBy']: # use iterSQL/iterColumns for WHERE / SELECT
1897 self.whereSQL = db.iterSQL
1898 if selectCols == '*': # extracting all columns
1899 self.whereParams = [db.data[col] for col in db.iterColumns]
1900 else: # selectCols is single column
1901 iterColumns = list(db.iterColumns)
1902 try: # if selectCols in db.iterColumns, just use that
1903 i = iterColumns.index(selectCols)
1904 except ValueError: # have to append selectCols
1905 i = len(db.iterColumns)
1906 iterColumns += [selectCols]
1907 self.selectCols = ','.join(iterColumns)
1908 self.whereParams = range(len(db.iterColumns))
1909 if i > 0: # need to extract desired column
1910 self.cache_f = lambda x:[t[i] for t in x]
1911 else: # just use primary key
1912 self.whereSQL = 'WHERE %s>%%s' % self.db.primary_key
1913 self.whereParams = (db.data['id'],)
1914 self.params = ()
1915 self.done = False
1917 def __call__(self):
1918 'get the next block of data'
1919 if self.done:
1920 return ()
1921 self.db._select(self.whereClause, self.params, cursor=self.cursor,
1922 limit='LIMIT %s' % self.cursor.arraysize,
1923 selectCols=self.selectCols, **(self.kwargs))
1924 rows = self.cursor.fetchall()
1925 if len(rows) < self.cursor.arraysize: # iteration complete
1926 self.done = True
1927 return rows
1928 lastrow = rows[-1] # extract params from the last row in this block
1929 if len(lastrow) > 1:
1930 self.params = [lastrow[icol] for icol in self.whereParams]
1931 else:
1932 self.params = lastrow
1933 self.whereClause = self.whereSQL
1934 return rows
1938 class SQLiteServerInfo(DBServerInfo):
1939 """picklable reference to a sqlite database"""
1940 _serverType = 'sqlite'
1941 def __init__(self, database, *args, **kwargs):
1942 """Takes same arguments as sqlite3.connect()"""
1943 DBServerInfo.__init__(self, 'sqlite', # save abs path!
1944 database=SourceFileName(database),
1945 *args, **kwargs)
1946 def _start_connection(self):
1947 self._connection,self._cursor = sqlite_connect(*self.args, **self.kwargs)
1948 def __getstate__(self):
1949 database = self.kwargs.get('database', False) or self.args[0]
1950 if database == ':memory:':
1951 raise ValueError('SQLite in-memory database is not picklable!')
1952 return DBServerInfo.__getstate__(self)
1954 # list of DBServerInfo subclasses for different modules
1955 _DBServerModuleDict = dict(MySQLdb=MySQLServerInfo, sqlite=SQLiteServerInfo)
1958 class MapView(object, UserDict.DictMixin):
1959 'general purpose 1:1 mapping defined by any SQL query'
1960 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None,
1961 serverInfo=None, inverseSQL=None, **kwargs):
1962 self.sourceDB = sourceDB
1963 self.targetDB = targetDB
1964 self.viewSQL = viewSQL
1965 self.inverseSQL = inverseSQL
1966 if cursor is None:
1967 if serverInfo is not None: # get cursor from serverInfo
1968 cursor = serverInfo.cursor()
1969 else:
1970 try: # can we get it from our other db?
1971 serverInfo = sourceDB.serverInfo
1972 except AttributeError:
1973 raise ValueError('you must provide serverInfo or cursor!')
1974 else:
1975 cursor = serverInfo.cursor()
1976 self.cursor = cursor
1977 self.serverInfo = serverInfo
1978 self.get_sql_format(False) # get sql formatter for this db interface
1979 _schemaModuleDict = _schemaModuleDict # default module list
1980 get_sql_format = get_table_schema
1981 def __getitem__(self, k):
1982 if not hasattr(k,'db') or k.db is not self.sourceDB:
1983 raise KeyError('object is not in the sourceDB bound to this map!')
1984 sql,params = self._format_query(self.viewSQL, (k.id,))
1985 self.cursor.execute(sql, params) # formatted for this db interface
1986 t = self.cursor.fetchmany(2) # get at most two rows
1987 if len(t) != 1:
1988 raise KeyError('%s not found in MapView, or not unique'
1989 % str(k))
1990 return self.targetDB[t[0][0]] # get the corresponding object
1991 _pickleAttrs = dict(sourceDB=0, targetDB=0, viewSQL=0, serverInfo=0,
1992 inverseSQL=0)
1993 __getstate__ = standard_getstate
1994 __setstate__ = standard_setstate
1995 __setitem__ = __delitem__ = clear = pop = popitem = update = \
1996 setdefault = read_only_error
1997 def __iter__(self):
1998 'only yield sourceDB items that are actually in this mapping!'
1999 for k in self.sourceDB.itervalues():
2000 try:
2001 self[k]
2002 yield k
2003 except KeyError:
2004 pass
2005 def keys(self):
2006 return [k for k in self] # don't use list(self); causes infinite loop!
2007 def __invert__(self):
2008 try:
2009 return self._inverse
2010 except AttributeError:
2011 if self.inverseSQL is None:
2012 raise ValueError('this MapView has no inverseSQL!')
2013 self._inverse = self.__class__(self.targetDB, self.sourceDB,
2014 self.inverseSQL, self.cursor,
2015 serverInfo=self.serverInfo,
2016 inverseSQL=self.viewSQL)
2017 self._inverse._inverse = self
2018 return self._inverse
2020 class GraphViewEdgeDict(UserDict.DictMixin):
2021 'edge dictionary for GraphView: just pre-loaded on init'
2022 def __init__(self, g, k):
2023 self.g = g
2024 self.k = k
2025 sql,params = self.g._format_query(self.g.viewSQL, (k.id,))
2026 self.g.cursor.execute(sql, params) # run the query
2027 l = self.g.cursor.fetchall() # get results
2028 if len(l) <= 0:
2029 raise KeyError('key %s not in GraphView' % k.id)
2030 self.targets = [t[0] for t in l] # preserve order of the results
2031 d = {} # also keep targetID:edgeID mapping
2032 if self.g.edgeDB is not None: # save with edge info
2033 for t in l:
2034 d[t[0]] = t[1]
2035 else:
2036 for t in l:
2037 d[t[0]] = None
2038 self.targetDict = d
2039 def __len__(self):
2040 return len(self.targets)
2041 def __iter__(self):
2042 for k in self.targets:
2043 yield self.g.targetDB[k]
2044 def keys(self):
2045 return list(self)
2046 def iteritems(self):
2047 if self.g.edgeDB is not None: # save with edge info
2048 for k in self.targets:
2049 yield (self.g.targetDB[k], self.g.edgeDB[self.targetDict[k]])
2050 else: # just save the list of targets, no edge info
2051 for k in self.targets:
2052 yield (self.g.targetDB[k], None)
2053 def __getitem__(self, o, exitIfFound=False):
2054 'for the specified target object, return its associated edge object'
2055 try:
2056 if o.db is not self.g.targetDB:
2057 raise KeyError('key is not part of targetDB!')
2058 edgeID = self.targetDict[o.id]
2059 except AttributeError:
2060 raise KeyError('key has no id or db attribute?!')
2061 if exitIfFound:
2062 return
2063 if self.g.edgeDB is not None: # return the edge object
2064 return self.g.edgeDB[edgeID]
2065 else: # no edge info
2066 return None
2067 def __contains__(self, o):
2068 try:
2069 self.__getitem__(o, True) # raise KeyError if not found
2070 return True
2071 except KeyError:
2072 return False
2073 __setitem__ = __delitem__ = clear = pop = popitem = update = \
2074 setdefault = read_only_error
2076 class GraphView(MapView):
2077 'general purpose graph interface defined by any SQL query'
2078 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None, edgeDB=None,
2079 **kwargs):
2080 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
2081 self.edgeDB = edgeDB
2082 MapView.__init__(self, sourceDB, targetDB, viewSQL, cursor, **kwargs)
2083 def __getitem__(self, k):
2084 if not hasattr(k,'db') or k.db is not self.sourceDB:
2085 raise KeyError('object is not in the sourceDB bound to this map!')
2086 return GraphViewEdgeDict(self, k)
2087 _pickleAttrs = MapView._pickleAttrs.copy()
2088 _pickleAttrs.update(dict(edgeDB=0))
2090 # @CTB move to sqlgraph.py?
2092 class SQLSequence(SQLRow, SequenceBase):
2093 """Transparent access to a DB row representing a sequence.
2095 Use attrAlias dict to rename 'length' to something else.
2097 def _init_subclass(cls, db, **kwargs):
2098 db.seqInfoDict = db # db will act as its own seqInfoDict
2099 SQLRow._init_subclass(db=db, **kwargs)
2100 _init_subclass = classmethod(_init_subclass)
2101 def __init__(self, id):
2102 SQLRow.__init__(self, id)
2103 SequenceBase.__init__(self)
2104 def __len__(self):
2105 return self.length
2106 def strslice(self,start,end):
2107 "Efficient access to slice of a sequence, useful for huge contigs"
2108 return self._select('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
2109 %(self.db._attrSQL('seq'),start+1,end-start))
2111 class DNASQLSequence(SQLSequence):
2112 _seqtype=DNA_SEQTYPE
2114 class RNASQLSequence(SQLSequence):
2115 _seqtype=RNA_SEQTYPE
2117 class ProteinSQLSequence(SQLSequence):
2118 _seqtype=PROTEIN_SEQTYPE