3 from __future__
import generators
5 from sequence
import SequenceBase
, DNA_SEQTYPE
, RNA_SEQTYPE
, PROTEIN_SEQTYPE
7 from classutil
import methodFactory
,standard_getstate
,\
8 override_rich_cmp
,generate_items
,get_bound_subclass
,standard_setstate
,\
9 get_valid_path
,standard_invert
,RecentValueDictionary
,read_only_error
,\
10 SourceFileName
, split_kwargs
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self
, db
, attr
):
20 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
21 def __get__(self
, obj
, klass
):
22 return obj
._data
[self
.icol
]
23 def __set__(self
, obj
, val
):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor
):
27 def __set__(self
, obj
, val
):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor
):
32 'read-write interface to named attribute'
33 def __init__(self
, db
, attr
):
35 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
36 self
.attrSQL
= db
._attrSQL
(attr
, sqlColumn
=True) # SQL column name
37 def __set__(self
, obj
, val
):
38 obj
.db
._update
(obj
.id, self
.attrSQL
, val
) # AND UPDATE THE DATABASE
39 obj
.save_local(self
.attr
, val
)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self
, db
, attr
):
44 self
.selectSQL
= db
._attrSQL
(attr
) # SQL expression for this attr
45 def __get__(self
, obj
, klass
):
46 return obj
._select
(self
.selectSQL
)
47 def __set__(self
, obj
, val
):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor
):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self
, obj
, val
):
53 obj
.db
._update
(obj
.id, self
.selectSQL
, val
) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self
, db
, attr
):
59 def __get__(self
, obj
, klass
):
60 return getattr(obj
, self
.attr
)
61 def __set__(self
, obj
, val
):
62 raise AttributeError('attribute %s is read-only!' % self
.attr
)
65 def select_from_row(row
, what
):
66 "return value from SQL expression applied to this row"
67 sql
,params
= row
.db
._format
_query
('select %s from %s where %s=%%s limit 2'
68 % (what
,row
.db
.name
,row
.db
.primary_key
),
70 row
.db
.cursor
.execute(sql
, params
)
71 t
= row
.db
.cursor
.fetchmany(2) # get at most two rows
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row
.db
.name
,str(row
.id),what
))
75 return t
[0][0] #return the single field we requested
77 def init_row_subclass(cls
, db
):
78 'add descriptors for db attributes'
79 for attr
in db
.data
: # bind all database columns
80 if attr
== 'id': # handle ID attribute specially
81 setattr(cls
, attr
, cls
._idDescriptor
(db
, attr
))
83 try: # check if this attr maps to an SQL column
84 db
._attrSQL
(attr
, columnNumber
=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls
, attr
, cls
._sqlDescriptor
(db
, attr
))
87 else: # treat as interface to our stored tuple
88 setattr(cls
, attr
, cls
._columnDescriptor
(db
, attr
))
91 """get list of column names as our attributes """
92 return self
.db
.data
.keys()
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor
= TupleDescriptor
115 _idDescriptor
= TupleIDDescriptor
116 _sqlDescriptor
= SQLDescriptor
117 _init_subclass
= classmethod(init_row_subclass
)
118 _select
= select_from_row
120 def __init__(self
, data
):
121 self
._data
= data
# save our data tuple
123 def insert_and_cache_id(self
, l
, **kwargs
):
124 'insert tuple into db and cache its rowID on self'
125 self
.db
._insert
(l
) # save to database
127 rowID
= kwargs
['id'] # use the ID supplied by user
129 rowID
= self
.db
.get_insert_id() # get auto-inc ID value
130 self
.cache_id(rowID
) # cache this ID on self
132 class TupleORW(TupleO
):
133 'read-write version of TupleO'
134 _columnDescriptor
= TupleDescriptorRW
135 insert_and_cache_id
= insert_and_cache_id
136 def __init__(self
, data
, newRow
=False, **kwargs
):
137 if not newRow
: # just cache data from the database
140 self
._data
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
141 self
.insert_and_cache_id(self
._data
, **kwargs
)
142 def cache_id(self
,row_id
):
143 self
.save_local('id',row_id
)
144 def save_local(self
,attr
,val
):
145 icol
= self
._attrcol
[attr
]
147 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self
._data
= list(self
._data
)
150 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
152 TupleO
._RWClass
= TupleORW
# record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self
, db
, attr
, readOnly
= False):
158 self
.col
= db
._attrSQL
(attr
, sqlColumn
=True) # MAP THIS TO SQL COLUMN NAME
161 self
.__class
__ = self
._readOnlyClass
162 def __get__(self
, obj
, objtype
):
164 return obj
.__dict
__[self
.attr
]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self
.col
==self
.db
.primary_key
:
168 self
.db
._select
('where %s=%%s' % self
.db
.primary_key
,(obj
.id,),self
.col
)
169 l
= self
.db
.cursor
.fetchall()
171 raise AttributeError('db row not found or not unique!')
172 obj
.__dict
__[self
.attr
] = l
[0][0] # UPDATE THE CACHE
174 def __set__(self
, obj
, val
):
175 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self
.db
._update
(obj
.id, self
.col
, val
) # UPDATE THE DATABASE
177 obj
.__dict
__[self
.attr
] = val
# UPDATE THE CACHE
179 ## m = self.consequences
180 ## except AttributeError:
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor
):
188 def __set__(self
, obj
, val
):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor
._readOnlyClass
= ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor
= _sqlDescriptor
= SQLDescriptor
200 _idDescriptor
= ReadOnlyDescriptor
201 _init_subclass
= classmethod(init_row_subclass
)
202 _select
= select_from_row
204 def __init__(self
, rowID
):
208 class SQLRowRW(SQLRow
):
209 'read-write version of SQLRow'
210 _columnDescriptor
= SQLDescriptorRW
211 insert_and_cache_id
= insert_and_cache_id
212 def __init__(self
, rowID
, newRow
=False, **kwargs
):
213 if not newRow
: # just cache data from the database
214 return self
.cache_id(rowID
)
215 l
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
216 self
.insert_and_cache_id(l
, **kwargs
)
217 def cache_id(self
, rowID
):
220 SQLRow
._RWClass
= SQLRowRW
224 def list_to_dict(names
, values
):
225 'return dictionary of those named args that are present in values[]'
227 for i
,v
in enumerate(values
):
235 def get_name_cursor(name
=None, **kwargs
):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
240 argList
= name
.split() # TREAT AS WS-SEPARATED LIST
242 name
= argList
[0] # USE 1ST ARG AS TABLE NAME
243 argnames
= ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs
= kwargs
.copy() # a copy we can overwrite
245 kwargs
.update(list_to_dict(argnames
, argList
[1:]))
246 serverInfo
= DBServerInfo(**kwargs
)
247 return name
,serverInfo
.cursor(),serverInfo
249 def mysql_connect(connect
=None, configFile
=None, **args
):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs
= args
.copy() # a copy we can modify
252 if 'user' not in kwargs
and configFile
is None: #Find where config file is
253 osname
= platform
.system()
254 if osname
in('Microsoft', 'Windows'): # Machine is a Windows box
256 try: # handle case where WINDIR not defined by Windows...
257 windir
= os
.environ
['WINDIR']
258 paths
+= [(windir
, 'my.ini'), (windir
, 'my.cnf')]
262 sysdrv
= os
.environ
['SYSTEMDRIVE']
263 paths
+= [(sysdrv
, os
.path
.sep
+ 'my.ini'),
264 (sysdrv
, os
.path
.sep
+ 'my.cnf')]
268 configFile
= get_valid_path(*paths
)
269 else: # treat as normal platform with home directories
270 configFile
= os
.path
.join(os
.path
.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile
= configFile
or os
.path
.join( os
.getcwd(), 'mysql.cnf' )
276 if configFile
and os
.path
.exists(configFile
):
277 kwargs
['read_default_file'] = configFile
278 connect
= None # force it to use MySQLdb
281 connect
= MySQLdb
.connect
282 kwargs
['compress'] = True
283 conn
= connect(**kwargs
)
284 cursor
= conn
.cursor()
287 _mysqlMacros
= dict(IGNORE
='ignore', REPLACE
='replace',
288 AUTO_INCREMENT
='AUTO_INCREMENT', SUBSTRING
='substring',
289 SUBSTR_FROM
='FROM', SUBSTR_FOR
='FOR')
291 def mysql_table_schema(self
, analyzeSchema
=True):
292 'retrieve table schema from a MySQL database, save on self'
294 self
._format
_query
= SQLFormatDict(MySQLdb
.paramstyle
, _mysqlMacros
)
295 if not analyzeSchema
:
297 self
.clear_schema() # reset settings and dictionaries
298 self
.cursor
.execute('describe %s' % self
.name
) # get info about columns
299 columns
= self
.cursor
.fetchall()
300 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
301 for icol
,c
in enumerate(columns
):
303 self
.columnName
.append(field
) # list of columns in same order as table
304 if c
[3] == "PRI": # record as primary key
305 if self
.primary_key
is None:
306 self
.primary_key
= field
309 self
.primary_key
.append(field
)
310 except AttributeError:
311 self
.primary_key
= [self
.primary_key
,field
]
312 if c
[1][:3].lower() == 'int':
313 self
.usesIntID
= True
315 self
.usesIntID
= False
317 self
.indexed
[field
] = icol
318 self
.description
[field
] = self
.cursor
.description
[icol
]
319 self
.columnType
[field
] = c
[1] # SQL COLUMN TYPE
321 _sqliteMacros
= dict(IGNORE
='or ignore', REPLACE
='insert or replace',
322 AUTO_INCREMENT
='', SUBSTRING
='substr',
323 SUBSTR_FROM
=',', SUBSTR_FOR
=',')
326 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
328 import sqlite3
as sqlite
330 from pysqlite2
import dbapi2
as sqlite
333 def sqlite_table_schema(self
, analyzeSchema
=True):
334 'retrieve table schema from a sqlite3 database, save on self'
335 sqlite
= import_sqlite()
336 self
._format
_query
= SQLFormatDict(sqlite
.paramstyle
, _sqliteMacros
)
337 if not analyzeSchema
:
339 self
.clear_schema() # reset settings and dictionaries
340 self
.cursor
.execute('PRAGMA table_info("%s")' % self
.name
)
341 columns
= self
.cursor
.fetchall()
342 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
343 for icol
,c
in enumerate(columns
):
345 self
.columnName
.append(field
) # list of columns in same order as table
346 self
.description
[field
] = self
.cursor
.description
[icol
]
347 self
.columnType
[field
] = c
[2] # SQL COLUMN TYPE
348 self
.cursor
.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self
.name
) # get primary key / unique indexes
349 for indexname
in self
.cursor
.fetchall(): # search indexes for primary key
350 self
.cursor
.execute('PRAGMA index_info("%s")' % indexname
)
351 l
= self
.cursor
.fetchall() # get list of columns in this index
352 if len(l
) == 1: # assume 1st single-column unique index is primary key!
353 self
.primary_key
= l
[0][2]
354 break # done searching for primary key!
355 if self
.primary_key
is None: # grrr, INTEGER PRIMARY KEY handled differently
356 self
.cursor
.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self
.name
)
357 sql
= self
.cursor
.fetchall()[0][0]
358 for columnSQL
in sql
[sql
.index('(') + 1 :].split(','):
359 if 'primary key' in columnSQL
.lower(): # must be the primary key!
360 col
= columnSQL
.split()[0] # get column name
361 if col
in self
.columnType
:
362 self
.primary_key
= col
363 break # done searching for primary key!
365 raise ValueError('unknown primary key %s in table %s'
367 if self
.primary_key
is not None: # check its type
368 if self
.columnType
[self
.primary_key
] == 'int' or \
369 self
.columnType
[self
.primary_key
] == 'integer':
370 self
.usesIntID
= True
372 self
.usesIntID
= False
374 class SQLFormatDict(object):
375 '''Perform SQL keyword replacements for maintaining compatibility across
376 a wide range of SQL backends. Uses Python dict-based string format
377 function to do simple string replacements, and also to convert
378 params list to the paramstyle required for this interface.
379 Create by passing a dict of macros and the db-api paramstyle:
380 sfd = SQLFormatDict("qmark", substitutionDict)
382 Then transform queries+params as follows; input should be "format" style:
383 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
384 cursor.execute(sql, params)
386 _paramFormats
= dict(pyformat
='%%(%d)s', numeric
=':%d', named
=':%d',
387 qmark
='(ignore)', format
='(ignore)')
388 def __init__(self
, paramstyle
, substitutionDict
={}):
389 self
.substitutionDict
= substitutionDict
.copy()
390 self
.paramstyle
= paramstyle
391 self
.paramFormat
= self
._paramFormats
[paramstyle
]
392 self
.makeDict
= (paramstyle
== 'pyformat' or paramstyle
== 'named')
393 if paramstyle
== 'qmark': # handle these as simple substitution
394 self
.substitutionDict
['?'] = '?'
395 elif paramstyle
== 'format':
396 self
.substitutionDict
['?'] = '%s'
397 def __getitem__(self
, k
):
398 'apply correct substitution for this SQL interface'
400 return self
.substitutionDict
[k
] # apply our substitutions
403 if k
== '?': # sequential parameter
404 s
= self
.paramFormat
% self
.iparam
405 self
.iparam
+= 1 # advance to the next parameter
407 raise KeyError('unknown macro: %s' % k
)
408 def __call__(self
, sql
, paramList
):
409 'returns corrected sql,params for this interface'
410 self
.iparam
= 1 # DB-ABI param indexing begins at 1
411 sql
= sql
.replace('%s', '%(?)s') # convert format into pyformat
412 s
= sql
% self
# apply all %(x)s replacements in sql
413 if self
.makeDict
: # construct a params dict
415 for i
,param
in enumerate(paramList
):
416 paramDict
[str(i
+ 1)] = param
#DB-ABI param indexing begins at 1
418 else: # just return the original params list
421 def get_table_schema(self
, analyzeSchema
=True):
422 'run the right schema function based on type of db server connection'
424 modname
= self
.cursor
.__class
__.__module
__
425 except AttributeError:
426 raise ValueError('no cursor object or module information!')
428 schema_func
= self
._schemaModuleDict
[modname
]
430 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
431 attribute to supply a method for obtaining table schema
432 for this module''' % modname
)
433 schema_func(self
, analyzeSchema
) # run the schema function
436 _schemaModuleDict
= {'MySQLdb.cursors':mysql_table_schema
,
437 'pysqlite2.dbapi2':sqlite_table_schema
,
438 'sqlite3':sqlite_table_schema
}
440 class SQLTableBase(object, UserDict
.DictMixin
):
441 "Store information about an SQL table as dict keyed by primary key"
442 _schemaModuleDict
= _schemaModuleDict
# default module list
443 get_table_schema
= get_table_schema
444 def __init__(self
,name
,cursor
=None,itemClass
=None,attrAlias
=None,
445 clusterKey
=None,createTable
=None,graph
=None,maxCache
=None,
446 arraysize
=1024, itemSliceClass
=None, dropIfExists
=False,
447 serverInfo
=None, autoGC
=True, orderBy
=None,
448 writeable
=False, **kwargs
):
449 if autoGC
: # automatically garbage collect unused objects
450 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
452 self
._weakValueDict
= {}
454 self
.orderBy
= orderBy
455 self
.writeable
= writeable
457 if serverInfo
is not None: # get cursor from serverInfo
458 cursor
= serverInfo
.cursor()
459 else: # try to read connection info from name or config file
460 name
,cursor
,serverInfo
= get_name_cursor(name
,**kwargs
)
462 warnings
.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
463 DeprecationWarning, stacklevel
=2)
465 if createTable
is not None: # RUN COMMAND TO CREATE THIS TABLE
466 if dropIfExists
: # get rid of any existing table
467 cursor
.execute('drop table if exists ' + name
)
468 self
.get_table_schema(False) # check dbtype, init _format_query
469 sql
,params
= self
._format
_query
(createTable
, ()) # apply macros
470 cursor
.execute(sql
) # create the table
472 if graph
is not None:
474 if maxCache
is not None:
475 self
.maxCache
= maxCache
476 if arraysize
is not None:
477 self
.arraysize
= arraysize
478 cursor
.arraysize
= arraysize
479 self
.get_table_schema() # get schema of columns to serve as attrs
480 self
.data
= {} # map of all attributes, including aliases
481 for icol
,field
in enumerate(self
.columnName
):
482 self
.data
[field
] = icol
# 1st add mappings to columns
484 self
.data
['id']=self
.data
[self
.primary_key
]
485 except (KeyError,TypeError):
487 if hasattr(self
,'_attr_alias'): # apply attribute aliases for this class
488 self
.addAttrAlias(False,**self
._attr
_alias
)
489 self
.objclass(itemClass
) # NEED TO SUBCLASS OUR ITEM CLASS
490 if itemSliceClass
is not None:
491 self
.itemSliceClass
= itemSliceClass
492 get_bound_subclass(self
, 'itemSliceClass', self
.name
) # need to subclass itemSliceClass
493 if attrAlias
is not None: # ADD ATTRIBUTE ALIASES
494 self
.attrAlias
= attrAlias
# RECORD FOR PICKLING PURPOSES
495 self
.data
.update(attrAlias
)
496 if clusterKey
is not None:
497 self
.clusterKey
=clusterKey
498 if serverInfo
is not None:
499 self
.serverInfo
= serverInfo
502 self
._select
(selectCols
='count(*)')
503 return self
.cursor
.fetchone()[0]
506 _pickleAttrs
= dict(name
=0, clusterKey
=0, maxCache
=0, arraysize
=0,
507 attrAlias
=0, serverInfo
=0, autoGC
=0, orderBy
=0,
509 __getstate__
= standard_getstate
510 def __setstate__(self
,state
):
511 # default cursor provisioning by worldbase is deprecated!
512 ## if 'serverInfo' not in state: # hmm, no address for db server?
513 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
514 ## from Data import getResource
515 ## state['cursor'] = getResource.getTableCursor(state['name'])
516 ## except ImportError:
517 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
518 self
.__init
__(**state
)
520 return '<SQL table '+self
.name
+'>'
522 def clear_schema(self
):
523 'reset all schema information for this table'
527 self
.usesIntID
= None
528 self
.primary_key
= None
530 def _attrSQL(self
,attr
,sqlColumn
=False,columnNumber
=False):
531 "Translate python attribute name to appropriate SQL expression"
532 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
533 field
=self
.data
[attr
]
535 raise AttributeError('attribute %s not a valid column or alias in %s'
537 if sqlColumn
: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
538 try: # CHECK IF field IS COLUMN NUMBER
539 return self
.columnName
[field
] # RETURN SQL COLUMN NAME
541 try: # CHECK IF field IS SQL COLUMN NAME
542 return self
.columnName
[self
.data
[field
]] # THIS WILL JUST RETURN field...
543 except (KeyError,TypeError):
544 raise AttributeError('attribute %s does not map to an SQL column in %s'
547 try: # CHECK IF field IS A COLUMN NUMBER
548 return field
+0 # ONLY RETURN AN INTEGER
550 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
551 return self
.data
[field
]+0 # ONLY RETURN AN INTEGER
552 except (KeyError,TypeError):
553 raise ValueError('attribute %s does not map to a SQL column!' % attr
)
554 if isinstance(field
,types
.StringType
):
555 attr
=field
# USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
557 attr
=self
.primary_key
559 def addAttrAlias(self
,saveToPickle
=True,**kwargs
):
560 """Add new attributes as aliases of existing attributes.
561 They can be specified either as named args:
562 t.addAttrAlias(newattr=oldattr)
563 or by passing a dictionary kwargs whose keys are newattr
564 and values are oldattr:
565 t.addAttrAlias(**kwargs)
566 saveToPickle=True forces these aliases to be saved if object is pickled.
569 self
.attrAlias
.update(kwargs
)
570 for key
,val
in kwargs
.items():
571 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
572 self
.data
[val
]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
573 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
574 except TypeError: # val IS ITSELF AN ALIAS
575 self
.data
[key
] = self
.data
[val
] # SO MAP TO WHAT IT MAPS TO
576 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
578 def objclass(self
,oclass
=None):
579 "Create class representing a row in this table by subclassing oclass, adding data"
580 if oclass
is not None: # use this as our base itemClass
581 self
.itemClass
= oclass
583 self
.itemClass
= self
.itemClass
._RWClass
# use its writeable version
584 oclass
= get_bound_subclass(self
, 'itemClass', self
.name
,
585 subclassArgs
=dict(db
=self
)) # bind itemClass
586 if issubclass(oclass
, TupleO
):
587 oclass
._attrcol
= self
.data
# BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
588 if hasattr(oclass
,'_tableclass') and not isinstance(self
,oclass
._tableclass
):
589 self
.__class
__=oclass
._tableclass
# ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
590 def _select(self
, whereClause
='', params
=(), selectCols
='t1.*',
592 'execute the specified query but do not fetch'
593 sql
,params
= self
._format
_query
('select %s from %s t1 %s'
594 % (selectCols
,self
.name
,whereClause
), params
)
596 self
.cursor
.execute(sql
, params
)
598 cursor
.execute(sql
, params
)
599 def select(self
,whereClause
,params
=None,oclass
=None,selectCols
='t1.*'):
600 "Generate the list of objects that satisfy the database SELECT"
602 oclass
=self
.itemClass
603 self
._select
(whereClause
,params
,selectCols
)
604 l
=self
.cursor
.fetchall()
606 yield self
.cacheItem(t
,oclass
)
607 def query(self
,**kwargs
):
608 'query for intersection of all specified kwargs, returned as iterator'
611 for k
,v
in kwargs
.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
612 if v
is None: # CONVERT TO SQL NULL TEST
613 criteria
.append('%s IS NULL' % self
._attrSQL
(k
))
614 else: # TEST FOR EQUALITY
615 criteria
.append('%s=%%s' % self
._attrSQL
(k
))
617 return self
.select('where '+' and '.join(criteria
),params
)
618 def _update(self
,row_id
,col
,val
):
619 'update a single field in the specified row to the specified value'
620 sql
,params
= self
._format
_query
('update %s set %s=%%s where %s=%%s'
621 %(self
.name
,col
,self
.primary_key
),
623 self
.cursor
.execute(sql
, params
)
626 return t
[self
.data
['id']] # GET ID FROM TUPLE
627 except TypeError: # treat as alias
628 return t
[self
.data
[self
.data
['id']]]
629 def cacheItem(self
,t
,oclass
):
630 'get obj from cache if possible, or construct from tuple'
633 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
635 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
636 return self
._weakValueDict
[id]
640 self
._weakValueDict
[id] = o
# CACHE THIS ITEM IN OUR DICTIONARY
642 def cache_items(self
,rows
,oclass
=None):
644 oclass
=self
.itemClass
646 yield self
.cacheItem(t
,oclass
)
647 def foreignKey(self
,attr
,k
):
648 'get iterator for objects with specified foreign key value'
649 return self
.select('where %s=%%s'%attr
,(k
,))
650 def limit_cache(self
):
651 'APPLY maxCache LIMIT TO CACHE SIZE'
653 if self
.maxCache
<len(self
._weakValueDict
):
654 self
._weakValueDict
.clear()
655 except AttributeError:
658 def get_new_cursor(self
):
659 """Return a new cursor object, or None if not possible """
661 new_cursor
= self
.serverInfo
.new_cursor
662 except AttributeError:
666 def generic_iterator(self
, cursor
=None, fetch_f
=None, cache_f
=None,
668 'generic iterator that runs fetch, cache and map functions'
669 if fetch_f
is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
671 fetch_f
= self
.cursor
.fetchmany
672 else: # isolate this iter from other queries
673 fetch_f
= cursor
.fetchmany
675 cache_f
= self
.cache_items
678 rows
= fetch_f() # FETCH THE NEXT SET OF ROWS
679 if len(rows
)==0: # NO MORE DATA SO ALL DONE
681 for v
in map_f(cache_f(rows
)): # CACHE AND GENERATE RESULTS
683 if cursor
is not None: # close iterator now that we're done
685 def tuple_from_dict(self
, d
):
686 'transform kwarg dict into tuple for storing in database'
687 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
688 for col
,icol
in self
.data
.items():
691 except (KeyError,TypeError):
694 def tuple_from_obj(self
, obj
):
695 'transform object attributes into tuple for storing in database'
696 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
697 for col
,icol
in self
.data
.items():
699 l
[icol
] = getattr(obj
,col
)
700 except (AttributeError,TypeError):
703 def _insert(self
, l
):
704 '''insert tuple into the database. Note this uses the MySQL
705 extension REPLACE, which overwrites any duplicate key.'''
706 s
= '%(REPLACE)s into ' + self
.name
+ ' values (' \
707 + ','.join(['%s']*len(l
)) + ')'
708 sql
,params
= self
._format
_query
(s
, l
)
709 self
.cursor
.execute(sql
, params
)
710 def insert(self
, obj
):
711 '''insert new row by transforming obj to tuple of values'''
712 l
= self
.tuple_from_obj(obj
)
714 def get_insert_id(self
):
715 'get the primary key value for the last INSERT'
716 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
717 auto_id
= self
.cursor
.lastrowid
718 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
719 raise NotImplementedError('''your db lacks lastrowid support?''')
721 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
723 def new(self
, **kwargs
):
724 'return a new record with the assigned attributes, added to DB'
725 if not self
.writeable
:
726 raise ValueError('this database is read only!')
727 obj
= self
.itemClass(None, newRow
=True, **kwargs
) # saves itself to db
728 self
._weakValueDict
[obj
.id] = obj
# AND SAVE TO OUR LOCAL DICT CACHE
730 def clear_cache(self
):
732 self
._weakValueDict
.clear()
733 def __delitem__(self
, k
):
734 if not self
.writeable
:
735 raise ValueError('this database is read only!')
736 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
737 % (self
.name
,self
.primary_key
),(k
,))
738 self
.cursor
.execute(sql
, params
)
740 del self
._weakValueDict
[k
]
744 def getKeys(self
,queryOption
='', selectCols
=None):
745 'uses db select; does not force load'
746 if selectCols
is None:
747 selectCols
=self
.primary_key
748 if queryOption
=='' and self
.orderBy
is not None:
749 queryOption
= self
.orderBy
# apply default ordering
750 self
.cursor
.execute('select %s from %s %s'
751 %(selectCols
,self
.name
,queryOption
))
752 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
754 def iter_keys(self
, selectCols
=None):
755 'guarantee correct iteration insulated from other queries'
756 if selectCols
is None:
757 selectCols
=self
.primary_key
758 cursor
= self
.get_new_cursor()
759 if cursor
: # got our own cursor, guaranteeing query isolation
760 self
._select
(cursor
=cursor
, selectCols
=selectCols
)
761 return self
.generic_iterator(cursor
=cursor
,
762 cache_f
=lambda x
:[t
[0] for t
in x
])
763 else: # must pre-fetch all keys to ensure query isolation
764 return iter(self
.keys())
766 class SQLTable(SQLTableBase
):
767 "Provide on-the-fly access to rows in the database, caching the results in dict"
768 itemClass
= TupleO
# our default itemClass; constructor can override
771 def load(self
,oclass
=None):
772 "Load all data from the table"
773 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
774 return self
._isLoaded
775 except AttributeError:
778 oclass
=self
.itemClass
779 self
.cursor
.execute('select * from %s' % self
.name
)
780 l
=self
.cursor
.fetchall()
781 self
._weakValueDict
= {} # just store the whole dataset in memory
783 self
.cacheItem(t
,oclass
) # CACHE IT IN LOCAL DICTIONARY
784 self
._isLoaded
=True # MARK THIS CONTAINER AS FULLY LOADED
786 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
788 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
789 except KeyError: # NOT FOUND, SO TRY THE DATABASE
790 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 2'
791 % (self
.name
,self
.primary_key
),(k
,))
792 self
.cursor
.execute(sql
, params
)
793 l
= self
.cursor
.fetchmany(2) # get at most 2 rows
795 raise KeyError('%s not found in %s, or not unique' %(str(k
),self
.name
))
797 return self
.cacheItem(l
[0],self
.itemClass
) # CACHE IT IN LOCAL DICTIONARY
798 def __setitem__(self
, k
, v
):
799 if not self
.writeable
:
800 raise ValueError('this database is read only!')
804 except AttributeError:
805 raise ValueError('object not bound to itemClass for this db!')
810 except AttributeError:
812 else: # delete row with old ID
814 v
.cache_id(k
) # cache the new ID on the object
815 self
.insert(v
) # SAVE TO THE RELATIONAL DB SERVER
816 self
._weakValueDict
[k
] = v
# CACHE THIS ITEM IN OUR DICTIONARY
818 'forces load of entire table into memory'
820 return self
._weakValueDict
.items()
822 'uses arraysize / maxCache and fetchmany() to manage data transfer'
823 cursor
= self
.get_new_cursor()
824 self
._select
(cursor
=cursor
)
825 return self
.generic_iterator(cursor
=cursor
, map_f
=generate_items
)
827 'forces load of entire table into memory'
829 return self
._weakValueDict
.values()
830 def itervalues(self
):
831 'uses arraysize / maxCache and fetchmany() to manage data transfer'
832 cursor
= self
.get_new_cursor()
833 self
._select
(cursor
=cursor
)
834 return self
.generic_iterator(cursor
=cursor
)
836 def getClusterKeys(self
,queryOption
=''):
837 'uses db select; does not force load'
838 self
.cursor
.execute('select distinct %s from %s %s'
839 %(self
.clusterKey
,self
.name
,queryOption
))
840 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
843 class SQLTableClustered(SQLTable
):
844 '''use clusterKey to load a whole cluster of rows at once,
845 specifically, all rows that share the same clusterKey value.'''
846 def __init__(self
, *args
, **kwargs
):
847 kwargs
= kwargs
.copy() # get a copy we can alter
848 kwargs
['autoGC'] = False # don't use WeakValueDictionary
849 SQLTable
.__init
__(self
, *args
, **kwargs
)
851 return getKeys(self
,'order by %s' %self
.clusterKey
)
852 def clusterkeys(self
):
853 return getClusterKeys(self
, 'order by %s' %self
.clusterKey
)
854 def __getitem__(self
,k
):
856 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
857 except KeyError: # NOT FOUND, SO TRY THE DATABASE
858 sql
,params
= self
._format
_query
('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
859 % (self
.name
,self
.name
,self
.primary_key
,
860 self
.clusterKey
,self
.clusterKey
),(k
,))
861 self
.cursor
.execute(sql
, params
)
862 l
=self
.cursor
.fetchall()
864 for t
in l
: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
865 self
.cacheItem(t
,self
.itemClass
)
866 return self
._weakValueDict
[k
] # should be in cache, if row k exists
867 def itercluster(self
,cluster_id
):
868 'iterate over all items from the specified cluster'
870 return self
.select('where %s=%%s'%self
.clusterKey
,(cluster_id
,))
871 def fetch_cluster(self
):
872 'use self.cursor.fetchmany to obtain all rows for next cluster'
873 icol
= self
._attrSQL
(self
.clusterKey
,columnNumber
=True)
876 rows
= self
._fetch
_cluster
_cache
# USE SAVED ROWS FROM PREVIOUS CALL
877 del self
._fetch
_cluster
_cache
878 except AttributeError:
879 rows
= self
.cursor
.fetchmany()
881 cluster_id
= rows
[0][icol
]
885 for i
,t
in enumerate(rows
): # CHECK THAT ALL ROWS FROM THIS CLUSTER
886 if cluster_id
!= t
[icol
]: # START OF A NEW CLUSTER
887 result
+= rows
[:i
] # RETURN ROWS OF CURRENT CLUSTER
888 self
._fetch
_cluster
_cache
= rows
[i
:] # SAVE NEXT CLUSTER
891 rows
= self
.cursor
.fetchmany() # GET NEXT SET OF ROWS
893 def itervalues(self
):
894 'uses arraysize / maxCache and fetchmany() to manage data transfer'
895 cursor
= self
.get_new_cursor()
896 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
897 return self
.generic_iterator(cursor
, self
.fetch_cluster
)
899 'uses arraysize / maxCache and fetchmany() to manage data transfer'
900 cursor
= self
.get_new_cursor()
901 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
902 return self
.generic_iterator(cursor
, self
.fetch_cluster
,
903 map_f
=generate_items
)
905 class SQLForeignRelation(object):
906 'mapping based on matching a foreign key in an SQL table'
907 def __init__(self
,table
,keyName
):
910 def __getitem__(self
,k
):
911 'get list of objects o with getattr(o,keyName)==k.id'
913 for o
in self
.table
.select('where %s=%%s'%self
.keyName
,(k
.id,)):
916 raise KeyError('%s not found in %s' %(str(k
),self
.name
))
920 class SQLTableNoCache(SQLTableBase
):
921 '''Provide on-the-fly access to rows in the database;
922 values are simply an object interface (SQLRow) to back-end db query.
923 Row data are not stored locally, but always accessed by querying the db'''
924 itemClass
=SQLRow
# DEFAULT OBJECT CLASS FOR ROWS...
927 def getID(self
,t
): return t
[0] # GET ID FROM TUPLE
928 def select(self
,whereClause
,params
):
929 return SQLTableBase
.select(self
,whereClause
,params
,self
.oclass
,
931 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
933 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
934 except KeyError: # NOT FOUND, SO TRY THE DATABASE
935 self
._select
('where %s=%%s' % self
.primary_key
, (k
,),
937 t
= self
.cursor
.fetchmany(2)
939 raise KeyError('id %s non-existent or not unique' % k
)
940 o
= self
.itemClass(k
) # create obj referencing this ID
941 self
._weakValueDict
[k
] = o
# cache the SQLRow object
943 def __setitem__(self
, k
, v
):
944 if not self
.writeable
:
945 raise ValueError('this database is read only!')
949 except AttributeError:
950 raise ValueError('object not bound to itemClass for this db!')
952 del self
[k
] # delete row with new ID if any
956 del self
._weakValueDict
[v
.id] # delete from old cache location
959 self
._update
(v
.id, self
.primary_key
, k
) # just change its ID in db
960 v
.cache_id(k
) # change the cached ID value
961 self
._weakValueDict
[k
] = v
# assign to new cache location
962 def addAttrAlias(self
,**kwargs
):
963 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
965 SQLRow
._tableclass
=SQLTableNoCache
# SQLRow IS FOR NON-CACHING TABLE INTERFACE
968 class SQLTableMultiNoCache(SQLTableBase
):
969 "Trivial on-the-fly access for table with key that returns multiple rows"
970 itemClass
= TupleO
# default itemClass; constructor can override
971 _distinct_key
='id' # DEFAULT COLUMN TO USE AS KEY
973 return getKeys(self
, selectCols
='distinct(%s)'
974 % self
._attrSQL
(self
._distinct
_key
))
976 return iter_keys(self
, 'distinct(%s)' % self
._attrSQL
(self
._distinct
_key
))
977 def __getitem__(self
,id):
978 sql
,params
= self
._format
_query
('select * from %s where %s=%%s'
979 %(self
.name
,self
._attrSQL
(self
._distinct
_key
)),(id,))
980 self
.cursor
.execute(sql
, params
)
981 l
=self
.cursor
.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
983 yield self
.itemClass(row
)
984 def addAttrAlias(self
,**kwargs
):
985 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
989 class SQLEdges(SQLTableMultiNoCache
):
990 '''provide iterator over edges as (source,target,edge)
991 and getitem[edge] --> [(source,target),...]'''
992 _distinct_key
='edge_id'
993 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
994 _pickleAttrs
.update(dict(graph
=0))
996 self
.cursor
.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
997 %(self
._attrSQL
('source_id'),self
._attrSQL
('target_id'),
998 self
._attrSQL
('edge_id'),self
.name
,
999 self
._attrSQL
('target_id'),self
._attrSQL
('source_id'),
1000 self
._attrSQL
('target_id')))
1001 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1002 for source_id
,target_id
,edge_id
in self
.cursor
.fetchall():
1003 l
.append((self
.graph
.unpack_source(source_id
),
1004 self
.graph
.unpack_target(target_id
),
1005 self
.graph
.unpack_edge(edge_id
)))
1009 return iter(self
.keys())
1010 def __getitem__(self
,edge
):
1011 sql
,params
= self
._format
_query
('select %s,%s from %s where %s=%%s'
1012 %(self
._attrSQL
('source_id'),
1013 self
._attrSQL
('target_id'),
1015 self
._attrSQL
(self
._distinct
_key
)),
1016 (self
.graph
.pack_edge(edge
),))
1017 self
.cursor
.execute(sql
, params
)
1018 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1019 for source_id
,target_id
in self
.cursor
.fetchall():
1020 l
.append((self
.graph
.unpack_source(source_id
),
1021 self
.graph
.unpack_target(target_id
)))
1025 class SQLEdgeDict(object):
1026 '2nd level graph interface to SQL database'
1027 def __init__(self
,fromNode
,table
):
1028 self
.fromNode
=fromNode
1030 if not hasattr(self
.table
,'allowMissingNodes'):
1031 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s limit 1'
1032 %(self
.table
.sourceSQL
,
1034 self
.table
.sourceSQL
),
1036 self
.table
.cursor
.execute(sql
, params
)
1037 if len(self
.table
.cursor
.fetchall())<1:
1038 raise KeyError('node not in graph!')
1040 def __getitem__(self
,target
):
1041 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1042 %(self
.table
.edgeSQL
,
1044 self
.table
.sourceSQL
,
1045 self
.table
.targetSQL
),
1047 self
.table
.pack_target(target
)))
1048 self
.table
.cursor
.execute(sql
, params
)
1049 l
= self
.table
.cursor
.fetchmany(2) # get at most two rows
1051 raise KeyError('either no edge from source to target or not unique!')
1053 return self
.table
.unpack_edge(l
[0][0]) # RETURN EDGE
1055 raise KeyError('no edge from node to target')
1056 def __setitem__(self
,target
,edge
):
1057 sql
,params
= self
.table
._format
_query
('replace into %s values (%%s,%%s,%%s)'
1060 self
.table
.pack_target(target
),
1061 self
.table
.pack_edge(edge
)))
1062 self
.table
.cursor
.execute(sql
, params
)
1063 if not hasattr(self
.table
,'sourceDB') or \
1064 (hasattr(self
.table
,'targetDB') and self
.table
.sourceDB
==self
.table
.targetDB
):
1065 self
.table
+= target
# ADD AS NODE TO GRAPH
1066 def __iadd__(self
,target
):
1068 return self
# iadd MUST RETURN self!
1069 def __delitem__(self
,target
):
1070 sql
,params
= self
.table
._format
_query
('delete from %s where %s=%%s and %s=%%s'
1072 self
.table
.sourceSQL
,
1073 self
.table
.targetSQL
),
1075 self
.table
.pack_target(target
)))
1076 self
.table
.cursor
.execute(sql
, params
)
1077 if self
.table
.cursor
.rowcount
< 1: # no rows deleted?
1078 raise KeyError('no edge from node to target')
1080 def iterator_query(self
):
1081 sql
,params
= self
.table
._format
_query
('select %s,%s from %s where %s=%%s and %s is not null'
1082 %(self
.table
.targetSQL
,
1085 self
.table
.sourceSQL
,
1086 self
.table
.targetSQL
),
1088 self
.table
.cursor
.execute(sql
, params
)
1089 return self
.table
.cursor
.fetchall()
1091 return [self
.table
.unpack_target(target_id
)
1092 for target_id
,edge_id
in self
.iterator_query()]
1094 return [self
.table
.unpack_edge(edge_id
)
1095 for target_id
,edge_id
in self
.iterator_query()]
1097 return [(self
.table
.unpack_source(self
.fromNode
),self
.table
.unpack_target(target_id
),
1098 self
.table
.unpack_edge(edge_id
))
1099 for target_id
,edge_id
in self
.iterator_query()]
1101 return [(self
.table
.unpack_target(target_id
),self
.table
.unpack_edge(edge_id
))
1102 for target_id
,edge_id
in self
.iterator_query()]
1103 def __iter__(self
): return iter(self
.keys())
1104 def itervalues(self
): return iter(self
.values())
1105 def iteritems(self
): return iter(self
.items())
1107 return len(self
.keys())
1110 class SQLEdgelessDict(SQLEdgeDict
):
1111 'for SQLGraph tables that lack edge_id column'
1112 def __getitem__(self
,target
):
1113 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1114 %(self
.table
.targetSQL
,
1116 self
.table
.sourceSQL
,
1117 self
.table
.targetSQL
),
1119 self
.table
.pack_target(target
)))
1120 self
.table
.cursor
.execute(sql
, params
)
1121 l
= self
.table
.cursor
.fetchmany(2)
1123 raise KeyError('either no edge from source to target or not unique!')
1124 return None # no edge info!
1125 def iterator_query(self
):
1126 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s is not null'
1127 %(self
.table
.targetSQL
,
1129 self
.table
.sourceSQL
,
1130 self
.table
.targetSQL
),
1132 self
.table
.cursor
.execute(sql
, params
)
1133 return [(t
[0],None) for t
in self
.table
.cursor
.fetchall()]
1135 SQLEdgeDict
._edgelessClass
= SQLEdgelessDict
1137 class SQLGraphEdgeDescriptor(object):
1138 'provide an SQLEdges interface on demand'
1139 def __get__(self
,obj
,objtype
):
1141 attrAlias
=obj
.attrAlias
.copy()
1142 except AttributeError:
1143 return SQLEdges(obj
.name
, obj
.cursor
, graph
=obj
)
1145 return SQLEdges(obj
.name
, obj
.cursor
, attrAlias
=attrAlias
,
1148 def getColumnTypes(createTable
,attrAlias
={},defaultColumnType
='int',
1149 columnAttrs
=('source','target','edge'),**kwargs
):
1150 'return list of [(colname,coltype),...] for source,target,edge'
1152 for attr
in columnAttrs
:
1154 attrName
= attrAlias
[attr
+'_id']
1156 attrName
= attr
+'_id'
1157 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1158 l
.append((attrName
,createTable
[attr
+'_id']))
1160 except (KeyError,TypeError):
1162 try: # get type info from primary key for that database
1163 db
= kwargs
[attr
+'DB']
1165 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1168 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1170 try: # GET ONE IDENTIFIER FROM THE DATABASE
1172 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1174 l
.append((attrName
,db
.columnType
[db
.primary_key
]))
1176 except AttributeError:
1178 else: # GET THE TYPE FROM THIS IDENTIFIER
1179 if isinstance(k
,int) or isinstance(k
,long):
1180 l
.append((attrName
,'int'))
1182 elif isinstance(k
,str):
1183 l
.append((attrName
,'varchar(32)'))
1186 raise ValueError('SQLGraph node / edge must be int or str!')
1187 l
.append((attrName
,defaultColumnType
))
1188 logger
.warn('no type info found for %s, so using default: %s'
1189 % (attrName
, defaultColumnType
))
1195 class SQLGraph(SQLTableMultiNoCache
):
1196 '''provide a graph interface via a SQL table. Key capabilities are:
1197 - setitem with an empty dictionary: a dummy operation
1198 - getitem with a key that exists: return a placeholder
1199 - setitem with non empty placeholder: again a dummy operation
1200 EXAMPLE TABLE SCHEMA:
1201 create table mygraph (source_id int not null,target_id int,edge_id int,
1202 unique(source_id,target_id));
1204 _distinct_key
='source_id'
1205 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1206 _pickleAttrs
.update(dict(sourceDB
=0,targetDB
=0,edgeDB
=0,allowMissingNodes
=0))
1207 _edgeClass
= SQLEdgeDict
1208 def __init__(self
,name
,*l
,**kwargs
):
1209 graphArgs
,tableArgs
= split_kwargs(kwargs
,
1210 ('attrAlias','defaultColumnType','columnAttrs',
1211 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1212 'edgeDictClass','graph'))
1213 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1214 c
= getColumnTypes(**kwargs
)
1215 tableArgs
['createTable'] = \
1216 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1217 % (name
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1219 self
.allowMissingNodes
= kwargs
['allowMissingNodes']
1220 except KeyError: pass
1221 SQLTableMultiNoCache
.__init
__(self
,name
,*l
,**tableArgs
)
1222 self
.sourceSQL
= self
._attrSQL
('source_id')
1223 self
.targetSQL
= self
._attrSQL
('target_id')
1225 self
.edgeSQL
= self
._attrSQL
('edge_id')
1226 except AttributeError:
1228 self
._edgeClass
= self
._edgeClass
._edgelessClass
1229 save_graph_db_refs(self
,**kwargs
)
1230 def __getitem__(self
,k
):
1231 return self
._edgeClass
(self
.pack_source(k
),self
)
1232 def __iadd__(self
,k
):
1233 sql
,params
= self
._format
_query
('delete from %s where %s=%%s and %s is null'
1234 % (self
.name
,self
.sourceSQL
,self
.targetSQL
),
1235 (self
.pack_source(k
),))
1236 self
.cursor
.execute(sql
, params
)
1237 sql
,params
= self
._format
_query
('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1238 % self
.name
,(self
.pack_source(k
),))
1239 self
.cursor
.execute(sql
, params
)
1240 return self
# iadd MUST RETURN SELF!
1241 def __isub__(self
,k
):
1242 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
1243 % (self
.name
,self
.sourceSQL
),
1244 (self
.pack_source(k
),))
1245 self
.cursor
.execute(sql
, params
)
1246 if self
.cursor
.rowcount
== 0:
1247 raise KeyError('node not found in graph')
1248 return self
# iadd MUST RETURN SELF!
1249 __setitem__
= graph_setitem
1250 def __contains__(self
,k
):
1251 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 1'
1252 %(self
.name
,self
.sourceSQL
),
1253 (self
.pack_source(k
),))
1254 self
.cursor
.execute(sql
, params
)
1255 l
= self
.cursor
.fetchmany(2)
1257 def __invert__(self
):
1258 'get an interface to the inverse graph mapping'
1260 return self
._inverse
1261 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1262 attrAlias
= dict(source_id
=self
.targetSQL
, # SWAP SOURCE & TARGET
1263 target_id
=self
.sourceSQL
,
1264 edge_id
=self
.edgeSQL
)
1265 if self
.edgeSQL
is None: # no edge interface
1266 del attrAlias
['edge_id']
1267 self
._inverse
=SQLGraph(self
.name
,self
.cursor
,
1268 attrAlias
=attrAlias
,
1269 **graph_db_inverse_refs(self
))
1270 self
._inverse
._inverse
=self
1271 return self
._inverse
1273 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1274 yield self
.unpack_source(k
)
1275 def iteritems(self
):
1276 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1277 yield (self
.unpack_source(k
), self
._edgeClass
(k
, self
))
1278 def itervalues(self
):
1279 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1280 yield self
._edgeClass
(k
, self
)
1282 return [self
.unpack_source(k
) for k
in SQLTableMultiNoCache
.keys(self
)]
1283 def values(self
): return list(self
.itervalues())
1284 def items(self
): return list(self
.iteritems())
1285 edges
=SQLGraphEdgeDescriptor()
1286 update
= update_graph
1288 'get number of source nodes in graph'
1289 self
.cursor
.execute('select count(distinct %s) from %s'
1290 %(self
.sourceSQL
,self
.name
))
1291 return self
.cursor
.fetchone()[0]
1293 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1294 ## def __cmp__(self,other):
1298 ## it = iter(self.edges)
1301 ## source,target,edge = it.next()
1302 ## except StopIteration:
1305 ## if d is not None:
1306 ## diff = cmp(n_target,len(d))
1309 ## if source is None:
1312 ## n += 1 # COUNT SOURCE NODES
1319 ## diff = cmp(edge,d[target])
1324 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1325 ## return cmp(n,len(other))
1327 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1329 class SQLIDGraph(SQLGraph
):
1330 add_trivial_packing_methods(locals())
1331 SQLGraph
._IDGraphClass
= SQLIDGraph
1335 class SQLEdgeDictClustered(dict):
1336 'simple cache for 2nd level dictionary of target_id:edge_id'
1337 def __init__(self
,g
,fromNode
):
1339 self
.fromNode
=fromNode
1341 def __iadd__(self
,l
):
1342 for target_id
,edge_id
in l
:
1343 dict.__setitem
__(self
,target_id
,edge_id
)
1344 return self
# iadd MUST RETURN SELF!
1346 class SQLEdgesClusteredDescr(object):
1347 def __get__(self
,obj
,objtype
):
1348 e
=SQLEdgesClustered(obj
.table
,obj
.edge_id
,obj
.source_id
,obj
.target_id
,
1349 graph
=obj
,**graph_db_inverse_refs(obj
,True))
1350 for source_id
,d
in obj
.d
.iteritems(): # COPY EDGE CACHE
1351 e
.load([(edge_id
,source_id
,target_id
)
1352 for (target_id
,edge_id
) in d
.iteritems()])
1355 class SQLGraphClustered(object):
1356 'SQL graph with clustered caching -- loads an entire cluster at a time'
1357 _edgeDictClass
=SQLEdgeDictClustered
1358 def __init__(self
,table
,source_id
='source_id',target_id
='target_id',
1359 edge_id
='edge_id',clusterKey
=None,**kwargs
):
1361 if isinstance(table
,types
.StringType
): # CREATE THE TABLE INTERFACE
1362 if clusterKey
is None:
1363 raise ValueError('you must provide a clusterKey argument!')
1364 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1365 c
= getColumnTypes(attrAlias
=dict(source_id
=source_id
,target_id
=target_id
,
1366 edge_id
=edge_id
),**kwargs
)
1367 kwargs
['createTable'] = \
1368 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1369 % (table
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],
1370 c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1371 table
= SQLTableClustered(table
,clusterKey
=clusterKey
,**kwargs
)
1373 self
.source_id
=source_id
1374 self
.target_id
=target_id
1375 self
.edge_id
=edge_id
1377 save_graph_db_refs(self
,**kwargs
)
1378 _pickleAttrs
= dict(table
=0,source_id
=0,target_id
=0,edge_id
=0,sourceDB
=0,targetDB
=0,
1380 def __getstate__(self
):
1381 state
= standard_getstate(self
)
1382 state
['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1384 def __getitem__(self
,k
):
1385 'get edgeDict for source node k, from cache or by loading its cluster'
1386 try: # GET DIRECTLY FROM CACHE
1389 if hasattr(self
,'_isLoaded'):
1390 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1391 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1392 sql
,params
= self
.table
._format
_query
('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1393 %(self
.source_id
,self
.target_id
,
1394 self
.edge_id
,self
.table
.name
,
1395 self
.table
.name
,self
.source_id
,
1396 self
.table
.clusterKey
,self
.table
.clusterKey
,
1397 self
.table
.primary_key
),
1398 (self
.pack_source(k
),))
1399 self
.table
.cursor
.execute(sql
, params
)
1400 self
.load(self
.table
.cursor
.fetchall()) # CACHE THIS CLUSTER
1401 return self
.d
[k
] # RETURN EDGE DICT FOR THIS NODE
1402 def load(self
,l
=None,unpack
=True):
1403 'load the specified rows (or all, if None provided) into local cache'
1405 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1406 return self
._isLoaded
1407 except AttributeError:
1409 self
.table
.cursor
.execute('select %s,%s,%s from %s'
1410 %(self
.source_id
,self
.target_id
,
1411 self
.edge_id
,self
.table
.name
))
1412 l
=self
.table
.cursor
.fetchall()
1414 self
.d
.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1415 for source
,target
,edge
in l
: # SAVE TO OUR CACHE
1417 source
= self
.unpack_source(source
)
1418 target
= self
.unpack_target(target
)
1419 edge
= self
.unpack_edge(edge
)
1421 self
.d
[source
] += [(target
,edge
)]
1423 d
= self
._edgeDictClass
(self
,source
)
1424 d
+= [(target
,edge
)]
1426 def __invert__(self
):
1427 'interface to reverse graph mapping'
1429 return self
._inverse
# INVERSE MAP ALREADY EXISTS
1430 except AttributeError:
1432 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1433 self
._inverse
=SQLGraphClustered(self
.table
,self
.target_id
,self
.source_id
,
1434 self
.edge_id
,**graph_db_inverse_refs(self
))
1435 self
._inverse
._inverse
=self
1436 for source
,d
in self
.d
.iteritems(): # INVERT OUR CACHE
1437 self
._inverse
.load([(target
,source
,edge
)
1438 for (target
,edge
) in d
.iteritems()],unpack
=False)
1439 return self
._inverse
1440 edges
=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1441 update
= update_graph
1442 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1443 def __iter__(self
): ################# ITERATORS
1444 'uses db select; does not force load'
1445 return iter(self
.keys())
1447 'uses db select; does not force load'
1448 self
.table
.cursor
.execute('select distinct(%s) from %s'
1449 %(self
.source_id
,self
.table
.name
))
1450 return [self
.unpack_source(t
[0])
1451 for t
in self
.table
.cursor
.fetchall()]
1452 methodFactory(['iteritems','items','itervalues','values'],
1453 'lambda self:(self.load(),self.d.%s())[1]',locals())
1454 def __contains__(self
,k
):
1461 class SQLIDGraphClustered(SQLGraphClustered
):
1462 add_trivial_packing_methods(locals())
1463 SQLGraphClustered
._IDGraphClass
= SQLIDGraphClustered
1465 class SQLEdgesClustered(SQLGraphClustered
):
1466 'edges interface for SQLGraphClustered'
1467 _edgeDictClass
= list
1468 _pickleAttrs
= SQLGraphClustered
._pickleAttrs
.copy()
1469 _pickleAttrs
.update(dict(graph
=0))
1473 for edge_id
,l
in self
.d
.iteritems():
1474 for source_id
,target_id
in l
:
1475 result
.append((self
.graph
.unpack_source(source_id
),
1476 self
.graph
.unpack_target(target_id
),
1477 self
.graph
.unpack_edge(edge_id
)))
1480 class ForeignKeyInverse(object):
1481 'map each key to a single value according to its foreign key'
1482 def __init__(self
,g
):
1484 def __getitem__(self
,obj
):
1486 source_id
= getattr(obj
,self
.g
.keyColumn
)
1487 if source_id
is None:
1489 return self
.g
.sourceDB
[source_id
]
1490 def __setitem__(self
,obj
,source
):
1492 if source
is not None:
1493 self
.g
[source
][obj
] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1494 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1495 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1496 old_source
= self
[obj
]
1497 if old_source
is not None:
1498 del self
.g
[old_source
][obj
]
1499 def check_obj(self
,obj
):
1500 'raise KeyError if obj not from this db'
1502 if obj
.db
!= self
.g
.targetDB
:
1503 raise AttributeError
1504 except AttributeError:
1505 raise KeyError('key is not from targetDB of this graph!')
1506 def __contains__(self
,obj
):
1513 return self
.g
.targetDB
.itervalues()
1515 return self
.g
.targetDB
.values()
1516 def iteritems(self
):
1518 source_id
= getattr(obj
,self
.g
.keyColumn
)
1519 if source_id
is None:
1522 yield obj
,self
.g
.sourceDB
[source_id
]
1524 return list(self
.iteritems())
1525 def itervalues(self
):
1526 for obj
,val
in self
.iteritems():
1529 return list(self
.itervalues())
1530 def __invert__(self
):
1533 class ForeignKeyEdge(dict):
1534 '''edge interface to a foreign key in an SQL table.
1535 Caches dict of target nodes in itself; provides dict interface.
1536 Adds or deletes edges by setting foreign key values in the table'''
1537 def __init__(self
,g
,k
):
1541 for v
in g
.targetDB
.select('where %s=%%s' % g
.keyColumn
,(k
.id,)): # SEARCH THE DB
1542 dict.__setitem
__(self
,v
,None) # SAVE IN CACHE
1543 def __setitem__(self
,dest
,v
):
1544 if not hasattr(dest
,'db') or dest
.db
!= self
.g
.targetDB
:
1545 raise KeyError('dest is not in the targetDB bound to this graph!')
1547 raise ValueError('sorry,this graph cannot store edge information!')
1548 if not hasattr(dest
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1549 old_source
= self
.g
._inverse
[dest
] # CHECK FOR PRE-EXISTING EDGE
1550 if old_source
is not None: # REMOVE OLD EDGE FROM CACHE
1551 dict.__delitem
__(self
.g
[old_source
],dest
)
1552 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1553 setattr(dest
,self
.g
.keyColumn
,self
.src
.id) # SAVE TO DB ATTRIBUTE
1554 dict.__setitem
__(self
,dest
,None) # SAVE IN CACHE
1555 def __delitem__(self
,dest
):
1556 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1557 setattr(dest
,self
.g
.keyColumn
,None) # SAVE TO DB ATTRIBUTE
1558 dict.__delitem
__(self
,dest
) # REMOVE FROM CACHE
1560 class ForeignKeyGraph(object, UserDict
.DictMixin
):
1561 '''graph interface to a foreign key in an SQL table
1562 Caches dict of target nodes in itself; provides dict interface.
1564 def __init__(self
, sourceDB
, targetDB
, keyColumn
, autoGC
=True, **kwargs
):
1565 '''sourceDB is any database of source nodes;
1566 targetDB must be an SQL database of target nodes;
1567 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1568 if autoGC
: # automatically garbage collect unused objects
1569 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
1571 self
._weakValueDict
= {}
1572 self
.autoGC
= autoGC
1573 self
.sourceDB
= sourceDB
1574 self
.targetDB
= targetDB
1575 self
.keyColumn
= keyColumn
1576 self
._inverse
= ForeignKeyInverse(self
)
1577 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, keyColumn
=0, autoGC
=0)
1578 __getstate__
= standard_getstate
########### SUPPORT FOR PICKLING
1579 __setstate__
= standard_setstate
1580 def _inverse_schema(self
):
1581 'provide custom schema rule for inverting this graph... just use keyColumn!'
1582 return dict(invert
=True,uniqueMapping
=True)
1583 def __getitem__(self
,k
):
1584 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1585 raise KeyError('object is not in the sourceDB bound to this graph!')
1587 return self
._weakValueDict
[k
.id] # get from cache
1590 d
= ForeignKeyEdge(self
,k
)
1591 self
._weakValueDict
[k
.id] = d
# save in cache
1593 def __setitem__(self
, k
, v
):
1594 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1595 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1596 def __delitem__(self
, k
):
1597 raise KeyError('''Instead of del g[k], follow a graph
1598 interface: del g[src][dest]''')
1600 return self
.sourceDB
.values()
1601 __invert__
= standard_invert
1603 def describeDBTables(name
,cursor
,idDict
):
1605 Get table info about database <name> via <cursor>, and store primary keys
1606 in idDict, along with a list of the tables each key indexes.
1608 cursor
.execute('use %s' % name
)
1609 cursor
.execute('show tables')
1611 l
=[c
[0] for c
in cursor
.fetchall()]
1614 o
=SQLTable(tname
,cursor
)
1616 for f
in o
.description
:
1617 if f
==o
.primary_key
:
1618 idDict
.setdefault(f
, []).append(o
)
1619 elif f
[-3:]=='_id' and f
not in idDict
:
1625 def indexIDs(tables
,idDict
=None):
1626 "Get an index of primary keys in the <tables> dictionary."
1629 for o
in tables
.values():
1631 if o
.primary_key
not in idDict
:
1632 idDict
[o
.primary_key
]=[]
1633 idDict
[o
.primary_key
].append(o
) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1634 for f
in o
.description
:
1635 if f
[-3:]=='_id' and f
not in idDict
:
1641 def suffixSubset(tables
,suffix
):
1642 "Filter table index for those matching a specific suffix"
1644 for name
,t
in tables
.items():
1645 if name
.endswith(suffix
):
1652 def graphDBTables(tables
,idDict
):
1654 for t
in tables
.values():
1655 for f
in t
.description
:
1656 if f
==t
.primary_key
:
1657 edgeInfo
=PRIMARY_KEY
1660 g
.setEdge(f
,t
,edgeInfo
)
1661 g
.setEdge(t
,f
,edgeInfo
)
1664 SQLTypeTranslation
= {types
.StringType
:'varchar(32)',
1665 types
.IntType
:'int',
1666 types
.FloatType
:'float'}
1668 def createTableFromRepr(rows
,tableName
,cursor
,typeTranslation
=None,
1669 optionalDict
=None,indexDict
=()):
1670 """Save rows into SQL tableName using cursor, with optional
1671 translations of columns to specific SQL types (specified
1672 by typeTranslation dict).
1673 - optionDict can specify columns that are allowed to be NULL.
1674 - indexDict can specify columns that must be indexed; columns
1675 whose names end in _id will be indexed by default.
1676 - rows must be an iterator which in turn returns dictionaries,
1677 each representing a tuple of values (indexed by their column
1681 row
=rows
.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1682 except StopIteration:
1683 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1685 createTableFromRow(cursor
, tableName
,row
,typeTranslation
,
1686 optionalDict
,indexDict
)
1689 storeRow(cursor
,tableName
,row
) # SAVE OUR FIRST ROW
1690 for row
in rows
: # NOW SAVE ALL THE ROWS
1691 storeRow(cursor
,tableName
,row
)
1693 def createTableFromRow(cursor
, tableName
, row
,typeTranslation
=None,
1694 optionalDict
=None,indexDict
=()):
1696 for col
,val
in row
.items(): # PREPARE SQL TYPES FOR COLUMNS
1698 if typeTranslation
!=None and col
in typeTranslation
:
1699 coltype
=typeTranslation
[col
] # USER-SUPPLIED TRANSLATION
1700 elif type(val
) in SQLTypeTranslation
:
1701 coltype
=SQLTypeTranslation
[type(val
)]
1702 else: # SEARCH FOR A COMPATIBLE TYPE
1703 for t
in SQLTypeTranslation
:
1704 if isinstance(val
,t
):
1705 coltype
=SQLTypeTranslation
[t
]
1708 raise TypeError("Don't know SQL type to use for %s" % col
)
1709 create_def
='%s %s' %(col
,coltype
)
1710 if optionalDict
==None or col
not in optionalDict
:
1711 create_def
+=' not null'
1712 create_defs
.append(create_def
)
1713 for col
in row
: # CREATE INDEXES FOR ID COLUMNS
1714 if col
[-3:]=='_id' or col
in indexDict
:
1715 create_defs
.append('index(%s)' % col
)
1716 cmd
='create table if not exists %s (%s)' % (tableName
,','.join(create_defs
))
1717 cursor
.execute(cmd
) # CREATE THE TABLE IN THE DATABASE
1720 def storeRow(cursor
, tableName
, row
):
1721 row_format
=','.join(len(row
)*['%s'])
1722 cmd
='insert into %s values (%s)' % (tableName
,row_format
)
1723 cursor
.execute(cmd
,tuple(row
.values()))
1725 def storeRowDelayed(cursor
, tableName
, row
):
1726 row_format
=','.join(len(row
)*['%s'])
1727 cmd
='insert delayed into %s values (%s)' % (tableName
,row_format
)
1728 cursor
.execute(cmd
,tuple(row
.values()))
1731 class TableGroup(dict):
1732 'provide attribute access to dbname qualified tablenames'
1733 def __init__(self
,db
='test',suffix
=None,**kw
):
1736 if suffix
is not None:
1738 for k
,v
in kw
.items():
1739 if v
is not None and '.' not in v
:
1740 v
=self
.db
+'.'+v
# ADD DATABASE NAME AS PREFIX
1742 def __getattr__(self
,k
):
1745 def sqlite_connect(*args
, **kwargs
):
1746 sqlite
= import_sqlite()
1747 connection
= sqlite
.connect(*args
, **kwargs
)
1748 cursor
= connection
.cursor()
1749 return connection
, cursor
1751 # list of database connection functions DBServerInfo knows how to use
1752 _DBServerModuleDict
= dict(MySQLdb
=mysql_connect
, sqlite
=sqlite_connect
)
1755 class DBServerInfo(object):
1756 'picklable reference to a database server'
1757 def __init__(self
, moduleName
='MySQLdb', *args
, **kwargs
):
1758 if moduleName
not in _DBServerModuleDict
:
1759 raise ValueError('Module name not found in _DBServerModuleDict: '\
1761 self
.moduleName
= moduleName
1763 self
.kwargs
= kwargs
# connection arguments
1766 """returns cursor associated with the DB server info (reused)"""
1769 except AttributeError:
1770 self
._start
_connection
()
1773 def new_cursor(self
):
1774 """returns a NEW cursor; you must close it yourself! """
1775 if not hasattr(self
, '_connection'):
1776 self
._start
_connection
()
1777 return self
._connection
.cursor()
1779 def _start_connection(self
):
1781 moduleName
= self
.moduleName
1782 except AttributeError:
1783 moduleName
= 'MySQLdb'
1784 connect
= _DBServerModuleDict
[moduleName
]
1785 self
._connection
,self
._cursor
= connect(*self
.args
, **self
.kwargs
)
1788 """Close file containing this database"""
1789 self
._cursor
.close()
1790 self
._connection
.close()
1792 del self
._connection
1794 def __getstate__(self
):
1795 """return all picklable arguments"""
1796 return dict(args
=self
.args
, kwargs
=self
.kwargs
,
1797 moduleName
=self
.moduleName
)
1800 class SQLiteServerInfo(DBServerInfo
):
1801 """picklable reference to a sqlite database"""
1802 def __init__(self
, database
, *args
, **kwargs
):
1803 """Takes same arguments as sqlite3.connect()"""
1804 DBServerInfo
.__init
__(self
, 'sqlite',
1805 SourceFileName(database
), # save abs path!
1807 def __getstate__(self
):
1808 if self
.args
[0] == ':memory:':
1809 raise ValueError('SQLite in-memory database is not picklable!')
1810 return DBServerInfo
.__getstate
__(self
)
1813 class MapView(object, UserDict
.DictMixin
):
1814 'general purpose 1:1 mapping defined by any SQL query'
1815 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None,
1816 serverInfo
=None, inverseSQL
=None, **kwargs
):
1817 self
.sourceDB
= sourceDB
1818 self
.targetDB
= targetDB
1819 self
.viewSQL
= viewSQL
1820 self
.inverseSQL
= inverseSQL
1822 if serverInfo
is not None: # get cursor from serverInfo
1823 cursor
= serverInfo
.cursor()
1825 try: # can we get it from our other db?
1826 serverInfo
= sourceDB
.serverInfo
1827 except AttributeError:
1828 raise ValueError('you must provide serverInfo or cursor!')
1830 cursor
= serverInfo
.cursor()
1831 self
.cursor
= cursor
1832 self
.serverInfo
= serverInfo
1833 self
.get_sql_format(False) # get sql formatter for this db interface
1834 _schemaModuleDict
= _schemaModuleDict
# default module list
1835 get_sql_format
= get_table_schema
1836 def __getitem__(self
, k
):
1837 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1838 raise KeyError('object is not in the sourceDB bound to this map!')
1839 sql
,params
= self
._format
_query
(self
.viewSQL
, (k
.id,))
1840 self
.cursor
.execute(sql
, params
) # formatted for this db interface
1841 t
= self
.cursor
.fetchmany(2) # get at most two rows
1843 raise KeyError('%s not found in MapView, or not unique'
1845 return self
.targetDB
[t
[0][0]] # get the corresponding object
1846 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, viewSQL
=0, serverInfo
=0,
1848 __getstate__
= standard_getstate
1849 __setstate__
= standard_setstate
1850 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
1851 setdefault
= read_only_error
1853 'only yield sourceDB items that are actually in this mapping!'
1854 for k
in self
.sourceDB
.itervalues():
1861 return [k
for k
in self
] # don't use list(self); causes infinite loop!
1862 def __invert__(self
):
1864 return self
._inverse
1865 except AttributeError:
1866 if self
.inverseSQL
is None:
1867 raise ValueError('this MapView has no inverseSQL!')
1868 self
._inverse
= MapView(self
.targetDB
, self
.sourceDB
,
1869 self
.inverseSQL
, self
.cursor
,
1870 serverInfo
=self
.serverInfo
,
1871 inverseSQL
=self
.viewSQL
)
1872 self
._inverse
._inverse
= self
1873 return self
._inverse
1875 class GraphViewEdgeDict(UserDict
.DictMixin
):
1876 'edge dictionary for GraphView: just pre-loaded on init'
1877 def __init__(self
, g
, k
):
1880 sql
,params
= self
.g
._format
_query
(self
.g
.viewSQL
, (k
.id,))
1881 self
.g
.cursor
.execute(sql
, params
) # run the query
1882 l
= self
.g
.cursor
.fetchall() # get results
1883 self
.targets
= [t
[0] for t
in l
] # preserve order of the results
1884 d
= {} # also keep targetID:edgeID mapping
1885 if self
.g
.edgeDB
is not None: # save with edge info
1893 return len(self
.targets
)
1895 for k
in self
.targets
:
1896 yield self
.g
.targetDB
[k
]
1899 def iteritems(self
):
1900 if self
.g
.edgeDB
is not None: # save with edge info
1901 for k
in self
.targets
:
1902 yield (self
.g
.targetDB
[k
], self
.g
.edgeDB
[self
.targetDict
[k
]])
1903 else: # just save the list of targets, no edge info
1904 for k
in self
.targets
:
1905 yield (self
.g
.targetDB
[k
], None)
1906 def __getitem__(self
, o
, exitIfFound
=False):
1907 'for the specified target object, return its associated edge object'
1909 if o
.db
is not self
.g
.targetDB
:
1910 raise KeyError('key is not part of targetDB!')
1911 edgeID
= self
.targetDict
[o
.id]
1912 except AttributeError:
1913 raise KeyError('key has no id or db attribute?!')
1916 if self
.g
.edgeDB
is not None: # return the edge object
1917 return self
.g
.edgeDB
[edgeID
]
1918 else: # no edge info
1920 def __contains__(self
, o
):
1922 self
.__getitem
__(o
, True) # raise KeyError if not found
1926 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
1927 setdefault
= read_only_error
1929 class GraphView(MapView
):
1930 'general purpose graph interface defined by any SQL query'
1931 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None, edgeDB
=None,
1933 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
1934 self
.edgeDB
= edgeDB
1935 MapView
.__init
__(self
, sourceDB
, targetDB
, viewSQL
, cursor
, **kwargs
)
1936 def __getitem__(self
, k
):
1937 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1938 raise KeyError('object is not in the sourceDB bound to this map!')
1939 return GraphViewEdgeDict(self
, k
)
1940 _pickleAttrs
= MapView
._pickleAttrs
.copy()
1941 _pickleAttrs
.update(dict(edgeDB
=0))
1943 # @CTB move to sqlgraph.py?
1945 class SQLSequence(SQLRow
, SequenceBase
):
1946 """Transparent access to a DB row representing a sequence.
1948 Use attrAlias dict to rename 'length' to something else.
1950 def _init_subclass(cls
, db
, **kwargs
):
1951 db
.seqInfoDict
= db
# db will act as its own seqInfoDict
1952 SQLRow
._init
_subclass
(db
=db
, **kwargs
)
1953 _init_subclass
= classmethod(_init_subclass
)
1954 def __init__(self
, id):
1955 SQLRow
.__init
__(self
, id)
1956 SequenceBase
.__init
__(self
)
1959 def strslice(self
,start
,end
):
1960 "Efficient access to slice of a sequence, useful for huge contigs"
1961 return self
._select
('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
1962 %(self
.db
._attrSQL
('seq'),start
+1,end
-start
))
1964 class DNASQLSequence(SQLSequence
):
1965 _seqtype
=DNA_SEQTYPE
1967 class RNASQLSequence(SQLSequence
):
1968 _seqtype
=RNA_SEQTYPE
1970 class ProteinSQLSequence(SQLSequence
):
1971 _seqtype
=PROTEIN_SEQTYPE