first implementation of BlockGenerator iteration mechanism
[pygr.git] / pygr / sqlgraph.py
blob296f17379790e5168aaa8219fd9f22a27a64d8a5
3 from __future__ import generators
4 from mapping import *
5 from sequence import SequenceBase, DNA_SEQTYPE, RNA_SEQTYPE, PROTEIN_SEQTYPE
6 import types
7 from classutil import methodFactory,standard_getstate,\
8 override_rich_cmp,generate_items,get_bound_subclass,standard_setstate,\
9 get_valid_path,standard_invert,RecentValueDictionary,read_only_error,\
10 SourceFileName, split_kwargs
11 import os
12 import platform
13 import UserDict
14 import warnings
15 import logger
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self, db, attr):
20 self.icol = db.data[attr] # index of this attribute in the tuple
21 def __get__(self, obj, klass):
22 return obj._data[self.icol]
23 def __set__(self, obj, val):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor):
27 def __set__(self, obj, val):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor):
32 'read-write interface to named attribute'
33 def __init__(self, db, attr):
34 self.attr = attr
35 self.icol = db.data[attr] # index of this attribute in the tuple
36 self.attrSQL = db._attrSQL(attr, sqlColumn=True) # SQL column name
37 def __set__(self, obj, val):
38 obj.db._update(obj.id, self.attrSQL, val) # AND UPDATE THE DATABASE
39 obj.save_local(self.attr, val)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self, db, attr):
44 self.selectSQL = db._attrSQL(attr) # SQL expression for this attr
45 def __get__(self, obj, klass):
46 return obj._select(self.selectSQL)
47 def __set__(self, obj, val):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self, obj, val):
53 obj.db._update(obj.id, self.selectSQL, val) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self, db, attr):
58 self.attr = '_'+attr
59 def __get__(self, obj, klass):
60 return getattr(obj, self.attr)
61 def __set__(self, obj, val):
62 raise AttributeError('attribute %s is read-only!' % self.attr)
65 def select_from_row(row, what):
66 "return value from SQL expression applied to this row"
67 sql,params = row.db._format_query('select %s from %s where %s=%%s limit 2'
68 % (what,row.db.name,row.db.primary_key),
69 (row.id,))
70 row.db.cursor.execute(sql, params)
71 t = row.db.cursor.fetchmany(2) # get at most two rows
72 if len(t) != 1:
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row.db.name,str(row.id),what))
75 return t[0][0] #return the single field we requested
77 def init_row_subclass(cls, db):
78 'add descriptors for db attributes'
79 for attr in db.data: # bind all database columns
80 if attr == 'id': # handle ID attribute specially
81 setattr(cls, attr, cls._idDescriptor(db, attr))
82 continue
83 try: # check if this attr maps to an SQL column
84 db._attrSQL(attr, columnNumber=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls, attr, cls._sqlDescriptor(db, attr))
87 else: # treat as interface to our stored tuple
88 setattr(cls, attr, cls._columnDescriptor(db, attr))
90 def dir_row(self):
91 """get list of column names as our attributes """
92 return self.db.data.keys()
94 class TupleO(object):
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor = TupleDescriptor
115 _idDescriptor = TupleIDDescriptor
116 _sqlDescriptor = SQLDescriptor
117 _init_subclass = classmethod(init_row_subclass)
118 _select = select_from_row
119 __dir__ = dir_row
120 def __init__(self, data):
121 self._data = data # save our data tuple
123 def insert_and_cache_id(self, l, **kwargs):
124 'insert tuple into db and cache its rowID on self'
125 self.db._insert(l) # save to database
126 try:
127 rowID = kwargs['id'] # use the ID supplied by user
128 except KeyError:
129 rowID = self.db.get_insert_id() # get auto-inc ID value
130 self.cache_id(rowID) # cache this ID on self
132 class TupleORW(TupleO):
133 'read-write version of TupleO'
134 _columnDescriptor = TupleDescriptorRW
135 insert_and_cache_id = insert_and_cache_id
136 def __init__(self, data, newRow=False, **kwargs):
137 if not newRow: # just cache data from the database
138 self._data = data
139 return
140 self._data = self.db.tuple_from_dict(kwargs) # convert to tuple
141 self.insert_and_cache_id(self._data, **kwargs)
142 def cache_id(self,row_id):
143 self.save_local('id',row_id)
144 def save_local(self,attr,val):
145 icol = self._attrcol[attr]
146 try:
147 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self._data = list(self._data)
150 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
152 TupleO._RWClass = TupleORW # record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self, db, attr, readOnly = False):
157 self.attr = attr
158 self.col = db._attrSQL(attr, sqlColumn=True) # MAP THIS TO SQL COLUMN NAME
159 self.db = db
160 if readOnly:
161 self.__class__ = self._readOnlyClass
162 def __get__(self, obj, objtype):
163 try:
164 return obj.__dict__[self.attr]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self.col==self.db.primary_key:
167 raise AttributeError
168 self.db._select('where %s=%%s' % self.db.primary_key,(obj.id,),self.col)
169 l = self.db.cursor.fetchall()
170 if len(l)!=1:
171 raise AttributeError('db row not found or not unique!')
172 obj.__dict__[self.attr] = l[0][0] # UPDATE THE CACHE
173 return l[0][0]
174 def __set__(self, obj, val):
175 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self.db._update(obj.id, self.col, val) # UPDATE THE DATABASE
177 obj.__dict__[self.attr] = val # UPDATE THE CACHE
178 ## try:
179 ## m = self.consequences
180 ## except AttributeError:
181 ## return
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
185 ## import new
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor):
188 def __set__(self, obj, val):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor._readOnlyClass = ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor = _sqlDescriptor = SQLDescriptor
200 _idDescriptor = ReadOnlyDescriptor
201 _init_subclass = classmethod(init_row_subclass)
202 _select = select_from_row
203 __dir__ = dir_row
204 def __init__(self, rowID):
205 self._id = rowID
208 class SQLRowRW(SQLRow):
209 'read-write version of SQLRow'
210 _columnDescriptor = SQLDescriptorRW
211 insert_and_cache_id = insert_and_cache_id
212 def __init__(self, rowID, newRow=False, **kwargs):
213 if not newRow: # just cache data from the database
214 return self.cache_id(rowID)
215 l = self.db.tuple_from_dict(kwargs) # convert to tuple
216 self.insert_and_cache_id(l, **kwargs)
217 def cache_id(self, rowID):
218 self._id = rowID
220 SQLRow._RWClass = SQLRowRW
224 def list_to_dict(names, values):
225 'return dictionary of those named args that are present in values[]'
226 d={}
227 for i,v in enumerate(values):
228 try:
229 d[names[i]] = v
230 except IndexError:
231 break
232 return d
235 def get_name_cursor(name=None, **kwargs):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
239 if name is not None:
240 argList = name.split() # TREAT AS WS-SEPARATED LIST
241 if len(argList)>1:
242 name = argList[0] # USE 1ST ARG AS TABLE NAME
243 argnames = ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs = kwargs.copy() # a copy we can overwrite
245 kwargs.update(list_to_dict(argnames, argList[1:]))
246 serverInfo = DBServerInfo(**kwargs)
247 return name,serverInfo.cursor(),serverInfo
249 def mysql_connect(connect=None, configFile=None, useStreaming=False, **args):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs = args.copy() # a copy we can modify
252 if 'user' not in kwargs and configFile is None: #Find where config file is
253 osname = platform.system()
254 if osname in('Microsoft', 'Windows'): # Machine is a Windows box
255 paths = []
256 try: # handle case where WINDIR not defined by Windows...
257 windir = os.environ['WINDIR']
258 paths += [(windir, 'my.ini'), (windir, 'my.cnf')]
259 except KeyError:
260 pass
261 try:
262 sysdrv = os.environ['SYSTEMDRIVE']
263 paths += [(sysdrv, os.path.sep + 'my.ini'),
264 (sysdrv, os.path.sep + 'my.cnf')]
265 except KeyError:
266 pass
267 if len(paths) > 0:
268 configFile = get_valid_path(*paths)
269 else: # treat as normal platform with home directories
270 configFile = os.path.join(os.path.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile = configFile or os.path.join( os.getcwd(), 'mysql.cnf' )
276 if configFile and os.path.exists(configFile):
277 kwargs['read_default_file'] = configFile
278 connect = None # force it to use MySQLdb
279 if connect is None:
280 import MySQLdb
281 connect = MySQLdb.connect
282 kwargs['compress'] = True
283 if useStreaming: # use server side cursors for scalable result sets
284 try:
285 from MySQLdb import cursors
286 kwargs['cursorclass'] = cursors.SSCursor
287 except (ImportError, AttributeError):
288 pass
289 conn = connect(**kwargs)
290 cursor = conn.cursor()
291 return conn,cursor
293 _mysqlMacros = dict(IGNORE='ignore', REPLACE='replace',
294 AUTO_INCREMENT='AUTO_INCREMENT', SUBSTRING='substring',
295 SUBSTR_FROM='FROM', SUBSTR_FOR='FOR')
297 def mysql_table_schema(self, analyzeSchema=True):
298 'retrieve table schema from a MySQL database, save on self'
299 import MySQLdb
300 self._format_query = SQLFormatDict(MySQLdb.paramstyle, _mysqlMacros)
301 if not analyzeSchema:
302 return
303 self.clear_schema() # reset settings and dictionaries
304 self.cursor.execute('describe %s' % self.name) # get info about columns
305 columns = self.cursor.fetchall()
306 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
307 for icol,c in enumerate(columns):
308 field = c[0]
309 self.columnName.append(field) # list of columns in same order as table
310 if c[3] == "PRI": # record as primary key
311 if self.primary_key is None:
312 self.primary_key = field
313 else:
314 try:
315 self.primary_key.append(field)
316 except AttributeError:
317 self.primary_key = [self.primary_key,field]
318 if c[1][:3].lower() == 'int':
319 self.usesIntID = True
320 else:
321 self.usesIntID = False
322 elif c[3] == "MUL":
323 self.indexed[field] = icol
324 self.description[field] = self.cursor.description[icol]
325 self.columnType[field] = c[1] # SQL COLUMN TYPE
327 _sqliteMacros = dict(IGNORE='or ignore', REPLACE='insert or replace',
328 AUTO_INCREMENT='', SUBSTRING='substr',
329 SUBSTR_FROM=',', SUBSTR_FOR=',')
331 def import_sqlite():
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
333 try:
334 import sqlite3 as sqlite
335 except ImportError:
336 from pysqlite2 import dbapi2 as sqlite
337 return sqlite
339 def sqlite_table_schema(self, analyzeSchema=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite = import_sqlite()
342 self._format_query = SQLFormatDict(sqlite.paramstyle, _sqliteMacros)
343 if not analyzeSchema:
344 return
345 self.clear_schema() # reset settings and dictionaries
346 self.cursor.execute('PRAGMA table_info("%s")' % self.name)
347 columns = self.cursor.fetchall()
348 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
349 for icol,c in enumerate(columns):
350 field = c[1]
351 self.columnName.append(field) # list of columns in same order as table
352 self.description[field] = self.cursor.description[icol]
353 self.columnType[field] = c[2] # SQL COLUMN TYPE
354 self.cursor.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self.name) # get primary key / unique indexes
355 for indexname in self.cursor.fetchall(): # search indexes for primary key
356 self.cursor.execute('PRAGMA index_info("%s")' % indexname)
357 l = self.cursor.fetchall() # get list of columns in this index
358 if len(l) == 1: # assume 1st single-column unique index is primary key!
359 self.primary_key = l[0][2]
360 break # done searching for primary key!
361 if self.primary_key is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self.cursor.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self.name)
363 sql = self.cursor.fetchall()[0][0]
364 for columnSQL in sql[sql.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL.lower(): # must be the primary key!
366 col = columnSQL.split()[0] # get column name
367 if col in self.columnType:
368 self.primary_key = col
369 break # done searching for primary key!
370 else:
371 raise ValueError('unknown primary key %s in table %s'
372 % (col,self.name))
373 if self.primary_key is not None: # check its type
374 if self.columnType[self.primary_key] == 'int' or \
375 self.columnType[self.primary_key] == 'integer':
376 self.usesIntID = True
377 else:
378 self.usesIntID = False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats = dict(pyformat='%%(%d)s', numeric=':%d', named=':%d',
393 qmark='(ignore)', format='(ignore)')
394 def __init__(self, paramstyle, substitutionDict={}):
395 self.substitutionDict = substitutionDict.copy()
396 self.paramstyle = paramstyle
397 self.paramFormat = self._paramFormats[paramstyle]
398 self.makeDict = (paramstyle == 'pyformat' or paramstyle == 'named')
399 if paramstyle == 'qmark': # handle these as simple substitution
400 self.substitutionDict['?'] = '?'
401 elif paramstyle == 'format':
402 self.substitutionDict['?'] = '%s'
403 def __getitem__(self, k):
404 'apply correct substitution for this SQL interface'
405 try:
406 return self.substitutionDict[k] # apply our substitutions
407 except KeyError:
408 pass
409 if k == '?': # sequential parameter
410 s = self.paramFormat % self.iparam
411 self.iparam += 1 # advance to the next parameter
412 return s
413 raise KeyError('unknown macro: %s' % k)
414 def __call__(self, sql, paramList):
415 'returns corrected sql,params for this interface'
416 self.iparam = 1 # DB-ABI param indexing begins at 1
417 sql = sql.replace('%s', '%(?)s') # convert format into pyformat
418 s = sql % self # apply all %(x)s replacements in sql
419 if self.makeDict: # construct a params dict
420 paramDict = {}
421 for i,param in enumerate(paramList):
422 paramDict[str(i + 1)] = param #DB-ABI param indexing begins at 1
423 return s,paramDict
424 else: # just return the original params list
425 return s,paramList
427 def get_table_schema(self, analyzeSchema=True):
428 'run the right schema function based on type of db server connection'
429 try:
430 modname = self.cursor.__class__.__module__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
433 try:
434 schema_func = self._schemaModuleDict[modname]
435 except KeyError:
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname)
439 schema_func(self, analyzeSchema) # run the schema function
442 _schemaModuleDict = {'MySQLdb.cursors':mysql_table_schema,
443 'pysqlite2.dbapi2':sqlite_table_schema,
444 'sqlite3':sqlite_table_schema}
446 class SQLTableBase(object, UserDict.DictMixin):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict = _schemaModuleDict # default module list
449 get_table_schema = get_table_schema
450 def __init__(self,name,cursor=None,itemClass=None,attrAlias=None,
451 clusterKey=None,createTable=None,graph=None,maxCache=None,
452 arraysize=1024, itemSliceClass=None, dropIfExists=False,
453 serverInfo=None, autoGC=True, orderBy=None,
454 writeable=False, **kwargs):
455 if autoGC: # automatically garbage collect unused objects
456 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
457 else:
458 self._weakValueDict = {}
459 self.autoGC = autoGC
460 self.orderBy = orderBy
461 self.writeable = writeable
462 if cursor is None:
463 if serverInfo is not None: # get cursor from serverInfo
464 cursor = serverInfo.cursor()
465 else: # try to read connection info from name or config file
466 name,cursor,serverInfo = get_name_cursor(name,**kwargs)
467 else:
468 warnings.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
469 DeprecationWarning, stacklevel=2)
470 self.cursor = cursor
471 if createTable is not None: # RUN COMMAND TO CREATE THIS TABLE
472 if dropIfExists: # get rid of any existing table
473 cursor.execute('drop table if exists ' + name)
474 self.get_table_schema(False) # check dbtype, init _format_query
475 sql,params = self._format_query(createTable, ()) # apply macros
476 cursor.execute(sql) # create the table
477 self.name = name
478 if graph is not None:
479 self.graph = graph
480 if maxCache is not None:
481 self.maxCache = maxCache
482 if arraysize is not None:
483 self.arraysize = arraysize
484 cursor.arraysize = arraysize
485 self.get_table_schema() # get schema of columns to serve as attrs
486 self.data = {} # map of all attributes, including aliases
487 for icol,field in enumerate(self.columnName):
488 self.data[field] = icol # 1st add mappings to columns
489 try:
490 self.data['id']=self.data[self.primary_key]
491 except (KeyError,TypeError):
492 pass
493 if hasattr(self,'_attr_alias'): # apply attribute aliases for this class
494 self.addAttrAlias(False,**self._attr_alias)
495 self.objclass(itemClass) # NEED TO SUBCLASS OUR ITEM CLASS
496 if itemSliceClass is not None:
497 self.itemSliceClass = itemSliceClass
498 get_bound_subclass(self, 'itemSliceClass', self.name) # need to subclass itemSliceClass
499 if attrAlias is not None: # ADD ATTRIBUTE ALIASES
500 self.attrAlias = attrAlias # RECORD FOR PICKLING PURPOSES
501 self.data.update(attrAlias)
502 if clusterKey is not None:
503 self.clusterKey=clusterKey
504 if serverInfo is not None:
505 self.serverInfo = serverInfo
507 def __len__(self):
508 self._select(selectCols='count(*)')
509 return self.cursor.fetchone()[0]
510 def __hash__(self):
511 return id(self)
512 _pickleAttrs = dict(name=0, clusterKey=0, maxCache=0, arraysize=0,
513 attrAlias=0, serverInfo=0, autoGC=0, orderBy=0,
514 writeable=0)
515 __getstate__ = standard_getstate
516 def __setstate__(self,state):
517 # default cursor provisioning by worldbase is deprecated!
518 ## if 'serverInfo' not in state: # hmm, no address for db server?
519 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
520 ## from Data import getResource
521 ## state['cursor'] = getResource.getTableCursor(state['name'])
522 ## except ImportError:
523 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
524 self.__init__(**state)
525 def __repr__(self):
526 return '<SQL table '+self.name+'>'
528 def clear_schema(self):
529 'reset all schema information for this table'
530 self.description={}
531 self.columnName = []
532 self.columnType = {}
533 self.usesIntID = None
534 self.primary_key = None
535 self.indexed = {}
536 def _attrSQL(self,attr,sqlColumn=False,columnNumber=False):
537 "Translate python attribute name to appropriate SQL expression"
538 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
539 field=self.data[attr]
540 except KeyError:
541 raise AttributeError('attribute %s not a valid column or alias in %s'
542 % (attr,self.name))
543 if sqlColumn: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
544 try: # CHECK IF field IS COLUMN NUMBER
545 return self.columnName[field] # RETURN SQL COLUMN NAME
546 except TypeError:
547 try: # CHECK IF field IS SQL COLUMN NAME
548 return self.columnName[self.data[field]] # THIS WILL JUST RETURN field...
549 except (KeyError,TypeError):
550 raise AttributeError('attribute %s does not map to an SQL column in %s'
551 % (attr,self.name))
552 if columnNumber:
553 try: # CHECK IF field IS A COLUMN NUMBER
554 return field+0 # ONLY RETURN AN INTEGER
555 except TypeError:
556 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
557 return self.data[field]+0 # ONLY RETURN AN INTEGER
558 except (KeyError,TypeError):
559 raise ValueError('attribute %s does not map to a SQL column!' % attr)
560 if isinstance(field,types.StringType):
561 attr=field # USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
562 elif attr=='id':
563 attr=self.primary_key
564 return attr
565 def addAttrAlias(self,saveToPickle=True,**kwargs):
566 """Add new attributes as aliases of existing attributes.
567 They can be specified either as named args:
568 t.addAttrAlias(newattr=oldattr)
569 or by passing a dictionary kwargs whose keys are newattr
570 and values are oldattr:
571 t.addAttrAlias(**kwargs)
572 saveToPickle=True forces these aliases to be saved if object is pickled.
574 if saveToPickle:
575 self.attrAlias.update(kwargs)
576 for key,val in kwargs.items():
577 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
578 self.data[val]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
579 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
580 except TypeError: # val IS ITSELF AN ALIAS
581 self.data[key] = self.data[val] # SO MAP TO WHAT IT MAPS TO
582 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
583 self.data[key] = val
584 def objclass(self,oclass=None):
585 "Create class representing a row in this table by subclassing oclass, adding data"
586 if oclass is not None: # use this as our base itemClass
587 self.itemClass = oclass
588 if self.writeable:
589 self.itemClass = self.itemClass._RWClass # use its writeable version
590 oclass = get_bound_subclass(self, 'itemClass', self.name,
591 subclassArgs=dict(db=self)) # bind itemClass
592 if issubclass(oclass, TupleO):
593 oclass._attrcol = self.data # BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
594 if hasattr(oclass,'_tableclass') and not isinstance(self,oclass._tableclass):
595 self.__class__=oclass._tableclass # ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
596 def _select(self, whereClause='', params=(), selectCols='t1.*',
597 cursor=None, orderBy='', limit=''):
598 'execute the specified query but do not fetch'
599 sql,params = self._format_query('select %s from %s t1 %s %s %s'
600 % (selectCols, self.name, whereClause, orderBy,
601 limit), params)
602 if cursor is None:
603 self.cursor.execute(sql, params)
604 else:
605 cursor.execute(sql, params)
606 def select(self,whereClause,params=None,oclass=None,selectCols='t1.*'):
607 "Generate the list of objects that satisfy the database SELECT"
608 if oclass is None:
609 oclass=self.itemClass
610 self._select(whereClause,params,selectCols)
611 l=self.cursor.fetchall()
612 for t in l:
613 yield self.cacheItem(t,oclass)
614 def query(self,**kwargs):
615 'query for intersection of all specified kwargs, returned as iterator'
616 criteria = []
617 params = []
618 for k,v in kwargs.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
619 if v is None: # CONVERT TO SQL NULL TEST
620 criteria.append('%s IS NULL' % self._attrSQL(k))
621 else: # TEST FOR EQUALITY
622 criteria.append('%s=%%s' % self._attrSQL(k))
623 params.append(v)
624 return self.select('where '+' and '.join(criteria),params)
625 def _update(self,row_id,col,val):
626 'update a single field in the specified row to the specified value'
627 sql,params = self._format_query('update %s set %s=%%s where %s=%%s'
628 %(self.name,col,self.primary_key),
629 (val,row_id))
630 self.cursor.execute(sql, params)
631 def getID(self,t):
632 try:
633 return t[self.data['id']] # GET ID FROM TUPLE
634 except TypeError: # treat as alias
635 return t[self.data[self.data['id']]]
636 def cacheItem(self,t,oclass):
637 'get obj from cache if possible, or construct from tuple'
638 try:
639 id=self.getID(t)
640 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
641 return oclass(t)
642 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
643 return self._weakValueDict[id]
644 except KeyError:
645 pass
646 o = oclass(t)
647 self._weakValueDict[id] = o # CACHE THIS ITEM IN OUR DICTIONARY
648 return o
649 def cache_items(self,rows,oclass=None):
650 if oclass is None:
651 oclass=self.itemClass
652 for t in rows:
653 yield self.cacheItem(t,oclass)
654 def foreignKey(self,attr,k):
655 'get iterator for objects with specified foreign key value'
656 return self.select('where %s=%%s'%attr,(k,))
657 def limit_cache(self):
658 'APPLY maxCache LIMIT TO CACHE SIZE'
659 try:
660 if self.maxCache<len(self._weakValueDict):
661 self._weakValueDict.clear()
662 except AttributeError:
663 pass
665 def get_new_cursor(self):
666 """Return a new cursor object, or None if not possible """
667 try:
668 new_cursor = self.serverInfo.new_cursor
669 except AttributeError:
670 return None
671 return new_cursor(self.arraysize)
673 def generic_iterator(self, cursor=None, fetch_f=None, cache_f=None,
674 map_f=iter):
675 'generic iterator that runs fetch, cache and map functions'
676 if fetch_f is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
677 if cursor is None:
678 fetch_f = self.cursor.fetchmany
679 else: # isolate this iter from other queries
680 fetch_f = cursor.fetchmany
681 if cache_f is None:
682 cache_f = self.cache_items
683 while True:
684 self.limit_cache()
685 rows = fetch_f() # FETCH THE NEXT SET OF ROWS
686 if len(rows)==0: # NO MORE DATA SO ALL DONE
687 break
688 for v in map_f(cache_f(rows)): # CACHE AND GENERATE RESULTS
689 yield v
690 if cursor is not None: # close iterator now that we're done
691 cursor.close()
692 def tuple_from_dict(self, d):
693 'transform kwarg dict into tuple for storing in database'
694 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
695 for col,icol in self.data.items():
696 try:
697 l[icol] = d[col]
698 except (KeyError,TypeError):
699 pass
700 return l
701 def tuple_from_obj(self, obj):
702 'transform object attributes into tuple for storing in database'
703 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
704 for col,icol in self.data.items():
705 try:
706 l[icol] = getattr(obj,col)
707 except (AttributeError,TypeError):
708 pass
709 return l
710 def _insert(self, l):
711 '''insert tuple into the database. Note this uses the MySQL
712 extension REPLACE, which overwrites any duplicate key.'''
713 s = '%(REPLACE)s into ' + self.name + ' values (' \
714 + ','.join(['%s']*len(l)) + ')'
715 sql,params = self._format_query(s, l)
716 self.cursor.execute(sql, params)
717 def insert(self, obj):
718 '''insert new row by transforming obj to tuple of values'''
719 l = self.tuple_from_obj(obj)
720 self._insert(l)
721 def get_insert_id(self):
722 'get the primary key value for the last INSERT'
723 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
724 auto_id = self.cursor.lastrowid
725 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
726 raise NotImplementedError('''your db lacks lastrowid support?''')
727 if auto_id is None:
728 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
729 return auto_id
730 def new(self, **kwargs):
731 'return a new record with the assigned attributes, added to DB'
732 if not self.writeable:
733 raise ValueError('this database is read only!')
734 obj = self.itemClass(None, newRow=True, **kwargs) # saves itself to db
735 self._weakValueDict[obj.id] = obj # AND SAVE TO OUR LOCAL DICT CACHE
736 return obj
737 def clear_cache(self):
738 'empty the cache'
739 self._weakValueDict.clear()
740 def __delitem__(self, k):
741 if not self.writeable:
742 raise ValueError('this database is read only!')
743 sql,params = self._format_query('delete from %s where %s=%%s'
744 % (self.name,self.primary_key),(k,))
745 self.cursor.execute(sql, params)
746 try:
747 del self._weakValueDict[k]
748 except KeyError:
749 pass
751 def getKeys(self,queryOption='', selectCols=None):
752 'uses db select; does not force load'
753 if selectCols is None:
754 selectCols=self.primary_key
755 if queryOption=='' and self.orderBy is not None:
756 queryOption = self.orderBy # apply default ordering
757 self.cursor.execute('select %s from %s %s'
758 %(selectCols,self.name,queryOption))
759 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
761 def iter_keys(self, selectCols=None, orderBy='', map_f=iter,
762 cache_f=lambda x:[t[0] for t in x], get_f=None, **kwargs):
763 'guarantee correct iteration insulated from other queries'
764 if selectCols is None:
765 selectCols=self.primary_key
766 if orderBy=='' and self.orderBy is not None:
767 orderBy = self.orderBy # apply default ordering
768 cursor = self.get_new_cursor()
769 if cursor: # got our own cursor, guaranteeing query isolation
770 self._select(cursor=cursor, selectCols=selectCols, orderBy=orderBy,
771 **kwargs)
772 return self.generic_iterator(cursor=cursor, cache_f=cache_f, map_f=map_f)
773 else: # must pre-fetch all keys to ensure query isolation
774 if get_f is not None:
775 return iter(get_f())
776 else:
777 return iter(self.keys())
779 class SQLTable(SQLTableBase):
780 "Provide on-the-fly access to rows in the database, caching the results in dict"
781 itemClass = TupleO # our default itemClass; constructor can override
782 keys=getKeys
783 __iter__ = iter_keys
784 def load(self,oclass=None):
785 "Load all data from the table"
786 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
787 return self._isLoaded
788 except AttributeError:
789 pass
790 if oclass is None:
791 oclass=self.itemClass
792 self.cursor.execute('select * from %s' % self.name)
793 l=self.cursor.fetchall()
794 self._weakValueDict = {} # just store the whole dataset in memory
795 for t in l:
796 self.cacheItem(t,oclass) # CACHE IT IN LOCAL DICTIONARY
797 self._isLoaded=True # MARK THIS CONTAINER AS FULLY LOADED
799 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
800 try:
801 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
802 except KeyError: # NOT FOUND, SO TRY THE DATABASE
803 sql,params = self._format_query('select * from %s where %s=%%s limit 2'
804 % (self.name,self.primary_key),(k,))
805 self.cursor.execute(sql, params)
806 l = self.cursor.fetchmany(2) # get at most 2 rows
807 if len(l) != 1:
808 raise KeyError('%s not found in %s, or not unique' %(str(k),self.name))
809 self.limit_cache()
810 return self.cacheItem(l[0],self.itemClass) # CACHE IT IN LOCAL DICTIONARY
811 def __setitem__(self, k, v):
812 if not self.writeable:
813 raise ValueError('this database is read only!')
814 try:
815 if v.db != self:
816 raise AttributeError
817 except AttributeError:
818 raise ValueError('object not bound to itemClass for this db!')
819 try:
820 oldID = v.id
821 if oldID is None:
822 raise AttributeError
823 except AttributeError:
824 pass
825 else: # delete row with old ID
826 del self[v.id]
827 v.cache_id(k) # cache the new ID on the object
828 self.insert(v) # SAVE TO THE RELATIONAL DB SERVER
829 self._weakValueDict[k] = v # CACHE THIS ITEM IN OUR DICTIONARY
830 def items(self):
831 'forces load of entire table into memory'
832 self.load()
833 return [(k,self[k]) for k in self] # apply orderBy rules...
834 def iteritems(self):
835 'uses arraysize / maxCache and fetchmany() to manage data transfer'
836 return iter_keys(self, selectCols='*', cache_f=None,
837 map_f=generate_items, get_f=self.items)
838 def values(self):
839 'forces load of entire table into memory'
840 self.load()
841 return [self[k] for k in self] # apply orderBy rules...
842 def itervalues(self):
843 'uses arraysize / maxCache and fetchmany() to manage data transfer'
844 return iter_keys(self, selectCols='*', cache_f=None, get_f=self.values)
846 def getClusterKeys(self,queryOption=''):
847 'uses db select; does not force load'
848 self.cursor.execute('select distinct %s from %s %s'
849 %(self.clusterKey,self.name,queryOption))
850 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
853 class SQLTableClustered(SQLTable):
854 '''use clusterKey to load a whole cluster of rows at once,
855 specifically, all rows that share the same clusterKey value.'''
856 def __init__(self, *args, **kwargs):
857 kwargs = kwargs.copy() # get a copy we can alter
858 kwargs['autoGC'] = False # don't use WeakValueDictionary
859 SQLTable.__init__(self, *args, **kwargs)
860 def keys(self):
861 return getKeys(self,'order by %s' %self.clusterKey)
862 def clusterkeys(self):
863 return getClusterKeys(self, 'order by %s' %self.clusterKey)
864 def __getitem__(self,k):
865 try:
866 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
867 except KeyError: # NOT FOUND, SO TRY THE DATABASE
868 sql,params = self._format_query('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
869 % (self.name,self.name,self.primary_key,
870 self.clusterKey,self.clusterKey),(k,))
871 self.cursor.execute(sql, params)
872 l=self.cursor.fetchall()
873 self.limit_cache()
874 for t in l: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
875 self.cacheItem(t,self.itemClass)
876 return self._weakValueDict[k] # should be in cache, if row k exists
877 def itercluster(self,cluster_id):
878 'iterate over all items from the specified cluster'
879 self.limit_cache()
880 return self.select('where %s=%%s'%self.clusterKey,(cluster_id,))
881 def fetch_cluster(self):
882 'use self.cursor.fetchmany to obtain all rows for next cluster'
883 icol = self._attrSQL(self.clusterKey,columnNumber=True)
884 result = []
885 try:
886 rows = self._fetch_cluster_cache # USE SAVED ROWS FROM PREVIOUS CALL
887 del self._fetch_cluster_cache
888 except AttributeError:
889 rows = self.cursor.fetchmany()
890 try:
891 cluster_id = rows[0][icol]
892 except IndexError:
893 return result
894 while len(rows)>0:
895 for i,t in enumerate(rows): # CHECK THAT ALL ROWS FROM THIS CLUSTER
896 if cluster_id != t[icol]: # START OF A NEW CLUSTER
897 result += rows[:i] # RETURN ROWS OF CURRENT CLUSTER
898 self._fetch_cluster_cache = rows[i:] # SAVE NEXT CLUSTER
899 return result
900 result += rows
901 rows = self.cursor.fetchmany() # GET NEXT SET OF ROWS
902 return result
903 def itervalues(self):
904 'uses arraysize / maxCache and fetchmany() to manage data transfer'
905 cursor = self.get_new_cursor()
906 self._select('order by %s' %self.clusterKey, cursor=cursor)
907 return self.generic_iterator(cursor, self.fetch_cluster)
908 def iteritems(self):
909 'uses arraysize / maxCache and fetchmany() to manage data transfer'
910 cursor = self.get_new_cursor()
911 self._select('order by %s' %self.clusterKey, cursor=cursor)
912 return self.generic_iterator(cursor, self.fetch_cluster,
913 map_f=generate_items)
915 class SQLForeignRelation(object):
916 'mapping based on matching a foreign key in an SQL table'
917 def __init__(self,table,keyName):
918 self.table=table
919 self.keyName=keyName
920 def __getitem__(self,k):
921 'get list of objects o with getattr(o,keyName)==k.id'
922 l=[]
923 for o in self.table.select('where %s=%%s'%self.keyName,(k.id,)):
924 l.append(o)
925 if len(l)==0:
926 raise KeyError('%s not found in %s' %(str(k),self.name))
927 return l
930 class SQLTableNoCache(SQLTableBase):
931 '''Provide on-the-fly access to rows in the database;
932 values are simply an object interface (SQLRow) to back-end db query.
933 Row data are not stored locally, but always accessed by querying the db'''
934 itemClass=SQLRow # DEFAULT OBJECT CLASS FOR ROWS...
935 keys=getKeys
936 __iter__ = iter_keys
937 def getID(self,t): return t[0] # GET ID FROM TUPLE
938 def select(self,whereClause,params):
939 return SQLTableBase.select(self,whereClause,params,self.oclass,
940 self._attrSQL('id'))
941 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
942 try:
943 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
944 except KeyError: # NOT FOUND, SO TRY THE DATABASE
945 self._select('where %s=%%s' % self.primary_key, (k,),
946 self.primary_key)
947 t = self.cursor.fetchmany(2)
948 if len(t) != 1:
949 raise KeyError('id %s non-existent or not unique' % k)
950 o = self.itemClass(k) # create obj referencing this ID
951 self._weakValueDict[k] = o # cache the SQLRow object
952 return o
953 def __setitem__(self, k, v):
954 if not self.writeable:
955 raise ValueError('this database is read only!')
956 try:
957 if v.db != self:
958 raise AttributeError
959 except AttributeError:
960 raise ValueError('object not bound to itemClass for this db!')
961 try:
962 del self[k] # delete row with new ID if any
963 except KeyError:
964 pass
965 try:
966 del self._weakValueDict[v.id] # delete from old cache location
967 except KeyError:
968 pass
969 self._update(v.id, self.primary_key, k) # just change its ID in db
970 v.cache_id(k) # change the cached ID value
971 self._weakValueDict[k] = v # assign to new cache location
972 def addAttrAlias(self,**kwargs):
973 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
975 SQLRow._tableclass=SQLTableNoCache # SQLRow IS FOR NON-CACHING TABLE INTERFACE
978 class SQLTableMultiNoCache(SQLTableBase):
979 "Trivial on-the-fly access for table with key that returns multiple rows"
980 itemClass = TupleO # default itemClass; constructor can override
981 _distinct_key='id' # DEFAULT COLUMN TO USE AS KEY
982 def keys(self):
983 return getKeys(self, selectCols='distinct(%s)'
984 % self._attrSQL(self._distinct_key))
985 def __iter__(self):
986 return iter_keys(self, 'distinct(%s)' % self._attrSQL(self._distinct_key))
987 def __getitem__(self,id):
988 sql,params = self._format_query('select * from %s where %s=%%s'
989 %(self.name,self._attrSQL(self._distinct_key)),(id,))
990 self.cursor.execute(sql, params)
991 l=self.cursor.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
992 for row in l:
993 yield self.itemClass(row)
994 def addAttrAlias(self,**kwargs):
995 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
999 class SQLEdges(SQLTableMultiNoCache):
1000 '''provide iterator over edges as (source,target,edge)
1001 and getitem[edge] --> [(source,target),...]'''
1002 _distinct_key='edge_id'
1003 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1004 _pickleAttrs.update(dict(graph=0))
1005 def keys(self):
1006 self.cursor.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1007 %(self._attrSQL('source_id'),self._attrSQL('target_id'),
1008 self._attrSQL('edge_id'),self.name,
1009 self._attrSQL('target_id'),self._attrSQL('source_id'),
1010 self._attrSQL('target_id')))
1011 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1012 for source_id,target_id,edge_id in self.cursor.fetchall():
1013 l.append((self.graph.unpack_source(source_id),
1014 self.graph.unpack_target(target_id),
1015 self.graph.unpack_edge(edge_id)))
1016 return l
1017 __call__=keys
1018 def __iter__(self):
1019 return iter(self.keys())
1020 def __getitem__(self,edge):
1021 sql,params = self._format_query('select %s,%s from %s where %s=%%s'
1022 %(self._attrSQL('source_id'),
1023 self._attrSQL('target_id'),
1024 self.name,
1025 self._attrSQL(self._distinct_key)),
1026 (self.graph.pack_edge(edge),))
1027 self.cursor.execute(sql, params)
1028 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1029 for source_id,target_id in self.cursor.fetchall():
1030 l.append((self.graph.unpack_source(source_id),
1031 self.graph.unpack_target(target_id)))
1032 return l
1035 class SQLEdgeDict(object):
1036 '2nd level graph interface to SQL database'
1037 def __init__(self,fromNode,table):
1038 self.fromNode=fromNode
1039 self.table=table
1040 if not hasattr(self.table,'allowMissingNodes'):
1041 sql,params = self.table._format_query('select %s from %s where %s=%%s limit 1'
1042 %(self.table.sourceSQL,
1043 self.table.name,
1044 self.table.sourceSQL),
1045 (self.fromNode,))
1046 self.table.cursor.execute(sql, params)
1047 if len(self.table.cursor.fetchall())<1:
1048 raise KeyError('node not in graph!')
1050 def __getitem__(self,target):
1051 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1052 %(self.table.edgeSQL,
1053 self.table.name,
1054 self.table.sourceSQL,
1055 self.table.targetSQL),
1056 (self.fromNode,
1057 self.table.pack_target(target)))
1058 self.table.cursor.execute(sql, params)
1059 l = self.table.cursor.fetchmany(2) # get at most two rows
1060 if len(l) != 1:
1061 raise KeyError('either no edge from source to target or not unique!')
1062 try:
1063 return self.table.unpack_edge(l[0][0]) # RETURN EDGE
1064 except IndexError:
1065 raise KeyError('no edge from node to target')
1066 def __setitem__(self,target,edge):
1067 sql,params = self.table._format_query('replace into %s values (%%s,%%s,%%s)'
1068 %self.table.name,
1069 (self.fromNode,
1070 self.table.pack_target(target),
1071 self.table.pack_edge(edge)))
1072 self.table.cursor.execute(sql, params)
1073 if not hasattr(self.table,'sourceDB') or \
1074 (hasattr(self.table,'targetDB') and self.table.sourceDB==self.table.targetDB):
1075 self.table += target # ADD AS NODE TO GRAPH
1076 def __iadd__(self,target):
1077 self[target] = None
1078 return self # iadd MUST RETURN self!
1079 def __delitem__(self,target):
1080 sql,params = self.table._format_query('delete from %s where %s=%%s and %s=%%s'
1081 %(self.table.name,
1082 self.table.sourceSQL,
1083 self.table.targetSQL),
1084 (self.fromNode,
1085 self.table.pack_target(target)))
1086 self.table.cursor.execute(sql, params)
1087 if self.table.cursor.rowcount < 1: # no rows deleted?
1088 raise KeyError('no edge from node to target')
1090 def iterator_query(self):
1091 sql,params = self.table._format_query('select %s,%s from %s where %s=%%s and %s is not null'
1092 %(self.table.targetSQL,
1093 self.table.edgeSQL,
1094 self.table.name,
1095 self.table.sourceSQL,
1096 self.table.targetSQL),
1097 (self.fromNode,))
1098 self.table.cursor.execute(sql, params)
1099 return self.table.cursor.fetchall()
1100 def keys(self):
1101 return [self.table.unpack_target(target_id)
1102 for target_id,edge_id in self.iterator_query()]
1103 def values(self):
1104 return [self.table.unpack_edge(edge_id)
1105 for target_id,edge_id in self.iterator_query()]
1106 def edges(self):
1107 return [(self.table.unpack_source(self.fromNode),self.table.unpack_target(target_id),
1108 self.table.unpack_edge(edge_id))
1109 for target_id,edge_id in self.iterator_query()]
1110 def items(self):
1111 return [(self.table.unpack_target(target_id),self.table.unpack_edge(edge_id))
1112 for target_id,edge_id in self.iterator_query()]
1113 def __iter__(self): return iter(self.keys())
1114 def itervalues(self): return iter(self.values())
1115 def iteritems(self): return iter(self.items())
1116 def __len__(self):
1117 return len(self.keys())
1118 __cmp__ = graph_cmp
1120 class SQLEdgelessDict(SQLEdgeDict):
1121 'for SQLGraph tables that lack edge_id column'
1122 def __getitem__(self,target):
1123 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1124 %(self.table.targetSQL,
1125 self.table.name,
1126 self.table.sourceSQL,
1127 self.table.targetSQL),
1128 (self.fromNode,
1129 self.table.pack_target(target)))
1130 self.table.cursor.execute(sql, params)
1131 l = self.table.cursor.fetchmany(2)
1132 if len(l) != 1:
1133 raise KeyError('either no edge from source to target or not unique!')
1134 return None # no edge info!
1135 def iterator_query(self):
1136 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s is not null'
1137 %(self.table.targetSQL,
1138 self.table.name,
1139 self.table.sourceSQL,
1140 self.table.targetSQL),
1141 (self.fromNode,))
1142 self.table.cursor.execute(sql, params)
1143 return [(t[0],None) for t in self.table.cursor.fetchall()]
1145 SQLEdgeDict._edgelessClass = SQLEdgelessDict
1147 class SQLGraphEdgeDescriptor(object):
1148 'provide an SQLEdges interface on demand'
1149 def __get__(self,obj,objtype):
1150 try:
1151 attrAlias=obj.attrAlias.copy()
1152 except AttributeError:
1153 return SQLEdges(obj.name, obj.cursor, graph=obj)
1154 else:
1155 return SQLEdges(obj.name, obj.cursor, attrAlias=attrAlias,
1156 graph=obj)
1158 def getColumnTypes(createTable,attrAlias={},defaultColumnType='int',
1159 columnAttrs=('source','target','edge'),**kwargs):
1160 'return list of [(colname,coltype),...] for source,target,edge'
1161 l = []
1162 for attr in columnAttrs:
1163 try:
1164 attrName = attrAlias[attr+'_id']
1165 except KeyError:
1166 attrName = attr+'_id'
1167 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1168 l.append((attrName,createTable[attr+'_id']))
1169 continue
1170 except (KeyError,TypeError):
1171 pass
1172 try: # get type info from primary key for that database
1173 db = kwargs[attr+'DB']
1174 if db is None:
1175 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1176 except KeyError:
1177 pass
1178 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1179 it = iter(db)
1180 try: # GET ONE IDENTIFIER FROM THE DATABASE
1181 k = it.next()
1182 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1183 try:
1184 l.append((attrName,db.columnType[db.primary_key]))
1185 continue
1186 except AttributeError:
1187 pass
1188 else: # GET THE TYPE FROM THIS IDENTIFIER
1189 if isinstance(k,int) or isinstance(k,long):
1190 l.append((attrName,'int'))
1191 continue
1192 elif isinstance(k,str):
1193 l.append((attrName,'varchar(32)'))
1194 continue
1195 else:
1196 raise ValueError('SQLGraph node / edge must be int or str!')
1197 l.append((attrName,defaultColumnType))
1198 logger.warn('no type info found for %s, so using default: %s'
1199 % (attrName, defaultColumnType))
1202 return l
1205 class SQLGraph(SQLTableMultiNoCache):
1206 '''provide a graph interface via a SQL table. Key capabilities are:
1207 - setitem with an empty dictionary: a dummy operation
1208 - getitem with a key that exists: return a placeholder
1209 - setitem with non empty placeholder: again a dummy operation
1210 EXAMPLE TABLE SCHEMA:
1211 create table mygraph (source_id int not null,target_id int,edge_id int,
1212 unique(source_id,target_id));
1214 _distinct_key='source_id'
1215 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1216 _pickleAttrs.update(dict(sourceDB=0,targetDB=0,edgeDB=0,allowMissingNodes=0))
1217 _edgeClass = SQLEdgeDict
1218 def __init__(self,name,*l,**kwargs):
1219 graphArgs,tableArgs = split_kwargs(kwargs,
1220 ('attrAlias','defaultColumnType','columnAttrs',
1221 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1222 'edgeDictClass','graph'))
1223 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1224 c = getColumnTypes(**kwargs)
1225 tableArgs['createTable'] = \
1226 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1227 % (name,c[0][0],c[0][1],c[1][0],c[1][1],c[2][0],c[2][1],c[0][0],c[1][0])
1228 try:
1229 self.allowMissingNodes = kwargs['allowMissingNodes']
1230 except KeyError: pass
1231 SQLTableMultiNoCache.__init__(self,name,*l,**tableArgs)
1232 self.sourceSQL = self._attrSQL('source_id')
1233 self.targetSQL = self._attrSQL('target_id')
1234 try:
1235 self.edgeSQL = self._attrSQL('edge_id')
1236 except AttributeError:
1237 self.edgeSQL = None
1238 self._edgeClass = self._edgeClass._edgelessClass
1239 save_graph_db_refs(self,**kwargs)
1240 def __getitem__(self,k):
1241 return self._edgeClass(self.pack_source(k),self)
1242 def __iadd__(self,k):
1243 sql,params = self._format_query('delete from %s where %s=%%s and %s is null'
1244 % (self.name,self.sourceSQL,self.targetSQL),
1245 (self.pack_source(k),))
1246 self.cursor.execute(sql, params)
1247 sql,params = self._format_query('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1248 % self.name,(self.pack_source(k),))
1249 self.cursor.execute(sql, params)
1250 return self # iadd MUST RETURN SELF!
1251 def __isub__(self,k):
1252 sql,params = self._format_query('delete from %s where %s=%%s'
1253 % (self.name,self.sourceSQL),
1254 (self.pack_source(k),))
1255 self.cursor.execute(sql, params)
1256 if self.cursor.rowcount == 0:
1257 raise KeyError('node not found in graph')
1258 return self # iadd MUST RETURN SELF!
1259 __setitem__ = graph_setitem
1260 def __contains__(self,k):
1261 sql,params = self._format_query('select * from %s where %s=%%s limit 1'
1262 %(self.name,self.sourceSQL),
1263 (self.pack_source(k),))
1264 self.cursor.execute(sql, params)
1265 l = self.cursor.fetchmany(2)
1266 return len(l) > 0
1267 def __invert__(self):
1268 'get an interface to the inverse graph mapping'
1269 try: # CACHED
1270 return self._inverse
1271 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1272 attrAlias = dict(source_id=self.targetSQL, # SWAP SOURCE & TARGET
1273 target_id=self.sourceSQL,
1274 edge_id=self.edgeSQL)
1275 if self.edgeSQL is None: # no edge interface
1276 del attrAlias['edge_id']
1277 self._inverse=SQLGraph(self.name,self.cursor,
1278 attrAlias=attrAlias,
1279 **graph_db_inverse_refs(self))
1280 self._inverse._inverse=self
1281 return self._inverse
1282 def __iter__(self):
1283 for k in SQLTableMultiNoCache.__iter__(self):
1284 yield self.unpack_source(k)
1285 def iteritems(self):
1286 for k in SQLTableMultiNoCache.__iter__(self):
1287 yield (self.unpack_source(k), self._edgeClass(k, self))
1288 def itervalues(self):
1289 for k in SQLTableMultiNoCache.__iter__(self):
1290 yield self._edgeClass(k, self)
1291 def keys(self):
1292 return [self.unpack_source(k) for k in SQLTableMultiNoCache.keys(self)]
1293 def values(self): return list(self.itervalues())
1294 def items(self): return list(self.iteritems())
1295 edges=SQLGraphEdgeDescriptor()
1296 update = update_graph
1297 def __len__(self):
1298 'get number of source nodes in graph'
1299 self.cursor.execute('select count(distinct %s) from %s'
1300 %(self.sourceSQL,self.name))
1301 return self.cursor.fetchone()[0]
1302 __cmp__ = graph_cmp
1303 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1304 ## def __cmp__(self,other):
1305 ## node = ()
1306 ## n = 0
1307 ## d = None
1308 ## it = iter(self.edges)
1309 ## while True:
1310 ## try:
1311 ## source,target,edge = it.next()
1312 ## except StopIteration:
1313 ## source = None
1314 ## if source!=node:
1315 ## if d is not None:
1316 ## diff = cmp(n_target,len(d))
1317 ## if diff!=0:
1318 ## return diff
1319 ## if source is None:
1320 ## break
1321 ## node = source
1322 ## n += 1 # COUNT SOURCE NODES
1323 ## n_target = 0
1324 ## try:
1325 ## d = other[node]
1326 ## except KeyError:
1327 ## return 1
1328 ## try:
1329 ## diff = cmp(edge,d[target])
1330 ## except KeyError:
1331 ## return 1
1332 ## if diff!=0:
1333 ## return diff
1334 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1335 ## return cmp(n,len(other))
1337 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1339 class SQLIDGraph(SQLGraph):
1340 add_trivial_packing_methods(locals())
1341 SQLGraph._IDGraphClass = SQLIDGraph
1345 class SQLEdgeDictClustered(dict):
1346 'simple cache for 2nd level dictionary of target_id:edge_id'
1347 def __init__(self,g,fromNode):
1348 self.g=g
1349 self.fromNode=fromNode
1350 dict.__init__(self)
1351 def __iadd__(self,l):
1352 for target_id,edge_id in l:
1353 dict.__setitem__(self,target_id,edge_id)
1354 return self # iadd MUST RETURN SELF!
1356 class SQLEdgesClusteredDescr(object):
1357 def __get__(self,obj,objtype):
1358 e=SQLEdgesClustered(obj.table,obj.edge_id,obj.source_id,obj.target_id,
1359 graph=obj,**graph_db_inverse_refs(obj,True))
1360 for source_id,d in obj.d.iteritems(): # COPY EDGE CACHE
1361 e.load([(edge_id,source_id,target_id)
1362 for (target_id,edge_id) in d.iteritems()])
1363 return e
1365 class SQLGraphClustered(object):
1366 'SQL graph with clustered caching -- loads an entire cluster at a time'
1367 _edgeDictClass=SQLEdgeDictClustered
1368 def __init__(self,table,source_id='source_id',target_id='target_id',
1369 edge_id='edge_id',clusterKey=None,**kwargs):
1370 import types
1371 if isinstance(table,types.StringType): # CREATE THE TABLE INTERFACE
1372 if clusterKey is None:
1373 raise ValueError('you must provide a clusterKey argument!')
1374 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1375 c = getColumnTypes(attrAlias=dict(source_id=source_id,target_id=target_id,
1376 edge_id=edge_id),**kwargs)
1377 kwargs['createTable'] = \
1378 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1379 % (table,c[0][0],c[0][1],c[1][0],c[1][1],
1380 c[2][0],c[2][1],c[0][0],c[1][0])
1381 table = SQLTableClustered(table,clusterKey=clusterKey,**kwargs)
1382 self.table=table
1383 self.source_id=source_id
1384 self.target_id=target_id
1385 self.edge_id=edge_id
1386 self.d={}
1387 save_graph_db_refs(self,**kwargs)
1388 _pickleAttrs = dict(table=0,source_id=0,target_id=0,edge_id=0,sourceDB=0,targetDB=0,
1389 edgeDB=0)
1390 def __getstate__(self):
1391 state = standard_getstate(self)
1392 state['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1393 return state
1394 def __getitem__(self,k):
1395 'get edgeDict for source node k, from cache or by loading its cluster'
1396 try: # GET DIRECTLY FROM CACHE
1397 return self.d[k]
1398 except KeyError:
1399 if hasattr(self,'_isLoaded'):
1400 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1401 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1402 sql,params = self.table._format_query('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1403 %(self.source_id,self.target_id,
1404 self.edge_id,self.table.name,
1405 self.table.name,self.source_id,
1406 self.table.clusterKey,self.table.clusterKey,
1407 self.table.primary_key),
1408 (self.pack_source(k),))
1409 self.table.cursor.execute(sql, params)
1410 self.load(self.table.cursor.fetchall()) # CACHE THIS CLUSTER
1411 return self.d[k] # RETURN EDGE DICT FOR THIS NODE
1412 def load(self,l=None,unpack=True):
1413 'load the specified rows (or all, if None provided) into local cache'
1414 if l is None:
1415 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1416 return self._isLoaded
1417 except AttributeError:
1418 pass
1419 self.table.cursor.execute('select %s,%s,%s from %s'
1420 %(self.source_id,self.target_id,
1421 self.edge_id,self.table.name))
1422 l=self.table.cursor.fetchall()
1423 self._isLoaded=True
1424 self.d.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1425 for source,target,edge in l: # SAVE TO OUR CACHE
1426 if unpack:
1427 source = self.unpack_source(source)
1428 target = self.unpack_target(target)
1429 edge = self.unpack_edge(edge)
1430 try:
1431 self.d[source] += [(target,edge)]
1432 except KeyError:
1433 d = self._edgeDictClass(self,source)
1434 d += [(target,edge)]
1435 self.d[source] = d
1436 def __invert__(self):
1437 'interface to reverse graph mapping'
1438 try:
1439 return self._inverse # INVERSE MAP ALREADY EXISTS
1440 except AttributeError:
1441 pass
1442 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1443 self._inverse=SQLGraphClustered(self.table,self.target_id,self.source_id,
1444 self.edge_id,**graph_db_inverse_refs(self))
1445 self._inverse._inverse=self
1446 for source,d in self.d.iteritems(): # INVERT OUR CACHE
1447 self._inverse.load([(target,source,edge)
1448 for (target,edge) in d.iteritems()],unpack=False)
1449 return self._inverse
1450 edges=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1451 update = update_graph
1452 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1453 def __iter__(self): ################# ITERATORS
1454 'uses db select; does not force load'
1455 return iter(self.keys())
1456 def keys(self):
1457 'uses db select; does not force load'
1458 self.table.cursor.execute('select distinct(%s) from %s'
1459 %(self.source_id,self.table.name))
1460 return [self.unpack_source(t[0])
1461 for t in self.table.cursor.fetchall()]
1462 methodFactory(['iteritems','items','itervalues','values'],
1463 'lambda self:(self.load(),self.d.%s())[1]',locals())
1464 def __contains__(self,k):
1465 try:
1466 x=self[k]
1467 return True
1468 except KeyError:
1469 return False
1471 class SQLIDGraphClustered(SQLGraphClustered):
1472 add_trivial_packing_methods(locals())
1473 SQLGraphClustered._IDGraphClass = SQLIDGraphClustered
1475 class SQLEdgesClustered(SQLGraphClustered):
1476 'edges interface for SQLGraphClustered'
1477 _edgeDictClass = list
1478 _pickleAttrs = SQLGraphClustered._pickleAttrs.copy()
1479 _pickleAttrs.update(dict(graph=0))
1480 def keys(self):
1481 self.load()
1482 result = []
1483 for edge_id,l in self.d.iteritems():
1484 for source_id,target_id in l:
1485 result.append((self.graph.unpack_source(source_id),
1486 self.graph.unpack_target(target_id),
1487 self.graph.unpack_edge(edge_id)))
1488 return result
1490 class ForeignKeyInverse(object):
1491 'map each key to a single value according to its foreign key'
1492 def __init__(self,g):
1493 self.g = g
1494 def __getitem__(self,obj):
1495 self.check_obj(obj)
1496 source_id = getattr(obj,self.g.keyColumn)
1497 if source_id is None:
1498 return None
1499 return self.g.sourceDB[source_id]
1500 def __setitem__(self,obj,source):
1501 self.check_obj(obj)
1502 if source is not None:
1503 self.g[source][obj] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1504 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1505 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1506 old_source = self[obj]
1507 if old_source is not None:
1508 del self.g[old_source][obj]
1509 def check_obj(self,obj):
1510 'raise KeyError if obj not from this db'
1511 try:
1512 if obj.db != self.g.targetDB:
1513 raise AttributeError
1514 except AttributeError:
1515 raise KeyError('key is not from targetDB of this graph!')
1516 def __contains__(self,obj):
1517 try:
1518 self.check_obj(obj)
1519 return True
1520 except KeyError:
1521 return False
1522 def __iter__(self):
1523 return self.g.targetDB.itervalues()
1524 def keys(self):
1525 return self.g.targetDB.values()
1526 def iteritems(self):
1527 for obj in self:
1528 source_id = getattr(obj,self.g.keyColumn)
1529 if source_id is None:
1530 yield obj,None
1531 else:
1532 yield obj,self.g.sourceDB[source_id]
1533 def items(self):
1534 return list(self.iteritems())
1535 def itervalues(self):
1536 for obj,val in self.iteritems():
1537 yield val
1538 def values(self):
1539 return list(self.itervalues())
1540 def __invert__(self):
1541 return self.g
1543 class ForeignKeyEdge(dict):
1544 '''edge interface to a foreign key in an SQL table.
1545 Caches dict of target nodes in itself; provides dict interface.
1546 Adds or deletes edges by setting foreign key values in the table'''
1547 def __init__(self,g,k):
1548 dict.__init__(self)
1549 self.g = g
1550 self.src = k
1551 for v in g.targetDB.select('where %s=%%s' % g.keyColumn,(k.id,)): # SEARCH THE DB
1552 dict.__setitem__(self,v,None) # SAVE IN CACHE
1553 def __setitem__(self,dest,v):
1554 if not hasattr(dest,'db') or dest.db != self.g.targetDB:
1555 raise KeyError('dest is not in the targetDB bound to this graph!')
1556 if v is not None:
1557 raise ValueError('sorry,this graph cannot store edge information!')
1558 if not hasattr(dest,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1559 old_source = self.g._inverse[dest] # CHECK FOR PRE-EXISTING EDGE
1560 if old_source is not None: # REMOVE OLD EDGE FROM CACHE
1561 dict.__delitem__(self.g[old_source],dest)
1562 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1563 setattr(dest,self.g.keyColumn,self.src.id) # SAVE TO DB ATTRIBUTE
1564 dict.__setitem__(self,dest,None) # SAVE IN CACHE
1565 def __delitem__(self,dest):
1566 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1567 setattr(dest,self.g.keyColumn,None) # SAVE TO DB ATTRIBUTE
1568 dict.__delitem__(self,dest) # REMOVE FROM CACHE
1570 class ForeignKeyGraph(object, UserDict.DictMixin):
1571 '''graph interface to a foreign key in an SQL table
1572 Caches dict of target nodes in itself; provides dict interface.
1574 def __init__(self, sourceDB, targetDB, keyColumn, autoGC=True, **kwargs):
1575 '''sourceDB is any database of source nodes;
1576 targetDB must be an SQL database of target nodes;
1577 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1578 if autoGC: # automatically garbage collect unused objects
1579 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
1580 else:
1581 self._weakValueDict = {}
1582 self.autoGC = autoGC
1583 self.sourceDB = sourceDB
1584 self.targetDB = targetDB
1585 self.keyColumn = keyColumn
1586 self._inverse = ForeignKeyInverse(self)
1587 _pickleAttrs = dict(sourceDB=0, targetDB=0, keyColumn=0, autoGC=0)
1588 __getstate__ = standard_getstate ########### SUPPORT FOR PICKLING
1589 __setstate__ = standard_setstate
1590 def _inverse_schema(self):
1591 'provide custom schema rule for inverting this graph... just use keyColumn!'
1592 return dict(invert=True,uniqueMapping=True)
1593 def __getitem__(self,k):
1594 if not hasattr(k,'db') or k.db != self.sourceDB:
1595 raise KeyError('object is not in the sourceDB bound to this graph!')
1596 try:
1597 return self._weakValueDict[k.id] # get from cache
1598 except KeyError:
1599 pass
1600 d = ForeignKeyEdge(self,k)
1601 self._weakValueDict[k.id] = d # save in cache
1602 return d
1603 def __setitem__(self, k, v):
1604 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1605 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1606 def __delitem__(self, k):
1607 raise KeyError('''Instead of del g[k], follow a graph
1608 interface: del g[src][dest]''')
1609 def keys(self):
1610 return self.sourceDB.values()
1611 __invert__ = standard_invert
1613 def describeDBTables(name,cursor,idDict):
1615 Get table info about database <name> via <cursor>, and store primary keys
1616 in idDict, along with a list of the tables each key indexes.
1618 cursor.execute('use %s' % name)
1619 cursor.execute('show tables')
1620 tables={}
1621 l=[c[0] for c in cursor.fetchall()]
1622 for t in l:
1623 tname=name+'.'+t
1624 o=SQLTable(tname,cursor)
1625 tables[tname]=o
1626 for f in o.description:
1627 if f==o.primary_key:
1628 idDict.setdefault(f, []).append(o)
1629 elif f[-3:]=='_id' and f not in idDict:
1630 idDict[f]=[]
1631 return tables
1635 def indexIDs(tables,idDict=None):
1636 "Get an index of primary keys in the <tables> dictionary."
1637 if idDict==None:
1638 idDict={}
1639 for o in tables.values():
1640 if o.primary_key:
1641 if o.primary_key not in idDict:
1642 idDict[o.primary_key]=[]
1643 idDict[o.primary_key].append(o) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1644 for f in o.description:
1645 if f[-3:]=='_id' and f not in idDict:
1646 idDict[f]=[]
1647 return idDict
1651 def suffixSubset(tables,suffix):
1652 "Filter table index for those matching a specific suffix"
1653 subset={}
1654 for name,t in tables.items():
1655 if name.endswith(suffix):
1656 subset[name]=t
1657 return subset
1660 PRIMARY_KEY=1
1662 def graphDBTables(tables,idDict):
1663 g=dictgraph()
1664 for t in tables.values():
1665 for f in t.description:
1666 if f==t.primary_key:
1667 edgeInfo=PRIMARY_KEY
1668 else:
1669 edgeInfo=None
1670 g.setEdge(f,t,edgeInfo)
1671 g.setEdge(t,f,edgeInfo)
1672 return g
1674 SQLTypeTranslation= {types.StringType:'varchar(32)',
1675 types.IntType:'int',
1676 types.FloatType:'float'}
1678 def createTableFromRepr(rows,tableName,cursor,typeTranslation=None,
1679 optionalDict=None,indexDict=()):
1680 """Save rows into SQL tableName using cursor, with optional
1681 translations of columns to specific SQL types (specified
1682 by typeTranslation dict).
1683 - optionDict can specify columns that are allowed to be NULL.
1684 - indexDict can specify columns that must be indexed; columns
1685 whose names end in _id will be indexed by default.
1686 - rows must be an iterator which in turn returns dictionaries,
1687 each representing a tuple of values (indexed by their column
1688 names).
1690 try:
1691 row=rows.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1692 except StopIteration:
1693 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1694 try:
1695 createTableFromRow(cursor, tableName,row,typeTranslation,
1696 optionalDict,indexDict)
1697 except:
1698 pass
1699 storeRow(cursor,tableName,row) # SAVE OUR FIRST ROW
1700 for row in rows: # NOW SAVE ALL THE ROWS
1701 storeRow(cursor,tableName,row)
1703 def createTableFromRow(cursor, tableName, row,typeTranslation=None,
1704 optionalDict=None,indexDict=()):
1705 create_defs=[]
1706 for col,val in row.items(): # PREPARE SQL TYPES FOR COLUMNS
1707 coltype=None
1708 if typeTranslation!=None and col in typeTranslation:
1709 coltype=typeTranslation[col] # USER-SUPPLIED TRANSLATION
1710 elif type(val) in SQLTypeTranslation:
1711 coltype=SQLTypeTranslation[type(val)]
1712 else: # SEARCH FOR A COMPATIBLE TYPE
1713 for t in SQLTypeTranslation:
1714 if isinstance(val,t):
1715 coltype=SQLTypeTranslation[t]
1716 break
1717 if coltype==None:
1718 raise TypeError("Don't know SQL type to use for %s" % col)
1719 create_def='%s %s' %(col,coltype)
1720 if optionalDict==None or col not in optionalDict:
1721 create_def+=' not null'
1722 create_defs.append(create_def)
1723 for col in row: # CREATE INDEXES FOR ID COLUMNS
1724 if col[-3:]=='_id' or col in indexDict:
1725 create_defs.append('index(%s)' % col)
1726 cmd='create table if not exists %s (%s)' % (tableName,','.join(create_defs))
1727 cursor.execute(cmd) # CREATE THE TABLE IN THE DATABASE
1730 def storeRow(cursor, tableName, row):
1731 row_format=','.join(len(row)*['%s'])
1732 cmd='insert into %s values (%s)' % (tableName,row_format)
1733 cursor.execute(cmd,tuple(row.values()))
1735 def storeRowDelayed(cursor, tableName, row):
1736 row_format=','.join(len(row)*['%s'])
1737 cmd='insert delayed into %s values (%s)' % (tableName,row_format)
1738 cursor.execute(cmd,tuple(row.values()))
1741 class TableGroup(dict):
1742 'provide attribute access to dbname qualified tablenames'
1743 def __init__(self,db='test',suffix=None,**kw):
1744 dict.__init__(self)
1745 self.db=db
1746 if suffix is not None:
1747 self.suffix=suffix
1748 for k,v in kw.items():
1749 if v is not None and '.' not in v:
1750 v=self.db+'.'+v # ADD DATABASE NAME AS PREFIX
1751 self[k]=v
1752 def __getattr__(self,k):
1753 return self[k]
1755 def sqlite_connect(*args, **kwargs):
1756 sqlite = import_sqlite()
1757 connection = sqlite.connect(*args, **kwargs)
1758 cursor = connection.cursor()
1759 return connection, cursor
1761 class DBServerInfo(object):
1762 'picklable reference to a database server'
1763 def __init__(self, moduleName='MySQLdb', *args, **kwargs):
1764 try:
1765 self.__class__ = _DBServerModuleDict[moduleName]
1766 except KeyError:
1767 raise ValueError('Module name not found in _DBServerModuleDict: '\
1768 + moduleName)
1769 self.moduleName = moduleName
1770 self.args = args
1771 self.kwargs = kwargs # connection arguments
1773 def cursor(self):
1774 """returns cursor associated with the DB server info (reused)"""
1775 try:
1776 return self._cursor
1777 except AttributeError:
1778 self._start_connection()
1779 return self._cursor
1781 def new_cursor(self, arraysize=None):
1782 """returns a NEW cursor; you must close it yourself! """
1783 if not hasattr(self, '_connection'):
1784 self._start_connection()
1785 cursor = self._connection.cursor()
1786 if arraysize is not None:
1787 cursor.arraysize = arraysize
1788 return cursor
1790 def close(self):
1791 """Close file containing this database"""
1792 self._cursor.close()
1793 self._connection.close()
1794 del self._cursor
1795 del self._connection
1797 def __getstate__(self):
1798 """return all picklable arguments"""
1799 return dict(args=self.args, kwargs=self.kwargs,
1800 moduleName=self.moduleName)
1803 class MySQLServerInfo(DBServerInfo):
1804 'customized for MySQLdb SSCursor support via new_cursor()'
1805 def _start_connection(self):
1806 self._connection,self._cursor = mysql_connect(*self.args, **self.kwargs)
1807 def new_cursor(self, arraysize=None):
1808 'provide streaming cursor support'
1809 try:
1810 conn = self._conn_sscursor
1811 except AttributeError:
1812 self._conn_sscursor,cursor = mysql_connect(useStreaming=True,
1813 *self.args, **self.kwargs)
1814 else:
1815 cursor = self._conn_sscursor.cursor()
1816 if arraysize is not None:
1817 cursor.arraysize = arraysize
1818 return cursor
1819 def close(self):
1820 DBServerInfo.close(self)
1821 try:
1822 self._conn_sscursor.close()
1823 del self._conn_sscursor
1824 except AttributeError:
1825 pass
1826 def iter_keys(self, db, selectCols=None, orderBy='', map_f=iter,
1827 cache_f=lambda x:[t[0] for t in x], get_f=None, **kwargs):
1828 cursor = self.new_cursor()
1829 block_generator = BlockGenerator(db, self, cursor, selectCols=None,
1830 orderBy='', **kwargs)
1831 return db.generic_iterator(cursor=cursor, cache_f=cache_f,
1832 map_f=map_f, fetch_f=block_generator)
1835 class BlockGenerator(object):
1836 def __init__(self, db, serverInfo, cursor, whereClause='', **kwargs):
1837 self.db = db
1838 self.serverInfo = serverInfo
1839 self.cursor = cursor
1840 self.kwargs = kwargs
1841 self.blockSize = 10000
1842 self.whereClause = ''
1843 #self.__iter__() # start me up!
1845 ## def __iter__(self):
1846 ## 'initialize this iterator'
1847 ## self.db._select(cursor=cursor, selectCols='min(%s),max(%s),count(*)'
1848 ## % (self.db.name, self.db.name))
1849 ## l = self.cursor.fetchall()
1850 ## self.minID, self.maxID, self.count = l[0]
1851 ## self.start = self.minID - 1 # only works for int
1852 ## return self
1854 ## def next(self):
1855 ## 'get the next start position'
1856 ## if self.start >= self.maxID:
1857 ## raise StopIteration
1858 ## return start
1860 def __call__(self):
1861 'get the next block of data'
1862 ## try:
1863 ## start = self.next()
1864 ## except StopIteration:
1865 ## return ()
1866 self.db._select(cursor=self.cursor, whereClause=self.whereClause,
1867 limit='LIMIT %s' % self.blockSize, **kwargs)
1868 rows = self.cursor.fetchall()
1869 lastrow = rows[-1]
1870 if len(lastrow) > 1: # extract the last ID value in this block
1871 start = lastrow[self.db.data['id']]
1872 else:
1873 start = lastrow[0]
1874 self.whereClause = '%s>%s' %(self.db.primary_key,start)
1875 return rows
1879 class SQLiteServerInfo(DBServerInfo):
1880 """picklable reference to a sqlite database"""
1881 def __init__(self, database, *args, **kwargs):
1882 """Takes same arguments as sqlite3.connect()"""
1883 DBServerInfo.__init__(self, 'sqlite',
1884 SourceFileName(database), # save abs path!
1885 *args, **kwargs)
1886 def _start_connection(self):
1887 self._connection,self._cursor = sqlite_connect(*self.args, **self.kwargs)
1888 def __getstate__(self):
1889 if self.args[0] == ':memory:':
1890 raise ValueError('SQLite in-memory database is not picklable!')
1891 return DBServerInfo.__getstate__(self)
1893 # list of DBServerInfo subclasses for different modules
1894 _DBServerModuleDict = dict(MySQLdb=MySQLServerInfo, sqlite=SQLiteServerInfo)
1897 class MapView(object, UserDict.DictMixin):
1898 'general purpose 1:1 mapping defined by any SQL query'
1899 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None,
1900 serverInfo=None, inverseSQL=None, **kwargs):
1901 self.sourceDB = sourceDB
1902 self.targetDB = targetDB
1903 self.viewSQL = viewSQL
1904 self.inverseSQL = inverseSQL
1905 if cursor is None:
1906 if serverInfo is not None: # get cursor from serverInfo
1907 cursor = serverInfo.cursor()
1908 else:
1909 try: # can we get it from our other db?
1910 serverInfo = sourceDB.serverInfo
1911 except AttributeError:
1912 raise ValueError('you must provide serverInfo or cursor!')
1913 else:
1914 cursor = serverInfo.cursor()
1915 self.cursor = cursor
1916 self.serverInfo = serverInfo
1917 self.get_sql_format(False) # get sql formatter for this db interface
1918 _schemaModuleDict = _schemaModuleDict # default module list
1919 get_sql_format = get_table_schema
1920 def __getitem__(self, k):
1921 if not hasattr(k,'db') or k.db != self.sourceDB:
1922 raise KeyError('object is not in the sourceDB bound to this map!')
1923 sql,params = self._format_query(self.viewSQL, (k.id,))
1924 self.cursor.execute(sql, params) # formatted for this db interface
1925 t = self.cursor.fetchmany(2) # get at most two rows
1926 if len(t) != 1:
1927 raise KeyError('%s not found in MapView, or not unique'
1928 % str(k))
1929 return self.targetDB[t[0][0]] # get the corresponding object
1930 _pickleAttrs = dict(sourceDB=0, targetDB=0, viewSQL=0, serverInfo=0,
1931 inverseSQL=0)
1932 __getstate__ = standard_getstate
1933 __setstate__ = standard_setstate
1934 __setitem__ = __delitem__ = clear = pop = popitem = update = \
1935 setdefault = read_only_error
1936 def __iter__(self):
1937 'only yield sourceDB items that are actually in this mapping!'
1938 for k in self.sourceDB.itervalues():
1939 try:
1940 self[k]
1941 yield k
1942 except KeyError:
1943 pass
1944 def keys(self):
1945 return [k for k in self] # don't use list(self); causes infinite loop!
1946 def __invert__(self):
1947 try:
1948 return self._inverse
1949 except AttributeError:
1950 if self.inverseSQL is None:
1951 raise ValueError('this MapView has no inverseSQL!')
1952 self._inverse = self.__class__(self.targetDB, self.sourceDB,
1953 self.inverseSQL, self.cursor,
1954 serverInfo=self.serverInfo,
1955 inverseSQL=self.viewSQL)
1956 self._inverse._inverse = self
1957 return self._inverse
1959 class GraphViewEdgeDict(UserDict.DictMixin):
1960 'edge dictionary for GraphView: just pre-loaded on init'
1961 def __init__(self, g, k):
1962 self.g = g
1963 self.k = k
1964 sql,params = self.g._format_query(self.g.viewSQL, (k.id,))
1965 self.g.cursor.execute(sql, params) # run the query
1966 l = self.g.cursor.fetchall() # get results
1967 if len(l) <= 0:
1968 raise KeyError('key %s not in GraphView' % k.id)
1969 self.targets = [t[0] for t in l] # preserve order of the results
1970 d = {} # also keep targetID:edgeID mapping
1971 if self.g.edgeDB is not None: # save with edge info
1972 for t in l:
1973 d[t[0]] = t[1]
1974 else:
1975 for t in l:
1976 d[t[0]] = None
1977 self.targetDict = d
1978 def __len__(self):
1979 return len(self.targets)
1980 def __iter__(self):
1981 for k in self.targets:
1982 yield self.g.targetDB[k]
1983 def keys(self):
1984 return list(self)
1985 def iteritems(self):
1986 if self.g.edgeDB is not None: # save with edge info
1987 for k in self.targets:
1988 yield (self.g.targetDB[k], self.g.edgeDB[self.targetDict[k]])
1989 else: # just save the list of targets, no edge info
1990 for k in self.targets:
1991 yield (self.g.targetDB[k], None)
1992 def __getitem__(self, o, exitIfFound=False):
1993 'for the specified target object, return its associated edge object'
1994 try:
1995 if o.db is not self.g.targetDB:
1996 raise KeyError('key is not part of targetDB!')
1997 edgeID = self.targetDict[o.id]
1998 except AttributeError:
1999 raise KeyError('key has no id or db attribute?!')
2000 if exitIfFound:
2001 return
2002 if self.g.edgeDB is not None: # return the edge object
2003 return self.g.edgeDB[edgeID]
2004 else: # no edge info
2005 return None
2006 def __contains__(self, o):
2007 try:
2008 self.__getitem__(o, True) # raise KeyError if not found
2009 return True
2010 except KeyError:
2011 return False
2012 __setitem__ = __delitem__ = clear = pop = popitem = update = \
2013 setdefault = read_only_error
2015 class GraphView(MapView):
2016 'general purpose graph interface defined by any SQL query'
2017 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None, edgeDB=None,
2018 **kwargs):
2019 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
2020 self.edgeDB = edgeDB
2021 MapView.__init__(self, sourceDB, targetDB, viewSQL, cursor, **kwargs)
2022 def __getitem__(self, k):
2023 if not hasattr(k,'db') or k.db != self.sourceDB:
2024 raise KeyError('object is not in the sourceDB bound to this map!')
2025 return GraphViewEdgeDict(self, k)
2026 _pickleAttrs = MapView._pickleAttrs.copy()
2027 _pickleAttrs.update(dict(edgeDB=0))
2029 # @CTB move to sqlgraph.py?
2031 class SQLSequence(SQLRow, SequenceBase):
2032 """Transparent access to a DB row representing a sequence.
2034 Use attrAlias dict to rename 'length' to something else.
2036 def _init_subclass(cls, db, **kwargs):
2037 db.seqInfoDict = db # db will act as its own seqInfoDict
2038 SQLRow._init_subclass(db=db, **kwargs)
2039 _init_subclass = classmethod(_init_subclass)
2040 def __init__(self, id):
2041 SQLRow.__init__(self, id)
2042 SequenceBase.__init__(self)
2043 def __len__(self):
2044 return self.length
2045 def strslice(self,start,end):
2046 "Efficient access to slice of a sequence, useful for huge contigs"
2047 return self._select('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
2048 %(self.db._attrSQL('seq'),start+1,end-start))
2050 class DNASQLSequence(SQLSequence):
2051 _seqtype=DNA_SEQTYPE
2053 class RNASQLSequence(SQLSequence):
2054 _seqtype=RNA_SEQTYPE
2056 class ProteinSQLSequence(SQLSequence):
2057 _seqtype=PROTEIN_SEQTYPE