added tests for MapView iterators and len
[pygr.git] / pygr / sqlgraph.py
blob16b1d1d11d27a2c353be67f1ca064369b3472c1a
3 from __future__ import generators
4 from mapping import *
5 from sequence import SequenceBase, DNA_SEQTYPE, RNA_SEQTYPE, PROTEIN_SEQTYPE
6 import types
7 from classutil import methodFactory,standard_getstate,\
8 override_rich_cmp,generate_items,get_bound_subclass,standard_setstate,\
9 get_valid_path,standard_invert,RecentValueDictionary,read_only_error,\
10 SourceFileName, split_kwargs
11 import os
12 import platform
13 import UserDict
14 import warnings
15 import logger
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self, db, attr):
20 self.icol = db.data[attr] # index of this attribute in the tuple
21 def __get__(self, obj, klass):
22 return obj._data[self.icol]
23 def __set__(self, obj, val):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor):
27 def __set__(self, obj, val):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor):
32 'read-write interface to named attribute'
33 def __init__(self, db, attr):
34 self.attr = attr
35 self.icol = db.data[attr] # index of this attribute in the tuple
36 self.attrSQL = db._attrSQL(attr, sqlColumn=True) # SQL column name
37 def __set__(self, obj, val):
38 obj.db._update(obj.id, self.attrSQL, val) # AND UPDATE THE DATABASE
39 obj.save_local(self.attr, val)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self, db, attr):
44 self.selectSQL = db._attrSQL(attr) # SQL expression for this attr
45 def __get__(self, obj, klass):
46 return obj._select(self.selectSQL)
47 def __set__(self, obj, val):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self, obj, val):
53 obj.db._update(obj.id, self.selectSQL, val) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self, db, attr):
58 self.attr = '_'+attr
59 def __get__(self, obj, klass):
60 return getattr(obj, self.attr)
61 def __set__(self, obj, val):
62 raise AttributeError('attribute %s is read-only!' % self.attr)
65 def select_from_row(row, what):
66 "return value from SQL expression applied to this row"
67 sql,params = row.db._format_query('select %s from %s where %s=%%s limit 2'
68 % (what,row.db.name,row.db.primary_key),
69 (row.id,))
70 row.db.cursor.execute(sql, params)
71 t = row.db.cursor.fetchmany(2) # get at most two rows
72 if len(t) != 1:
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row.db.name,str(row.id),what))
75 return t[0][0] #return the single field we requested
77 def init_row_subclass(cls, db):
78 'add descriptors for db attributes'
79 for attr in db.data: # bind all database columns
80 if attr == 'id': # handle ID attribute specially
81 setattr(cls, attr, cls._idDescriptor(db, attr))
82 continue
83 try: # check if this attr maps to an SQL column
84 db._attrSQL(attr, columnNumber=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls, attr, cls._sqlDescriptor(db, attr))
87 else: # treat as interface to our stored tuple
88 setattr(cls, attr, cls._columnDescriptor(db, attr))
90 def dir_row(self):
91 """get list of column names as our attributes """
92 return self.db.data.keys()
94 class TupleO(object):
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor = TupleDescriptor
115 _idDescriptor = TupleIDDescriptor
116 _sqlDescriptor = SQLDescriptor
117 _init_subclass = classmethod(init_row_subclass)
118 _select = select_from_row
119 __dir__ = dir_row
120 def __init__(self, data):
121 self._data = data # save our data tuple
123 def insert_and_cache_id(self, l, **kwargs):
124 'insert tuple into db and cache its rowID on self'
125 self.db._insert(l) # save to database
126 try:
127 rowID = kwargs['id'] # use the ID supplied by user
128 except KeyError:
129 rowID = self.db.get_insert_id() # get auto-inc ID value
130 self.cache_id(rowID) # cache this ID on self
132 class TupleORW(TupleO):
133 'read-write version of TupleO'
134 _columnDescriptor = TupleDescriptorRW
135 insert_and_cache_id = insert_and_cache_id
136 def __init__(self, data, newRow=False, **kwargs):
137 if not newRow: # just cache data from the database
138 self._data = data
139 return
140 self._data = self.db.tuple_from_dict(kwargs) # convert to tuple
141 self.insert_and_cache_id(self._data, **kwargs)
142 def cache_id(self,row_id):
143 self.save_local('id',row_id)
144 def save_local(self,attr,val):
145 icol = self._attrcol[attr]
146 try:
147 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self._data = list(self._data)
150 self._data[icol] = val # FINALLY UPDATE OUR LOCAL CACHE
152 TupleO._RWClass = TupleORW # record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self, db, attr, readOnly = False):
157 self.attr = attr
158 self.col = db._attrSQL(attr, sqlColumn=True) # MAP THIS TO SQL COLUMN NAME
159 self.db = db
160 if readOnly:
161 self.__class__ = self._readOnlyClass
162 def __get__(self, obj, objtype):
163 try:
164 return obj.__dict__[self.attr]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self.col==self.db.primary_key:
167 raise AttributeError
168 self.db._select('where %s=%%s' % self.db.primary_key,(obj.id,),self.col)
169 l = self.db.cursor.fetchall()
170 if len(l)!=1:
171 raise AttributeError('db row not found or not unique!')
172 obj.__dict__[self.attr] = l[0][0] # UPDATE THE CACHE
173 return l[0][0]
174 def __set__(self, obj, val):
175 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self.db._update(obj.id, self.col, val) # UPDATE THE DATABASE
177 obj.__dict__[self.attr] = val # UPDATE THE CACHE
178 ## try:
179 ## m = self.consequences
180 ## except AttributeError:
181 ## return
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
185 ## import new
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor):
188 def __set__(self, obj, val):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor._readOnlyClass = ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor = _sqlDescriptor = SQLDescriptor
200 _idDescriptor = ReadOnlyDescriptor
201 _init_subclass = classmethod(init_row_subclass)
202 _select = select_from_row
203 __dir__ = dir_row
204 def __init__(self, rowID):
205 self._id = rowID
208 class SQLRowRW(SQLRow):
209 'read-write version of SQLRow'
210 _columnDescriptor = SQLDescriptorRW
211 insert_and_cache_id = insert_and_cache_id
212 def __init__(self, rowID, newRow=False, **kwargs):
213 if not newRow: # just cache data from the database
214 return self.cache_id(rowID)
215 l = self.db.tuple_from_dict(kwargs) # convert to tuple
216 self.insert_and_cache_id(l, **kwargs)
217 def cache_id(self, rowID):
218 self._id = rowID
220 SQLRow._RWClass = SQLRowRW
224 def list_to_dict(names, values):
225 'return dictionary of those named args that are present in values[]'
226 d={}
227 for i,v in enumerate(values):
228 try:
229 d[names[i]] = v
230 except IndexError:
231 break
232 return d
235 def get_name_cursor(name=None, **kwargs):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
239 if name is not None:
240 argList = name.split() # TREAT AS WS-SEPARATED LIST
241 if len(argList)>1:
242 name = argList[0] # USE 1ST ARG AS TABLE NAME
243 argnames = ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs = kwargs.copy() # a copy we can overwrite
245 kwargs.update(list_to_dict(argnames, argList[1:]))
246 serverInfo = DBServerInfo(**kwargs)
247 return name,serverInfo.cursor(),serverInfo
249 def mysql_connect(connect=None, configFile=None, **args):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs = args.copy() # a copy we can modify
252 if 'user' not in kwargs and configFile is None: #Find where config file is
253 osname = platform.system()
254 if osname in('Microsoft', 'Windows'): # Machine is a Windows box
255 paths = []
256 try: # handle case where WINDIR not defined by Windows...
257 windir = os.environ['WINDIR']
258 paths += [(windir, 'my.ini'), (windir, 'my.cnf')]
259 except KeyError:
260 pass
261 try:
262 sysdrv = os.environ['SYSTEMDRIVE']
263 paths += [(sysdrv, os.path.sep + 'my.ini'),
264 (sysdrv, os.path.sep + 'my.cnf')]
265 except KeyError:
266 pass
267 if len(paths) > 0:
268 configFile = get_valid_path(*paths)
269 else: # treat as normal platform with home directories
270 configFile = os.path.join(os.path.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile = configFile or os.path.join( os.getcwd(), 'mysql.cnf' )
276 if configFile and os.path.exists(configFile):
277 kwargs['read_default_file'] = configFile
278 connect = None # force it to use MySQLdb
279 if connect is None:
280 import MySQLdb
281 connect = MySQLdb.connect
282 kwargs['compress'] = True
283 conn = connect(**kwargs)
284 cursor = conn.cursor()
285 return conn,cursor
287 _mysqlMacros = dict(IGNORE='ignore', REPLACE='replace',
288 AUTO_INCREMENT='AUTO_INCREMENT', SUBSTRING='substring',
289 SUBSTR_FROM='FROM', SUBSTR_FOR='FOR')
291 def mysql_table_schema(self, analyzeSchema=True):
292 'retrieve table schema from a MySQL database, save on self'
293 import MySQLdb
294 self._format_query = SQLFormatDict(MySQLdb.paramstyle, _mysqlMacros)
295 if not analyzeSchema:
296 return
297 self.clear_schema() # reset settings and dictionaries
298 self.cursor.execute('describe %s' % self.name) # get info about columns
299 columns = self.cursor.fetchall()
300 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
301 for icol,c in enumerate(columns):
302 field = c[0]
303 self.columnName.append(field) # list of columns in same order as table
304 if c[3] == "PRI": # record as primary key
305 if self.primary_key is None:
306 self.primary_key = field
307 else:
308 try:
309 self.primary_key.append(field)
310 except AttributeError:
311 self.primary_key = [self.primary_key,field]
312 if c[1][:3].lower() == 'int':
313 self.usesIntID = True
314 else:
315 self.usesIntID = False
316 elif c[3] == "MUL":
317 self.indexed[field] = icol
318 self.description[field] = self.cursor.description[icol]
319 self.columnType[field] = c[1] # SQL COLUMN TYPE
321 _sqliteMacros = dict(IGNORE='or ignore', REPLACE='insert or replace',
322 AUTO_INCREMENT='', SUBSTRING='substr',
323 SUBSTR_FROM=',', SUBSTR_FOR=',')
325 def import_sqlite():
326 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
327 try:
328 import sqlite3 as sqlite
329 except ImportError:
330 from pysqlite2 import dbapi2 as sqlite
331 return sqlite
333 def sqlite_table_schema(self, analyzeSchema=True):
334 'retrieve table schema from a sqlite3 database, save on self'
335 sqlite = import_sqlite()
336 self._format_query = SQLFormatDict(sqlite.paramstyle, _sqliteMacros)
337 if not analyzeSchema:
338 return
339 self.clear_schema() # reset settings and dictionaries
340 self.cursor.execute('PRAGMA table_info("%s")' % self.name)
341 columns = self.cursor.fetchall()
342 self.cursor.execute('select * from %s limit 1' % self.name) # descriptions
343 for icol,c in enumerate(columns):
344 field = c[1]
345 self.columnName.append(field) # list of columns in same order as table
346 self.description[field] = self.cursor.description[icol]
347 self.columnType[field] = c[2] # SQL COLUMN TYPE
348 self.cursor.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self.name) # get primary key / unique indexes
349 for indexname in self.cursor.fetchall(): # search indexes for primary key
350 self.cursor.execute('PRAGMA index_info("%s")' % indexname)
351 l = self.cursor.fetchall() # get list of columns in this index
352 if len(l) == 1: # assume 1st single-column unique index is primary key!
353 self.primary_key = l[0][2]
354 break # done searching for primary key!
355 if self.primary_key is None: # grrr, INTEGER PRIMARY KEY handled differently
356 self.cursor.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self.name)
357 sql = self.cursor.fetchall()[0][0]
358 for columnSQL in sql[sql.index('(') + 1 :].split(','):
359 if 'primary key' in columnSQL.lower(): # must be the primary key!
360 col = columnSQL.split()[0] # get column name
361 if col in self.columnType:
362 self.primary_key = col
363 break # done searching for primary key!
364 else:
365 raise ValueError('unknown primary key %s in table %s'
366 % (col,self.name))
367 if self.primary_key is not None: # check its type
368 if self.columnType[self.primary_key] == 'int' or \
369 self.columnType[self.primary_key] == 'integer':
370 self.usesIntID = True
371 else:
372 self.usesIntID = False
374 class SQLFormatDict(object):
375 '''Perform SQL keyword replacements for maintaining compatibility across
376 a wide range of SQL backends. Uses Python dict-based string format
377 function to do simple string replacements, and also to convert
378 params list to the paramstyle required for this interface.
379 Create by passing a dict of macros and the db-api paramstyle:
380 sfd = SQLFormatDict("qmark", substitutionDict)
382 Then transform queries+params as follows; input should be "format" style:
383 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
384 cursor.execute(sql, params)
386 _paramFormats = dict(pyformat='%%(%d)s', numeric=':%d', named=':%d',
387 qmark='(ignore)', format='(ignore)')
388 def __init__(self, paramstyle, substitutionDict={}):
389 self.substitutionDict = substitutionDict.copy()
390 self.paramstyle = paramstyle
391 self.paramFormat = self._paramFormats[paramstyle]
392 self.makeDict = (paramstyle == 'pyformat' or paramstyle == 'named')
393 if paramstyle == 'qmark': # handle these as simple substitution
394 self.substitutionDict['?'] = '?'
395 elif paramstyle == 'format':
396 self.substitutionDict['?'] = '%s'
397 def __getitem__(self, k):
398 'apply correct substitution for this SQL interface'
399 try:
400 return self.substitutionDict[k] # apply our substitutions
401 except KeyError:
402 pass
403 if k == '?': # sequential parameter
404 s = self.paramFormat % self.iparam
405 self.iparam += 1 # advance to the next parameter
406 return s
407 raise KeyError('unknown macro: %s' % k)
408 def __call__(self, sql, paramList):
409 'returns corrected sql,params for this interface'
410 self.iparam = 1 # DB-ABI param indexing begins at 1
411 sql = sql.replace('%s', '%(?)s') # convert format into pyformat
412 s = sql % self # apply all %(x)s replacements in sql
413 if self.makeDict: # construct a params dict
414 paramDict = {}
415 for i,param in enumerate(paramList):
416 paramDict[str(i + 1)] = param #DB-ABI param indexing begins at 1
417 return s,paramDict
418 else: # just return the original params list
419 return s,paramList
421 def get_table_schema(self, analyzeSchema=True):
422 'run the right schema function based on type of db server connection'
423 try:
424 modname = self.cursor.__class__.__module__
425 except AttributeError:
426 raise ValueError('no cursor object or module information!')
427 try:
428 schema_func = self._schemaModuleDict[modname]
429 except KeyError:
430 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
431 attribute to supply a method for obtaining table schema
432 for this module''' % modname)
433 schema_func(self, analyzeSchema) # run the schema function
436 _schemaModuleDict = {'MySQLdb.cursors':mysql_table_schema,
437 'pysqlite2.dbapi2':sqlite_table_schema,
438 'sqlite3':sqlite_table_schema}
440 class SQLTableBase(object, UserDict.DictMixin):
441 "Store information about an SQL table as dict keyed by primary key"
442 _schemaModuleDict = _schemaModuleDict # default module list
443 get_table_schema = get_table_schema
444 def __init__(self,name,cursor=None,itemClass=None,attrAlias=None,
445 clusterKey=None,createTable=None,graph=None,maxCache=None,
446 arraysize=1024, itemSliceClass=None, dropIfExists=False,
447 serverInfo=None, autoGC=True, orderBy=None,
448 writeable=False, **kwargs):
449 if autoGC: # automatically garbage collect unused objects
450 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
451 else:
452 self._weakValueDict = {}
453 self.autoGC = autoGC
454 self.orderBy = orderBy
455 self.writeable = writeable
456 if cursor is None:
457 if serverInfo is not None: # get cursor from serverInfo
458 cursor = serverInfo.cursor()
459 else: # try to read connection info from name or config file
460 name,cursor,serverInfo = get_name_cursor(name,**kwargs)
461 else:
462 warnings.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
463 DeprecationWarning, stacklevel=2)
464 self.cursor = cursor
465 if createTable is not None: # RUN COMMAND TO CREATE THIS TABLE
466 if dropIfExists: # get rid of any existing table
467 cursor.execute('drop table if exists ' + name)
468 self.get_table_schema(False) # check dbtype, init _format_query
469 sql,params = self._format_query(createTable, ()) # apply macros
470 cursor.execute(sql) # create the table
471 self.name = name
472 if graph is not None:
473 self.graph = graph
474 if maxCache is not None:
475 self.maxCache = maxCache
476 if arraysize is not None:
477 self.arraysize = arraysize
478 cursor.arraysize = arraysize
479 self.get_table_schema() # get schema of columns to serve as attrs
480 self.data = {} # map of all attributes, including aliases
481 for icol,field in enumerate(self.columnName):
482 self.data[field] = icol # 1st add mappings to columns
483 try:
484 self.data['id']=self.data[self.primary_key]
485 except (KeyError,TypeError):
486 pass
487 if hasattr(self,'_attr_alias'): # apply attribute aliases for this class
488 self.addAttrAlias(False,**self._attr_alias)
489 self.objclass(itemClass) # NEED TO SUBCLASS OUR ITEM CLASS
490 if itemSliceClass is not None:
491 self.itemSliceClass = itemSliceClass
492 get_bound_subclass(self, 'itemSliceClass', self.name) # need to subclass itemSliceClass
493 if attrAlias is not None: # ADD ATTRIBUTE ALIASES
494 self.attrAlias = attrAlias # RECORD FOR PICKLING PURPOSES
495 self.data.update(attrAlias)
496 if clusterKey is not None:
497 self.clusterKey=clusterKey
498 if serverInfo is not None:
499 self.serverInfo = serverInfo
501 def __len__(self):
502 self._select(selectCols='count(*)')
503 return self.cursor.fetchone()[0]
504 def __hash__(self):
505 return id(self)
506 _pickleAttrs = dict(name=0, clusterKey=0, maxCache=0, arraysize=0,
507 attrAlias=0, serverInfo=0, autoGC=0, orderBy=0,
508 writeable=0)
509 __getstate__ = standard_getstate
510 def __setstate__(self,state):
511 # default cursor provisioning by worldbase is deprecated!
512 ## if 'serverInfo' not in state: # hmm, no address for db server?
513 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
514 ## from Data import getResource
515 ## state['cursor'] = getResource.getTableCursor(state['name'])
516 ## except ImportError:
517 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
518 self.__init__(**state)
519 def __repr__(self):
520 return '<SQL table '+self.name+'>'
522 def clear_schema(self):
523 'reset all schema information for this table'
524 self.description={}
525 self.columnName = []
526 self.columnType = {}
527 self.usesIntID = None
528 self.primary_key = None
529 self.indexed = {}
530 def _attrSQL(self,attr,sqlColumn=False,columnNumber=False):
531 "Translate python attribute name to appropriate SQL expression"
532 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
533 field=self.data[attr]
534 except KeyError:
535 raise AttributeError('attribute %s not a valid column or alias in %s'
536 % (attr,self.name))
537 if sqlColumn: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
538 try: # CHECK IF field IS COLUMN NUMBER
539 return self.columnName[field] # RETURN SQL COLUMN NAME
540 except TypeError:
541 try: # CHECK IF field IS SQL COLUMN NAME
542 return self.columnName[self.data[field]] # THIS WILL JUST RETURN field...
543 except (KeyError,TypeError):
544 raise AttributeError('attribute %s does not map to an SQL column in %s'
545 % (attr,self.name))
546 if columnNumber:
547 try: # CHECK IF field IS A COLUMN NUMBER
548 return field+0 # ONLY RETURN AN INTEGER
549 except TypeError:
550 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
551 return self.data[field]+0 # ONLY RETURN AN INTEGER
552 except (KeyError,TypeError):
553 raise ValueError('attribute %s does not map to a SQL column!' % attr)
554 if isinstance(field,types.StringType):
555 attr=field # USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
556 elif attr=='id':
557 attr=self.primary_key
558 return attr
559 def addAttrAlias(self,saveToPickle=True,**kwargs):
560 """Add new attributes as aliases of existing attributes.
561 They can be specified either as named args:
562 t.addAttrAlias(newattr=oldattr)
563 or by passing a dictionary kwargs whose keys are newattr
564 and values are oldattr:
565 t.addAttrAlias(**kwargs)
566 saveToPickle=True forces these aliases to be saved if object is pickled.
568 if saveToPickle:
569 self.attrAlias.update(kwargs)
570 for key,val in kwargs.items():
571 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
572 self.data[val]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
573 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
574 except TypeError: # val IS ITSELF AN ALIAS
575 self.data[key] = self.data[val] # SO MAP TO WHAT IT MAPS TO
576 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
577 self.data[key] = val
578 def objclass(self,oclass=None):
579 "Create class representing a row in this table by subclassing oclass, adding data"
580 if oclass is not None: # use this as our base itemClass
581 self.itemClass = oclass
582 if self.writeable:
583 self.itemClass = self.itemClass._RWClass # use its writeable version
584 oclass = get_bound_subclass(self, 'itemClass', self.name,
585 subclassArgs=dict(db=self)) # bind itemClass
586 if issubclass(oclass, TupleO):
587 oclass._attrcol = self.data # BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
588 if hasattr(oclass,'_tableclass') and not isinstance(self,oclass._tableclass):
589 self.__class__=oclass._tableclass # ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
590 def _select(self, whereClause='', params=(), selectCols='t1.*',
591 cursor=None):
592 'execute the specified query but do not fetch'
593 sql,params = self._format_query('select %s from %s t1 %s'
594 % (selectCols,self.name,whereClause), params)
595 if cursor is None:
596 self.cursor.execute(sql, params)
597 else:
598 cursor.execute(sql, params)
599 def select(self,whereClause,params=None,oclass=None,selectCols='t1.*'):
600 "Generate the list of objects that satisfy the database SELECT"
601 if oclass is None:
602 oclass=self.itemClass
603 self._select(whereClause,params,selectCols)
604 l=self.cursor.fetchall()
605 for t in l:
606 yield self.cacheItem(t,oclass)
607 def query(self,**kwargs):
608 'query for intersection of all specified kwargs, returned as iterator'
609 criteria = []
610 params = []
611 for k,v in kwargs.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
612 if v is None: # CONVERT TO SQL NULL TEST
613 criteria.append('%s IS NULL' % self._attrSQL(k))
614 else: # TEST FOR EQUALITY
615 criteria.append('%s=%%s' % self._attrSQL(k))
616 params.append(v)
617 return self.select('where '+' and '.join(criteria),params)
618 def _update(self,row_id,col,val):
619 'update a single field in the specified row to the specified value'
620 sql,params = self._format_query('update %s set %s=%%s where %s=%%s'
621 %(self.name,col,self.primary_key),
622 (val,row_id))
623 self.cursor.execute(sql, params)
624 def getID(self,t):
625 try:
626 return t[self.data['id']] # GET ID FROM TUPLE
627 except TypeError: # treat as alias
628 return t[self.data[self.data['id']]]
629 def cacheItem(self,t,oclass):
630 'get obj from cache if possible, or construct from tuple'
631 try:
632 id=self.getID(t)
633 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
634 return oclass(t)
635 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
636 return self._weakValueDict[id]
637 except KeyError:
638 pass
639 o = oclass(t)
640 self._weakValueDict[id] = o # CACHE THIS ITEM IN OUR DICTIONARY
641 return o
642 def cache_items(self,rows,oclass=None):
643 if oclass is None:
644 oclass=self.itemClass
645 for t in rows:
646 yield self.cacheItem(t,oclass)
647 def foreignKey(self,attr,k):
648 'get iterator for objects with specified foreign key value'
649 return self.select('where %s=%%s'%attr,(k,))
650 def limit_cache(self):
651 'APPLY maxCache LIMIT TO CACHE SIZE'
652 try:
653 if self.maxCache<len(self._weakValueDict):
654 self._weakValueDict.clear()
655 except AttributeError:
656 pass
658 def get_new_cursor(self):
659 """Return a new cursor object, or None if not possible """
660 try:
661 new_cursor = self.serverInfo.new_cursor
662 except AttributeError:
663 return None
664 return new_cursor()
666 def generic_iterator(self, cursor=None, fetch_f=None, cache_f=None,
667 map_f=iter):
668 'generic iterator that runs fetch, cache and map functions'
669 if fetch_f is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
670 if cursor is None:
671 fetch_f = self.cursor.fetchmany
672 else: # isolate this iter from other queries
673 fetch_f = cursor.fetchmany
674 if cache_f is None:
675 cache_f = self.cache_items
676 while True:
677 self.limit_cache()
678 rows = fetch_f() # FETCH THE NEXT SET OF ROWS
679 if len(rows)==0: # NO MORE DATA SO ALL DONE
680 break
681 for v in map_f(cache_f(rows)): # CACHE AND GENERATE RESULTS
682 yield v
683 if cursor is not None: # close iterator now that we're done
684 cursor.close()
685 def tuple_from_dict(self, d):
686 'transform kwarg dict into tuple for storing in database'
687 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
688 for col,icol in self.data.items():
689 try:
690 l[icol] = d[col]
691 except (KeyError,TypeError):
692 pass
693 return l
694 def tuple_from_obj(self, obj):
695 'transform object attributes into tuple for storing in database'
696 l = [None]*len(self.description) # DEFAULT COLUMN VALUES ARE NULL
697 for col,icol in self.data.items():
698 try:
699 l[icol] = getattr(obj,col)
700 except (AttributeError,TypeError):
701 pass
702 return l
703 def _insert(self, l):
704 '''insert tuple into the database. Note this uses the MySQL
705 extension REPLACE, which overwrites any duplicate key.'''
706 s = '%(REPLACE)s into ' + self.name + ' values (' \
707 + ','.join(['%s']*len(l)) + ')'
708 sql,params = self._format_query(s, l)
709 self.cursor.execute(sql, params)
710 def insert(self, obj):
711 '''insert new row by transforming obj to tuple of values'''
712 l = self.tuple_from_obj(obj)
713 self._insert(l)
714 def get_insert_id(self):
715 'get the primary key value for the last INSERT'
716 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
717 auto_id = self.cursor.lastrowid
718 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
719 raise NotImplementedError('''your db lacks lastrowid support?''')
720 if auto_id is None:
721 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
722 return auto_id
723 def new(self, **kwargs):
724 'return a new record with the assigned attributes, added to DB'
725 if not self.writeable:
726 raise ValueError('this database is read only!')
727 obj = self.itemClass(None, newRow=True, **kwargs) # saves itself to db
728 self._weakValueDict[obj.id] = obj # AND SAVE TO OUR LOCAL DICT CACHE
729 return obj
730 def clear_cache(self):
731 'empty the cache'
732 self._weakValueDict.clear()
733 def __delitem__(self, k):
734 if not self.writeable:
735 raise ValueError('this database is read only!')
736 sql,params = self._format_query('delete from %s where %s=%%s'
737 % (self.name,self.primary_key),(k,))
738 self.cursor.execute(sql, params)
739 try:
740 del self._weakValueDict[k]
741 except KeyError:
742 pass
744 def getKeys(self,queryOption='', selectCols=None):
745 'uses db select; does not force load'
746 if selectCols is None:
747 selectCols=self.primary_key
748 if queryOption=='' and self.orderBy is not None:
749 queryOption = self.orderBy # apply default ordering
750 self.cursor.execute('select %s from %s %s'
751 %(selectCols,self.name,queryOption))
752 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
754 def iter_keys(self, selectCols=None):
755 'guarantee correct iteration insulated from other queries'
756 if selectCols is None:
757 selectCols=self.primary_key
758 cursor = self.get_new_cursor()
759 if cursor: # got our own cursor, guaranteeing query isolation
760 self._select(cursor=cursor, selectCols=selectCols)
761 return self.generic_iterator(cursor=cursor,
762 cache_f=lambda x:[t[0] for t in x])
763 else: # must pre-fetch all keys to ensure query isolation
764 return iter(self.keys())
766 class SQLTable(SQLTableBase):
767 "Provide on-the-fly access to rows in the database, caching the results in dict"
768 itemClass = TupleO # our default itemClass; constructor can override
769 keys=getKeys
770 __iter__ = iter_keys
771 def load(self,oclass=None):
772 "Load all data from the table"
773 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
774 return self._isLoaded
775 except AttributeError:
776 pass
777 if oclass is None:
778 oclass=self.itemClass
779 self.cursor.execute('select * from %s' % self.name)
780 l=self.cursor.fetchall()
781 self._weakValueDict = {} # just store the whole dataset in memory
782 for t in l:
783 self.cacheItem(t,oclass) # CACHE IT IN LOCAL DICTIONARY
784 self._isLoaded=True # MARK THIS CONTAINER AS FULLY LOADED
786 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
787 try:
788 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
789 except KeyError: # NOT FOUND, SO TRY THE DATABASE
790 sql,params = self._format_query('select * from %s where %s=%%s limit 2'
791 % (self.name,self.primary_key),(k,))
792 self.cursor.execute(sql, params)
793 l = self.cursor.fetchmany(2) # get at most 2 rows
794 if len(l) != 1:
795 raise KeyError('%s not found in %s, or not unique' %(str(k),self.name))
796 self.limit_cache()
797 return self.cacheItem(l[0],self.itemClass) # CACHE IT IN LOCAL DICTIONARY
798 def __setitem__(self, k, v):
799 if not self.writeable:
800 raise ValueError('this database is read only!')
801 try:
802 if v.db != self:
803 raise AttributeError
804 except AttributeError:
805 raise ValueError('object not bound to itemClass for this db!')
806 try:
807 oldID = v.id
808 if oldID is None:
809 raise AttributeError
810 except AttributeError:
811 pass
812 else: # delete row with old ID
813 del self[v.id]
814 v.cache_id(k) # cache the new ID on the object
815 self.insert(v) # SAVE TO THE RELATIONAL DB SERVER
816 self._weakValueDict[k] = v # CACHE THIS ITEM IN OUR DICTIONARY
817 def items(self):
818 'forces load of entire table into memory'
819 self.load()
820 return self._weakValueDict.items()
821 def iteritems(self):
822 'uses arraysize / maxCache and fetchmany() to manage data transfer'
823 cursor = self.get_new_cursor()
824 self._select(cursor=cursor)
825 return self.generic_iterator(cursor=cursor, map_f=generate_items)
826 def values(self):
827 'forces load of entire table into memory'
828 self.load()
829 return self._weakValueDict.values()
830 def itervalues(self):
831 'uses arraysize / maxCache and fetchmany() to manage data transfer'
832 cursor = self.get_new_cursor()
833 self._select(cursor=cursor)
834 return self.generic_iterator(cursor=cursor)
836 def getClusterKeys(self,queryOption=''):
837 'uses db select; does not force load'
838 self.cursor.execute('select distinct %s from %s %s'
839 %(self.clusterKey,self.name,queryOption))
840 return [t[0] for t in self.cursor.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
843 class SQLTableClustered(SQLTable):
844 '''use clusterKey to load a whole cluster of rows at once,
845 specifically, all rows that share the same clusterKey value.'''
846 def __init__(self, *args, **kwargs):
847 kwargs = kwargs.copy() # get a copy we can alter
848 kwargs['autoGC'] = False # don't use WeakValueDictionary
849 SQLTable.__init__(self, *args, **kwargs)
850 def keys(self):
851 return getKeys(self,'order by %s' %self.clusterKey)
852 def clusterkeys(self):
853 return getClusterKeys(self, 'order by %s' %self.clusterKey)
854 def __getitem__(self,k):
855 try:
856 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
857 except KeyError: # NOT FOUND, SO TRY THE DATABASE
858 sql,params = self._format_query('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
859 % (self.name,self.name,self.primary_key,
860 self.clusterKey,self.clusterKey),(k,))
861 self.cursor.execute(sql, params)
862 l=self.cursor.fetchall()
863 self.limit_cache()
864 for t in l: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
865 self.cacheItem(t,self.itemClass)
866 return self._weakValueDict[k] # should be in cache, if row k exists
867 def itercluster(self,cluster_id):
868 'iterate over all items from the specified cluster'
869 self.limit_cache()
870 return self.select('where %s=%%s'%self.clusterKey,(cluster_id,))
871 def fetch_cluster(self):
872 'use self.cursor.fetchmany to obtain all rows for next cluster'
873 icol = self._attrSQL(self.clusterKey,columnNumber=True)
874 result = []
875 try:
876 rows = self._fetch_cluster_cache # USE SAVED ROWS FROM PREVIOUS CALL
877 del self._fetch_cluster_cache
878 except AttributeError:
879 rows = self.cursor.fetchmany()
880 try:
881 cluster_id = rows[0][icol]
882 except IndexError:
883 return result
884 while len(rows)>0:
885 for i,t in enumerate(rows): # CHECK THAT ALL ROWS FROM THIS CLUSTER
886 if cluster_id != t[icol]: # START OF A NEW CLUSTER
887 result += rows[:i] # RETURN ROWS OF CURRENT CLUSTER
888 self._fetch_cluster_cache = rows[i:] # SAVE NEXT CLUSTER
889 return result
890 result += rows
891 rows = self.cursor.fetchmany() # GET NEXT SET OF ROWS
892 return result
893 def itervalues(self):
894 'uses arraysize / maxCache and fetchmany() to manage data transfer'
895 cursor = self.get_new_cursor()
896 self._select('order by %s' %self.clusterKey, cursor=cursor)
897 return self.generic_iterator(cursor, self.fetch_cluster)
898 def iteritems(self):
899 'uses arraysize / maxCache and fetchmany() to manage data transfer'
900 cursor = self.get_new_cursor()
901 self._select('order by %s' %self.clusterKey, cursor=cursor)
902 return self.generic_iterator(cursor, self.fetch_cluster,
903 map_f=generate_items)
905 class SQLForeignRelation(object):
906 'mapping based on matching a foreign key in an SQL table'
907 def __init__(self,table,keyName):
908 self.table=table
909 self.keyName=keyName
910 def __getitem__(self,k):
911 'get list of objects o with getattr(o,keyName)==k.id'
912 l=[]
913 for o in self.table.select('where %s=%%s'%self.keyName,(k.id,)):
914 l.append(o)
915 if len(l)==0:
916 raise KeyError('%s not found in %s' %(str(k),self.name))
917 return l
920 class SQLTableNoCache(SQLTableBase):
921 '''Provide on-the-fly access to rows in the database;
922 values are simply an object interface (SQLRow) to back-end db query.
923 Row data are not stored locally, but always accessed by querying the db'''
924 itemClass=SQLRow # DEFAULT OBJECT CLASS FOR ROWS...
925 keys=getKeys
926 __iter__ = iter_keys
927 def getID(self,t): return t[0] # GET ID FROM TUPLE
928 def select(self,whereClause,params):
929 return SQLTableBase.select(self,whereClause,params,self.oclass,
930 self._attrSQL('id'))
931 def __getitem__(self,k): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
932 try:
933 return self._weakValueDict[k] # DIRECTLY RETURN CACHED VALUE
934 except KeyError: # NOT FOUND, SO TRY THE DATABASE
935 self._select('where %s=%%s' % self.primary_key, (k,),
936 self.primary_key)
937 t = self.cursor.fetchmany(2)
938 if len(t) != 1:
939 raise KeyError('id %s non-existent or not unique' % k)
940 o = self.itemClass(k) # create obj referencing this ID
941 self._weakValueDict[k] = o # cache the SQLRow object
942 return o
943 def __setitem__(self, k, v):
944 if not self.writeable:
945 raise ValueError('this database is read only!')
946 try:
947 if v.db != self:
948 raise AttributeError
949 except AttributeError:
950 raise ValueError('object not bound to itemClass for this db!')
951 try:
952 del self[k] # delete row with new ID if any
953 except KeyError:
954 pass
955 try:
956 del self._weakValueDict[v.id] # delete from old cache location
957 except KeyError:
958 pass
959 self._update(v.id, self.primary_key, k) # just change its ID in db
960 v.cache_id(k) # change the cached ID value
961 self._weakValueDict[k] = v # assign to new cache location
962 def addAttrAlias(self,**kwargs):
963 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
965 SQLRow._tableclass=SQLTableNoCache # SQLRow IS FOR NON-CACHING TABLE INTERFACE
968 class SQLTableMultiNoCache(SQLTableBase):
969 "Trivial on-the-fly access for table with key that returns multiple rows"
970 itemClass = TupleO # default itemClass; constructor can override
971 _distinct_key='id' # DEFAULT COLUMN TO USE AS KEY
972 def keys(self):
973 return getKeys(self, selectCols='distinct(%s)'
974 % self._attrSQL(self._distinct_key))
975 def __iter__(self):
976 return iter_keys(self, 'distinct(%s)' % self._attrSQL(self._distinct_key))
977 def __getitem__(self,id):
978 sql,params = self._format_query('select * from %s where %s=%%s'
979 %(self.name,self._attrSQL(self._distinct_key)),(id,))
980 self.cursor.execute(sql, params)
981 l=self.cursor.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
982 for row in l:
983 yield self.itemClass(row)
984 def addAttrAlias(self,**kwargs):
985 self.data.update(kwargs) # ALIAS KEYS TO EXPRESSION VALUES
989 class SQLEdges(SQLTableMultiNoCache):
990 '''provide iterator over edges as (source,target,edge)
991 and getitem[edge] --> [(source,target),...]'''
992 _distinct_key='edge_id'
993 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
994 _pickleAttrs.update(dict(graph=0))
995 def keys(self):
996 self.cursor.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
997 %(self._attrSQL('source_id'),self._attrSQL('target_id'),
998 self._attrSQL('edge_id'),self.name,
999 self._attrSQL('target_id'),self._attrSQL('source_id'),
1000 self._attrSQL('target_id')))
1001 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1002 for source_id,target_id,edge_id in self.cursor.fetchall():
1003 l.append((self.graph.unpack_source(source_id),
1004 self.graph.unpack_target(target_id),
1005 self.graph.unpack_edge(edge_id)))
1006 return l
1007 __call__=keys
1008 def __iter__(self):
1009 return iter(self.keys())
1010 def __getitem__(self,edge):
1011 sql,params = self._format_query('select %s,%s from %s where %s=%%s'
1012 %(self._attrSQL('source_id'),
1013 self._attrSQL('target_id'),
1014 self.name,
1015 self._attrSQL(self._distinct_key)),
1016 (self.graph.pack_edge(edge),))
1017 self.cursor.execute(sql, params)
1018 l = [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1019 for source_id,target_id in self.cursor.fetchall():
1020 l.append((self.graph.unpack_source(source_id),
1021 self.graph.unpack_target(target_id)))
1022 return l
1025 class SQLEdgeDict(object):
1026 '2nd level graph interface to SQL database'
1027 def __init__(self,fromNode,table):
1028 self.fromNode=fromNode
1029 self.table=table
1030 if not hasattr(self.table,'allowMissingNodes'):
1031 sql,params = self.table._format_query('select %s from %s where %s=%%s limit 1'
1032 %(self.table.sourceSQL,
1033 self.table.name,
1034 self.table.sourceSQL),
1035 (self.fromNode,))
1036 self.table.cursor.execute(sql, params)
1037 if len(self.table.cursor.fetchall())<1:
1038 raise KeyError('node not in graph!')
1040 def __getitem__(self,target):
1041 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1042 %(self.table.edgeSQL,
1043 self.table.name,
1044 self.table.sourceSQL,
1045 self.table.targetSQL),
1046 (self.fromNode,
1047 self.table.pack_target(target)))
1048 self.table.cursor.execute(sql, params)
1049 l = self.table.cursor.fetchmany(2) # get at most two rows
1050 if len(l) != 1:
1051 raise KeyError('either no edge from source to target or not unique!')
1052 try:
1053 return self.table.unpack_edge(l[0][0]) # RETURN EDGE
1054 except IndexError:
1055 raise KeyError('no edge from node to target')
1056 def __setitem__(self,target,edge):
1057 sql,params = self.table._format_query('replace into %s values (%%s,%%s,%%s)'
1058 %self.table.name,
1059 (self.fromNode,
1060 self.table.pack_target(target),
1061 self.table.pack_edge(edge)))
1062 self.table.cursor.execute(sql, params)
1063 if not hasattr(self.table,'sourceDB') or \
1064 (hasattr(self.table,'targetDB') and self.table.sourceDB==self.table.targetDB):
1065 self.table += target # ADD AS NODE TO GRAPH
1066 def __iadd__(self,target):
1067 self[target] = None
1068 return self # iadd MUST RETURN self!
1069 def __delitem__(self,target):
1070 sql,params = self.table._format_query('delete from %s where %s=%%s and %s=%%s'
1071 %(self.table.name,
1072 self.table.sourceSQL,
1073 self.table.targetSQL),
1074 (self.fromNode,
1075 self.table.pack_target(target)))
1076 self.table.cursor.execute(sql, params)
1077 if self.table.cursor.rowcount < 1: # no rows deleted?
1078 raise KeyError('no edge from node to target')
1080 def iterator_query(self):
1081 sql,params = self.table._format_query('select %s,%s from %s where %s=%%s and %s is not null'
1082 %(self.table.targetSQL,
1083 self.table.edgeSQL,
1084 self.table.name,
1085 self.table.sourceSQL,
1086 self.table.targetSQL),
1087 (self.fromNode,))
1088 self.table.cursor.execute(sql, params)
1089 return self.table.cursor.fetchall()
1090 def keys(self):
1091 return [self.table.unpack_target(target_id)
1092 for target_id,edge_id in self.iterator_query()]
1093 def values(self):
1094 return [self.table.unpack_edge(edge_id)
1095 for target_id,edge_id in self.iterator_query()]
1096 def edges(self):
1097 return [(self.table.unpack_source(self.fromNode),self.table.unpack_target(target_id),
1098 self.table.unpack_edge(edge_id))
1099 for target_id,edge_id in self.iterator_query()]
1100 def items(self):
1101 return [(self.table.unpack_target(target_id),self.table.unpack_edge(edge_id))
1102 for target_id,edge_id in self.iterator_query()]
1103 def __iter__(self): return iter(self.keys())
1104 def itervalues(self): return iter(self.values())
1105 def iteritems(self): return iter(self.items())
1106 def __len__(self):
1107 return len(self.keys())
1108 __cmp__ = graph_cmp
1110 class SQLEdgelessDict(SQLEdgeDict):
1111 'for SQLGraph tables that lack edge_id column'
1112 def __getitem__(self,target):
1113 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s=%%s limit 2'
1114 %(self.table.targetSQL,
1115 self.table.name,
1116 self.table.sourceSQL,
1117 self.table.targetSQL),
1118 (self.fromNode,
1119 self.table.pack_target(target)))
1120 self.table.cursor.execute(sql, params)
1121 l = self.table.cursor.fetchmany(2)
1122 if len(l) != 1:
1123 raise KeyError('either no edge from source to target or not unique!')
1124 return None # no edge info!
1125 def iterator_query(self):
1126 sql,params = self.table._format_query('select %s from %s where %s=%%s and %s is not null'
1127 %(self.table.targetSQL,
1128 self.table.name,
1129 self.table.sourceSQL,
1130 self.table.targetSQL),
1131 (self.fromNode,))
1132 self.table.cursor.execute(sql, params)
1133 return [(t[0],None) for t in self.table.cursor.fetchall()]
1135 SQLEdgeDict._edgelessClass = SQLEdgelessDict
1137 class SQLGraphEdgeDescriptor(object):
1138 'provide an SQLEdges interface on demand'
1139 def __get__(self,obj,objtype):
1140 try:
1141 attrAlias=obj.attrAlias.copy()
1142 except AttributeError:
1143 return SQLEdges(obj.name, obj.cursor, graph=obj)
1144 else:
1145 return SQLEdges(obj.name, obj.cursor, attrAlias=attrAlias,
1146 graph=obj)
1148 def getColumnTypes(createTable,attrAlias={},defaultColumnType='int',
1149 columnAttrs=('source','target','edge'),**kwargs):
1150 'return list of [(colname,coltype),...] for source,target,edge'
1151 l = []
1152 for attr in columnAttrs:
1153 try:
1154 attrName = attrAlias[attr+'_id']
1155 except KeyError:
1156 attrName = attr+'_id'
1157 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1158 l.append((attrName,createTable[attr+'_id']))
1159 continue
1160 except (KeyError,TypeError):
1161 pass
1162 try: # get type info from primary key for that database
1163 db = kwargs[attr+'DB']
1164 if db is None:
1165 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1166 except KeyError:
1167 pass
1168 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1169 it = iter(db)
1170 try: # GET ONE IDENTIFIER FROM THE DATABASE
1171 k = it.next()
1172 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1173 try:
1174 l.append((attrName,db.columnType[db.primary_key]))
1175 continue
1176 except AttributeError:
1177 pass
1178 else: # GET THE TYPE FROM THIS IDENTIFIER
1179 if isinstance(k,int) or isinstance(k,long):
1180 l.append((attrName,'int'))
1181 continue
1182 elif isinstance(k,str):
1183 l.append((attrName,'varchar(32)'))
1184 continue
1185 else:
1186 raise ValueError('SQLGraph node / edge must be int or str!')
1187 l.append((attrName,defaultColumnType))
1188 logger.warn('no type info found for %s, so using default: %s'
1189 % (attrName, defaultColumnType))
1192 return l
1195 class SQLGraph(SQLTableMultiNoCache):
1196 '''provide a graph interface via a SQL table. Key capabilities are:
1197 - setitem with an empty dictionary: a dummy operation
1198 - getitem with a key that exists: return a placeholder
1199 - setitem with non empty placeholder: again a dummy operation
1200 EXAMPLE TABLE SCHEMA:
1201 create table mygraph (source_id int not null,target_id int,edge_id int,
1202 unique(source_id,target_id));
1204 _distinct_key='source_id'
1205 _pickleAttrs = SQLTableMultiNoCache._pickleAttrs.copy()
1206 _pickleAttrs.update(dict(sourceDB=0,targetDB=0,edgeDB=0,allowMissingNodes=0))
1207 _edgeClass = SQLEdgeDict
1208 def __init__(self,name,*l,**kwargs):
1209 graphArgs,tableArgs = split_kwargs(kwargs,
1210 ('attrAlias','defaultColumnType','columnAttrs',
1211 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1212 'edgeDictClass','graph'))
1213 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1214 c = getColumnTypes(**kwargs)
1215 tableArgs['createTable'] = \
1216 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1217 % (name,c[0][0],c[0][1],c[1][0],c[1][1],c[2][0],c[2][1],c[0][0],c[1][0])
1218 try:
1219 self.allowMissingNodes = kwargs['allowMissingNodes']
1220 except KeyError: pass
1221 SQLTableMultiNoCache.__init__(self,name,*l,**tableArgs)
1222 self.sourceSQL = self._attrSQL('source_id')
1223 self.targetSQL = self._attrSQL('target_id')
1224 try:
1225 self.edgeSQL = self._attrSQL('edge_id')
1226 except AttributeError:
1227 self.edgeSQL = None
1228 self._edgeClass = self._edgeClass._edgelessClass
1229 save_graph_db_refs(self,**kwargs)
1230 def __getitem__(self,k):
1231 return self._edgeClass(self.pack_source(k),self)
1232 def __iadd__(self,k):
1233 sql,params = self._format_query('delete from %s where %s=%%s and %s is null'
1234 % (self.name,self.sourceSQL,self.targetSQL),
1235 (self.pack_source(k),))
1236 self.cursor.execute(sql, params)
1237 sql,params = self._format_query('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1238 % self.name,(self.pack_source(k),))
1239 self.cursor.execute(sql, params)
1240 return self # iadd MUST RETURN SELF!
1241 def __isub__(self,k):
1242 sql,params = self._format_query('delete from %s where %s=%%s'
1243 % (self.name,self.sourceSQL),
1244 (self.pack_source(k),))
1245 self.cursor.execute(sql, params)
1246 if self.cursor.rowcount == 0:
1247 raise KeyError('node not found in graph')
1248 return self # iadd MUST RETURN SELF!
1249 __setitem__ = graph_setitem
1250 def __contains__(self,k):
1251 sql,params = self._format_query('select * from %s where %s=%%s limit 1'
1252 %(self.name,self.sourceSQL),
1253 (self.pack_source(k),))
1254 self.cursor.execute(sql, params)
1255 l = self.cursor.fetchmany(2)
1256 return len(l) > 0
1257 def __invert__(self):
1258 'get an interface to the inverse graph mapping'
1259 try: # CACHED
1260 return self._inverse
1261 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1262 attrAlias = dict(source_id=self.targetSQL, # SWAP SOURCE & TARGET
1263 target_id=self.sourceSQL,
1264 edge_id=self.edgeSQL)
1265 if self.edgeSQL is None: # no edge interface
1266 del attrAlias['edge_id']
1267 self._inverse=SQLGraph(self.name,self.cursor,
1268 attrAlias=attrAlias,
1269 **graph_db_inverse_refs(self))
1270 self._inverse._inverse=self
1271 return self._inverse
1272 def __iter__(self):
1273 for k in SQLTableMultiNoCache.__iter__(self):
1274 yield self.unpack_source(k)
1275 def iteritems(self):
1276 for k in SQLTableMultiNoCache.__iter__(self):
1277 yield (self.unpack_source(k), self._edgeClass(k, self))
1278 def itervalues(self):
1279 for k in SQLTableMultiNoCache.__iter__(self):
1280 yield self._edgeClass(k, self)
1281 def keys(self):
1282 return [self.unpack_source(k) for k in SQLTableMultiNoCache.keys(self)]
1283 def values(self): return list(self.itervalues())
1284 def items(self): return list(self.iteritems())
1285 edges=SQLGraphEdgeDescriptor()
1286 update = update_graph
1287 def __len__(self):
1288 'get number of source nodes in graph'
1289 self.cursor.execute('select count(distinct %s) from %s'
1290 %(self.sourceSQL,self.name))
1291 return self.cursor.fetchone()[0]
1292 __cmp__ = graph_cmp
1293 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1294 ## def __cmp__(self,other):
1295 ## node = ()
1296 ## n = 0
1297 ## d = None
1298 ## it = iter(self.edges)
1299 ## while True:
1300 ## try:
1301 ## source,target,edge = it.next()
1302 ## except StopIteration:
1303 ## source = None
1304 ## if source!=node:
1305 ## if d is not None:
1306 ## diff = cmp(n_target,len(d))
1307 ## if diff!=0:
1308 ## return diff
1309 ## if source is None:
1310 ## break
1311 ## node = source
1312 ## n += 1 # COUNT SOURCE NODES
1313 ## n_target = 0
1314 ## try:
1315 ## d = other[node]
1316 ## except KeyError:
1317 ## return 1
1318 ## try:
1319 ## diff = cmp(edge,d[target])
1320 ## except KeyError:
1321 ## return 1
1322 ## if diff!=0:
1323 ## return diff
1324 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1325 ## return cmp(n,len(other))
1327 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1329 class SQLIDGraph(SQLGraph):
1330 add_trivial_packing_methods(locals())
1331 SQLGraph._IDGraphClass = SQLIDGraph
1335 class SQLEdgeDictClustered(dict):
1336 'simple cache for 2nd level dictionary of target_id:edge_id'
1337 def __init__(self,g,fromNode):
1338 self.g=g
1339 self.fromNode=fromNode
1340 dict.__init__(self)
1341 def __iadd__(self,l):
1342 for target_id,edge_id in l:
1343 dict.__setitem__(self,target_id,edge_id)
1344 return self # iadd MUST RETURN SELF!
1346 class SQLEdgesClusteredDescr(object):
1347 def __get__(self,obj,objtype):
1348 e=SQLEdgesClustered(obj.table,obj.edge_id,obj.source_id,obj.target_id,
1349 graph=obj,**graph_db_inverse_refs(obj,True))
1350 for source_id,d in obj.d.iteritems(): # COPY EDGE CACHE
1351 e.load([(edge_id,source_id,target_id)
1352 for (target_id,edge_id) in d.iteritems()])
1353 return e
1355 class SQLGraphClustered(object):
1356 'SQL graph with clustered caching -- loads an entire cluster at a time'
1357 _edgeDictClass=SQLEdgeDictClustered
1358 def __init__(self,table,source_id='source_id',target_id='target_id',
1359 edge_id='edge_id',clusterKey=None,**kwargs):
1360 import types
1361 if isinstance(table,types.StringType): # CREATE THE TABLE INTERFACE
1362 if clusterKey is None:
1363 raise ValueError('you must provide a clusterKey argument!')
1364 if 'createTable' in kwargs: # CREATE A SCHEMA FOR THIS TABLE
1365 c = getColumnTypes(attrAlias=dict(source_id=source_id,target_id=target_id,
1366 edge_id=edge_id),**kwargs)
1367 kwargs['createTable'] = \
1368 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1369 % (table,c[0][0],c[0][1],c[1][0],c[1][1],
1370 c[2][0],c[2][1],c[0][0],c[1][0])
1371 table = SQLTableClustered(table,clusterKey=clusterKey,**kwargs)
1372 self.table=table
1373 self.source_id=source_id
1374 self.target_id=target_id
1375 self.edge_id=edge_id
1376 self.d={}
1377 save_graph_db_refs(self,**kwargs)
1378 _pickleAttrs = dict(table=0,source_id=0,target_id=0,edge_id=0,sourceDB=0,targetDB=0,
1379 edgeDB=0)
1380 def __getstate__(self):
1381 state = standard_getstate(self)
1382 state['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1383 return state
1384 def __getitem__(self,k):
1385 'get edgeDict for source node k, from cache or by loading its cluster'
1386 try: # GET DIRECTLY FROM CACHE
1387 return self.d[k]
1388 except KeyError:
1389 if hasattr(self,'_isLoaded'):
1390 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1391 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1392 sql,params = self.table._format_query('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1393 %(self.source_id,self.target_id,
1394 self.edge_id,self.table.name,
1395 self.table.name,self.source_id,
1396 self.table.clusterKey,self.table.clusterKey,
1397 self.table.primary_key),
1398 (self.pack_source(k),))
1399 self.table.cursor.execute(sql, params)
1400 self.load(self.table.cursor.fetchall()) # CACHE THIS CLUSTER
1401 return self.d[k] # RETURN EDGE DICT FOR THIS NODE
1402 def load(self,l=None,unpack=True):
1403 'load the specified rows (or all, if None provided) into local cache'
1404 if l is None:
1405 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1406 return self._isLoaded
1407 except AttributeError:
1408 pass
1409 self.table.cursor.execute('select %s,%s,%s from %s'
1410 %(self.source_id,self.target_id,
1411 self.edge_id,self.table.name))
1412 l=self.table.cursor.fetchall()
1413 self._isLoaded=True
1414 self.d.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1415 for source,target,edge in l: # SAVE TO OUR CACHE
1416 if unpack:
1417 source = self.unpack_source(source)
1418 target = self.unpack_target(target)
1419 edge = self.unpack_edge(edge)
1420 try:
1421 self.d[source] += [(target,edge)]
1422 except KeyError:
1423 d = self._edgeDictClass(self,source)
1424 d += [(target,edge)]
1425 self.d[source] = d
1426 def __invert__(self):
1427 'interface to reverse graph mapping'
1428 try:
1429 return self._inverse # INVERSE MAP ALREADY EXISTS
1430 except AttributeError:
1431 pass
1432 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1433 self._inverse=SQLGraphClustered(self.table,self.target_id,self.source_id,
1434 self.edge_id,**graph_db_inverse_refs(self))
1435 self._inverse._inverse=self
1436 for source,d in self.d.iteritems(): # INVERT OUR CACHE
1437 self._inverse.load([(target,source,edge)
1438 for (target,edge) in d.iteritems()],unpack=False)
1439 return self._inverse
1440 edges=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1441 update = update_graph
1442 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1443 def __iter__(self): ################# ITERATORS
1444 'uses db select; does not force load'
1445 return iter(self.keys())
1446 def keys(self):
1447 'uses db select; does not force load'
1448 self.table.cursor.execute('select distinct(%s) from %s'
1449 %(self.source_id,self.table.name))
1450 return [self.unpack_source(t[0])
1451 for t in self.table.cursor.fetchall()]
1452 methodFactory(['iteritems','items','itervalues','values'],
1453 'lambda self:(self.load(),self.d.%s())[1]',locals())
1454 def __contains__(self,k):
1455 try:
1456 x=self[k]
1457 return True
1458 except KeyError:
1459 return False
1461 class SQLIDGraphClustered(SQLGraphClustered):
1462 add_trivial_packing_methods(locals())
1463 SQLGraphClustered._IDGraphClass = SQLIDGraphClustered
1465 class SQLEdgesClustered(SQLGraphClustered):
1466 'edges interface for SQLGraphClustered'
1467 _edgeDictClass = list
1468 _pickleAttrs = SQLGraphClustered._pickleAttrs.copy()
1469 _pickleAttrs.update(dict(graph=0))
1470 def keys(self):
1471 self.load()
1472 result = []
1473 for edge_id,l in self.d.iteritems():
1474 for source_id,target_id in l:
1475 result.append((self.graph.unpack_source(source_id),
1476 self.graph.unpack_target(target_id),
1477 self.graph.unpack_edge(edge_id)))
1478 return result
1480 class ForeignKeyInverse(object):
1481 'map each key to a single value according to its foreign key'
1482 def __init__(self,g):
1483 self.g = g
1484 def __getitem__(self,obj):
1485 self.check_obj(obj)
1486 source_id = getattr(obj,self.g.keyColumn)
1487 if source_id is None:
1488 return None
1489 return self.g.sourceDB[source_id]
1490 def __setitem__(self,obj,source):
1491 self.check_obj(obj)
1492 if source is not None:
1493 self.g[source][obj] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1494 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1495 if not hasattr(obj,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1496 old_source = self[obj]
1497 if old_source is not None:
1498 del self.g[old_source][obj]
1499 def check_obj(self,obj):
1500 'raise KeyError if obj not from this db'
1501 try:
1502 if obj.db != self.g.targetDB:
1503 raise AttributeError
1504 except AttributeError:
1505 raise KeyError('key is not from targetDB of this graph!')
1506 def __contains__(self,obj):
1507 try:
1508 self.check_obj(obj)
1509 return True
1510 except KeyError:
1511 return False
1512 def __iter__(self):
1513 return self.g.targetDB.itervalues()
1514 def keys(self):
1515 return self.g.targetDB.values()
1516 def iteritems(self):
1517 for obj in self:
1518 source_id = getattr(obj,self.g.keyColumn)
1519 if source_id is None:
1520 yield obj,None
1521 else:
1522 yield obj,self.g.sourceDB[source_id]
1523 def items(self):
1524 return list(self.iteritems())
1525 def itervalues(self):
1526 for obj,val in self.iteritems():
1527 yield val
1528 def values(self):
1529 return list(self.itervalues())
1530 def __invert__(self):
1531 return self.g
1533 class ForeignKeyEdge(dict):
1534 '''edge interface to a foreign key in an SQL table.
1535 Caches dict of target nodes in itself; provides dict interface.
1536 Adds or deletes edges by setting foreign key values in the table'''
1537 def __init__(self,g,k):
1538 dict.__init__(self)
1539 self.g = g
1540 self.src = k
1541 for v in g.targetDB.select('where %s=%%s' % g.keyColumn,(k.id,)): # SEARCH THE DB
1542 dict.__setitem__(self,v,None) # SAVE IN CACHE
1543 def __setitem__(self,dest,v):
1544 if not hasattr(dest,'db') or dest.db != self.g.targetDB:
1545 raise KeyError('dest is not in the targetDB bound to this graph!')
1546 if v is not None:
1547 raise ValueError('sorry,this graph cannot store edge information!')
1548 if not hasattr(dest,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1549 old_source = self.g._inverse[dest] # CHECK FOR PRE-EXISTING EDGE
1550 if old_source is not None: # REMOVE OLD EDGE FROM CACHE
1551 dict.__delitem__(self.g[old_source],dest)
1552 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1553 setattr(dest,self.g.keyColumn,self.src.id) # SAVE TO DB ATTRIBUTE
1554 dict.__setitem__(self,dest,None) # SAVE IN CACHE
1555 def __delitem__(self,dest):
1556 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1557 setattr(dest,self.g.keyColumn,None) # SAVE TO DB ATTRIBUTE
1558 dict.__delitem__(self,dest) # REMOVE FROM CACHE
1560 class ForeignKeyGraph(object, UserDict.DictMixin):
1561 '''graph interface to a foreign key in an SQL table
1562 Caches dict of target nodes in itself; provides dict interface.
1564 def __init__(self, sourceDB, targetDB, keyColumn, autoGC=True, **kwargs):
1565 '''sourceDB is any database of source nodes;
1566 targetDB must be an SQL database of target nodes;
1567 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1568 if autoGC: # automatically garbage collect unused objects
1569 self._weakValueDict = RecentValueDictionary(autoGC) # object cache
1570 else:
1571 self._weakValueDict = {}
1572 self.autoGC = autoGC
1573 self.sourceDB = sourceDB
1574 self.targetDB = targetDB
1575 self.keyColumn = keyColumn
1576 self._inverse = ForeignKeyInverse(self)
1577 _pickleAttrs = dict(sourceDB=0, targetDB=0, keyColumn=0, autoGC=0)
1578 __getstate__ = standard_getstate ########### SUPPORT FOR PICKLING
1579 __setstate__ = standard_setstate
1580 def _inverse_schema(self):
1581 'provide custom schema rule for inverting this graph... just use keyColumn!'
1582 return dict(invert=True,uniqueMapping=True)
1583 def __getitem__(self,k):
1584 if not hasattr(k,'db') or k.db != self.sourceDB:
1585 raise KeyError('object is not in the sourceDB bound to this graph!')
1586 try:
1587 return self._weakValueDict[k.id] # get from cache
1588 except KeyError:
1589 pass
1590 d = ForeignKeyEdge(self,k)
1591 self._weakValueDict[k.id] = d # save in cache
1592 return d
1593 def __setitem__(self, k, v):
1594 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1595 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1596 def __delitem__(self, k):
1597 raise KeyError('''Instead of del g[k], follow a graph
1598 interface: del g[src][dest]''')
1599 def keys(self):
1600 return self.sourceDB.values()
1601 __invert__ = standard_invert
1603 def describeDBTables(name,cursor,idDict):
1605 Get table info about database <name> via <cursor>, and store primary keys
1606 in idDict, along with a list of the tables each key indexes.
1608 cursor.execute('use %s' % name)
1609 cursor.execute('show tables')
1610 tables={}
1611 l=[c[0] for c in cursor.fetchall()]
1612 for t in l:
1613 tname=name+'.'+t
1614 o=SQLTable(tname,cursor)
1615 tables[tname]=o
1616 for f in o.description:
1617 if f==o.primary_key:
1618 idDict.setdefault(f, []).append(o)
1619 elif f[-3:]=='_id' and f not in idDict:
1620 idDict[f]=[]
1621 return tables
1625 def indexIDs(tables,idDict=None):
1626 "Get an index of primary keys in the <tables> dictionary."
1627 if idDict==None:
1628 idDict={}
1629 for o in tables.values():
1630 if o.primary_key:
1631 if o.primary_key not in idDict:
1632 idDict[o.primary_key]=[]
1633 idDict[o.primary_key].append(o) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1634 for f in o.description:
1635 if f[-3:]=='_id' and f not in idDict:
1636 idDict[f]=[]
1637 return idDict
1641 def suffixSubset(tables,suffix):
1642 "Filter table index for those matching a specific suffix"
1643 subset={}
1644 for name,t in tables.items():
1645 if name.endswith(suffix):
1646 subset[name]=t
1647 return subset
1650 PRIMARY_KEY=1
1652 def graphDBTables(tables,idDict):
1653 g=dictgraph()
1654 for t in tables.values():
1655 for f in t.description:
1656 if f==t.primary_key:
1657 edgeInfo=PRIMARY_KEY
1658 else:
1659 edgeInfo=None
1660 g.setEdge(f,t,edgeInfo)
1661 g.setEdge(t,f,edgeInfo)
1662 return g
1664 SQLTypeTranslation= {types.StringType:'varchar(32)',
1665 types.IntType:'int',
1666 types.FloatType:'float'}
1668 def createTableFromRepr(rows,tableName,cursor,typeTranslation=None,
1669 optionalDict=None,indexDict=()):
1670 """Save rows into SQL tableName using cursor, with optional
1671 translations of columns to specific SQL types (specified
1672 by typeTranslation dict).
1673 - optionDict can specify columns that are allowed to be NULL.
1674 - indexDict can specify columns that must be indexed; columns
1675 whose names end in _id will be indexed by default.
1676 - rows must be an iterator which in turn returns dictionaries,
1677 each representing a tuple of values (indexed by their column
1678 names).
1680 try:
1681 row=rows.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1682 except StopIteration:
1683 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1684 try:
1685 createTableFromRow(cursor, tableName,row,typeTranslation,
1686 optionalDict,indexDict)
1687 except:
1688 pass
1689 storeRow(cursor,tableName,row) # SAVE OUR FIRST ROW
1690 for row in rows: # NOW SAVE ALL THE ROWS
1691 storeRow(cursor,tableName,row)
1693 def createTableFromRow(cursor, tableName, row,typeTranslation=None,
1694 optionalDict=None,indexDict=()):
1695 create_defs=[]
1696 for col,val in row.items(): # PREPARE SQL TYPES FOR COLUMNS
1697 coltype=None
1698 if typeTranslation!=None and col in typeTranslation:
1699 coltype=typeTranslation[col] # USER-SUPPLIED TRANSLATION
1700 elif type(val) in SQLTypeTranslation:
1701 coltype=SQLTypeTranslation[type(val)]
1702 else: # SEARCH FOR A COMPATIBLE TYPE
1703 for t in SQLTypeTranslation:
1704 if isinstance(val,t):
1705 coltype=SQLTypeTranslation[t]
1706 break
1707 if coltype==None:
1708 raise TypeError("Don't know SQL type to use for %s" % col)
1709 create_def='%s %s' %(col,coltype)
1710 if optionalDict==None or col not in optionalDict:
1711 create_def+=' not null'
1712 create_defs.append(create_def)
1713 for col in row: # CREATE INDEXES FOR ID COLUMNS
1714 if col[-3:]=='_id' or col in indexDict:
1715 create_defs.append('index(%s)' % col)
1716 cmd='create table if not exists %s (%s)' % (tableName,','.join(create_defs))
1717 cursor.execute(cmd) # CREATE THE TABLE IN THE DATABASE
1720 def storeRow(cursor, tableName, row):
1721 row_format=','.join(len(row)*['%s'])
1722 cmd='insert into %s values (%s)' % (tableName,row_format)
1723 cursor.execute(cmd,tuple(row.values()))
1725 def storeRowDelayed(cursor, tableName, row):
1726 row_format=','.join(len(row)*['%s'])
1727 cmd='insert delayed into %s values (%s)' % (tableName,row_format)
1728 cursor.execute(cmd,tuple(row.values()))
1731 class TableGroup(dict):
1732 'provide attribute access to dbname qualified tablenames'
1733 def __init__(self,db='test',suffix=None,**kw):
1734 dict.__init__(self)
1735 self.db=db
1736 if suffix is not None:
1737 self.suffix=suffix
1738 for k,v in kw.items():
1739 if v is not None and '.' not in v:
1740 v=self.db+'.'+v # ADD DATABASE NAME AS PREFIX
1741 self[k]=v
1742 def __getattr__(self,k):
1743 return self[k]
1745 def sqlite_connect(*args, **kwargs):
1746 sqlite = import_sqlite()
1747 connection = sqlite.connect(*args, **kwargs)
1748 cursor = connection.cursor()
1749 return connection, cursor
1751 # list of database connection functions DBServerInfo knows how to use
1752 _DBServerModuleDict = dict(MySQLdb=mysql_connect, sqlite=sqlite_connect)
1755 class DBServerInfo(object):
1756 'picklable reference to a database server'
1757 def __init__(self, moduleName='MySQLdb', *args, **kwargs):
1758 if moduleName not in _DBServerModuleDict:
1759 raise ValueError('Module name not found in _DBServerModuleDict: '\
1760 + moduleName)
1761 self.moduleName = moduleName
1762 self.args = args
1763 self.kwargs = kwargs # connection arguments
1765 def cursor(self):
1766 """returns cursor associated with the DB server info (reused)"""
1767 try:
1768 return self._cursor
1769 except AttributeError:
1770 self._start_connection()
1771 return self._cursor
1773 def new_cursor(self):
1774 """returns a NEW cursor; you must close it yourself! """
1775 if not hasattr(self, '_connection'):
1776 self._start_connection()
1777 return self._connection.cursor()
1779 def _start_connection(self):
1780 try:
1781 moduleName = self.moduleName
1782 except AttributeError:
1783 moduleName = 'MySQLdb'
1784 connect = _DBServerModuleDict[moduleName]
1785 self._connection,self._cursor = connect(*self.args, **self.kwargs)
1787 def close(self):
1788 """Close file containing this database"""
1789 self._cursor.close()
1790 self._connection.close()
1791 del self._cursor
1792 del self._connection
1794 def __getstate__(self):
1795 """return all picklable arguments"""
1796 return dict(args=self.args, kwargs=self.kwargs,
1797 moduleName=self.moduleName)
1800 class SQLiteServerInfo(DBServerInfo):
1801 """picklable reference to a sqlite database"""
1802 def __init__(self, database, *args, **kwargs):
1803 """Takes same arguments as sqlite3.connect()"""
1804 DBServerInfo.__init__(self, 'sqlite',
1805 SourceFileName(database), # save abs path!
1806 *args, **kwargs)
1807 def __getstate__(self):
1808 if self.args[0] == ':memory:':
1809 raise ValueError('SQLite in-memory database is not picklable!')
1810 return DBServerInfo.__getstate__(self)
1813 class MapView(object, UserDict.DictMixin):
1814 'general purpose 1:1 mapping defined by any SQL query'
1815 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None,
1816 serverInfo=None, inverseSQL=None, **kwargs):
1817 self.sourceDB = sourceDB
1818 self.targetDB = targetDB
1819 self.viewSQL = viewSQL
1820 self.inverseSQL = inverseSQL
1821 if cursor is None:
1822 if serverInfo is not None: # get cursor from serverInfo
1823 cursor = serverInfo.cursor()
1824 else:
1825 try: # can we get it from our other db?
1826 serverInfo = sourceDB.serverInfo
1827 except AttributeError:
1828 raise ValueError('you must provide serverInfo or cursor!')
1829 else:
1830 cursor = serverInfo.cursor()
1831 self.cursor = cursor
1832 self.serverInfo = serverInfo
1833 self.get_sql_format(False) # get sql formatter for this db interface
1834 _schemaModuleDict = _schemaModuleDict # default module list
1835 get_sql_format = get_table_schema
1836 def __getitem__(self, k):
1837 if not hasattr(k,'db') or k.db != self.sourceDB:
1838 raise KeyError('object is not in the sourceDB bound to this map!')
1839 sql,params = self._format_query(self.viewSQL, (k.id,))
1840 self.cursor.execute(sql, params) # formatted for this db interface
1841 t = self.cursor.fetchmany(2) # get at most two rows
1842 if len(t) != 1:
1843 raise KeyError('%s not found in MapView, or not unique'
1844 % str(k))
1845 return self.targetDB[t[0][0]] # get the corresponding object
1846 _pickleAttrs = dict(sourceDB=0, targetDB=0, viewSQL=0, serverInfo=0,
1847 inverseSQL=0)
1848 __getstate__ = standard_getstate
1849 __setstate__ = standard_setstate
1850 __setitem__ = __delitem__ = clear = pop = popitem = update = \
1851 setdefault = read_only_error
1852 def __iter__(self):
1853 'only yield sourceDB items that are actually in this mapping!'
1854 for k in self.sourceDB.itervalues():
1855 try:
1856 self[k]
1857 yield k
1858 except KeyError:
1859 pass
1860 def keys(self):
1861 return [k for k in self] # don't use list(self); causes infinite loop!
1862 def __invert__(self):
1863 try:
1864 return self._inverse
1865 except AttributeError:
1866 if self.inverseSQL is None:
1867 raise ValueError('this MapView has no inverseSQL!')
1868 self._inverse = MapView(self.targetDB, self.sourceDB,
1869 self.inverseSQL, self.cursor,
1870 serverInfo=self.serverInfo,
1871 inverseSQL=self.viewSQL)
1872 self._inverse._inverse = self
1873 return self._inverse
1875 class GraphViewEdgeDict(UserDict.DictMixin):
1876 'edge dictionary for GraphView: just pre-loaded on init'
1877 def __init__(self, g, k):
1878 self.g = g
1879 self.k = k
1880 sql,params = self.g._format_query(self.g.viewSQL, (k.id,))
1881 self.g.cursor.execute(sql, params) # run the query
1882 l = self.g.cursor.fetchall() # get results
1883 self.targets = [t[0] for t in l] # preserve order of the results
1884 d = {} # also keep targetID:edgeID mapping
1885 if self.g.edgeDB is not None: # save with edge info
1886 for t in l:
1887 d[t[0]] = t[1]
1888 else:
1889 for t in l:
1890 d[t[0]] = None
1891 self.targetDict = d
1892 def __len__(self):
1893 return len(self.targets)
1894 def __iter__(self):
1895 for k in self.targets:
1896 yield self.g.targetDB[k]
1897 def keys(self):
1898 return list(self)
1899 def iteritems(self):
1900 if self.g.edgeDB is not None: # save with edge info
1901 for k in self.targets:
1902 yield (self.g.targetDB[k], self.g.edgeDB[self.targetDict[k]])
1903 else: # just save the list of targets, no edge info
1904 for k in self.targets:
1905 yield (self.g.targetDB[k], None)
1906 def __getitem__(self, o, exitIfFound=False):
1907 'for the specified target object, return its associated edge object'
1908 try:
1909 if o.db is not self.g.targetDB:
1910 raise KeyError('key is not part of targetDB!')
1911 edgeID = self.targetDict[o.id]
1912 except AttributeError:
1913 raise KeyError('key has no id or db attribute?!')
1914 if exitIfFound:
1915 return
1916 if self.g.edgeDB is not None: # return the edge object
1917 return self.g.edgeDB[edgeID]
1918 else: # no edge info
1919 return None
1920 def __contains__(self, o):
1921 try:
1922 self.__getitem__(o, True) # raise KeyError if not found
1923 return True
1924 except KeyError:
1925 return False
1926 __setitem__ = __delitem__ = clear = pop = popitem = update = \
1927 setdefault = read_only_error
1929 class GraphView(MapView):
1930 'general purpose graph interface defined by any SQL query'
1931 def __init__(self, sourceDB, targetDB, viewSQL, cursor=None, edgeDB=None,
1932 **kwargs):
1933 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
1934 self.edgeDB = edgeDB
1935 MapView.__init__(self, sourceDB, targetDB, viewSQL, cursor, **kwargs)
1936 def __getitem__(self, k):
1937 if not hasattr(k,'db') or k.db != self.sourceDB:
1938 raise KeyError('object is not in the sourceDB bound to this map!')
1939 return GraphViewEdgeDict(self, k)
1940 _pickleAttrs = MapView._pickleAttrs.copy()
1941 _pickleAttrs.update(dict(edgeDB=0))
1943 # @CTB move to sqlgraph.py?
1945 class SQLSequence(SQLRow, SequenceBase):
1946 """Transparent access to a DB row representing a sequence.
1948 Use attrAlias dict to rename 'length' to something else.
1950 def _init_subclass(cls, db, **kwargs):
1951 db.seqInfoDict = db # db will act as its own seqInfoDict
1952 SQLRow._init_subclass(db=db, **kwargs)
1953 _init_subclass = classmethod(_init_subclass)
1954 def __init__(self, id):
1955 SQLRow.__init__(self, id)
1956 SequenceBase.__init__(self)
1957 def __len__(self):
1958 return self.length
1959 def strslice(self,start,end):
1960 "Efficient access to slice of a sequence, useful for huge contigs"
1961 return self._select('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
1962 %(self.db._attrSQL('seq'),start+1,end-start))
1964 class DNASQLSequence(SQLSequence):
1965 _seqtype=DNA_SEQTYPE
1967 class RNASQLSequence(SQLSequence):
1968 _seqtype=RNA_SEQTYPE
1970 class ProteinSQLSequence(SQLSequence):
1971 _seqtype=PROTEIN_SEQTYPE