3 from __future__
import generators
5 from sequence
import SequenceBase
, DNA_SEQTYPE
, RNA_SEQTYPE
, PROTEIN_SEQTYPE
7 from classutil
import methodFactory
,standard_getstate
,\
8 override_rich_cmp
,generate_items
,get_bound_subclass
,standard_setstate
,\
9 get_valid_path
,standard_invert
,RecentValueDictionary
,read_only_error
,\
10 SourceFileName
, split_kwargs
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self
, db
, attr
):
20 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
21 def __get__(self
, obj
, klass
):
22 return obj
._data
[self
.icol
]
23 def __set__(self
, obj
, val
):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor
):
27 def __set__(self
, obj
, val
):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor
):
32 'read-write interface to named attribute'
33 def __init__(self
, db
, attr
):
35 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
36 self
.attrSQL
= db
._attrSQL
(attr
, sqlColumn
=True) # SQL column name
37 def __set__(self
, obj
, val
):
38 obj
.db
._update
(obj
.id, self
.attrSQL
, val
) # AND UPDATE THE DATABASE
39 obj
.save_local(self
.attr
, val
)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self
, db
, attr
):
44 self
.selectSQL
= db
._attrSQL
(attr
) # SQL expression for this attr
45 def __get__(self
, obj
, klass
):
46 return obj
._select
(self
.selectSQL
)
47 def __set__(self
, obj
, val
):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor
):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self
, obj
, val
):
53 obj
.db
._update
(obj
.id, self
.selectSQL
, val
) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self
, db
, attr
):
59 def __get__(self
, obj
, klass
):
60 return getattr(obj
, self
.attr
)
61 def __set__(self
, obj
, val
):
62 raise AttributeError('attribute %s is read-only!' % self
.attr
)
65 def select_from_row(row
, what
):
66 "return value from SQL expression applied to this row"
67 sql
,params
= row
.db
._format
_query
('select %s from %s where %s=%%s limit 2'
68 % (what
,row
.db
.name
,row
.db
.primary_key
),
70 row
.db
.cursor
.execute(sql
, params
)
71 t
= row
.db
.cursor
.fetchmany(2) # get at most two rows
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row
.db
.name
,str(row
.id),what
))
75 return t
[0][0] #return the single field we requested
77 def init_row_subclass(cls
, db
):
78 'add descriptors for db attributes'
79 for attr
in db
.data
: # bind all database columns
80 if attr
== 'id': # handle ID attribute specially
81 setattr(cls
, attr
, cls
._idDescriptor
(db
, attr
))
83 try: # check if this attr maps to an SQL column
84 db
._attrSQL
(attr
, columnNumber
=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls
, attr
, cls
._sqlDescriptor
(db
, attr
))
87 else: # treat as interface to our stored tuple
88 setattr(cls
, attr
, cls
._columnDescriptor
(db
, attr
))
91 """get list of column names as our attributes """
92 return self
.db
.data
.keys()
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor
= TupleDescriptor
115 _idDescriptor
= TupleIDDescriptor
116 _sqlDescriptor
= SQLDescriptor
117 _init_subclass
= classmethod(init_row_subclass
)
118 _select
= select_from_row
120 def __init__(self
, data
):
121 self
._data
= data
# save our data tuple
123 def insert_and_cache_id(self
, l
, **kwargs
):
124 'insert tuple into db and cache its rowID on self'
125 self
.db
._insert
(l
) # save to database
127 rowID
= kwargs
['id'] # use the ID supplied by user
129 rowID
= self
.db
.get_insert_id() # get auto-inc ID value
130 self
.cache_id(rowID
) # cache this ID on self
132 class TupleORW(TupleO
):
133 'read-write version of TupleO'
134 _columnDescriptor
= TupleDescriptorRW
135 insert_and_cache_id
= insert_and_cache_id
136 def __init__(self
, data
, newRow
=False, **kwargs
):
137 if not newRow
: # just cache data from the database
140 self
._data
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
141 self
.insert_and_cache_id(self
._data
, **kwargs
)
142 def cache_id(self
,row_id
):
143 self
.save_local('id',row_id
)
144 def save_local(self
,attr
,val
):
145 icol
= self
._attrcol
[attr
]
147 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self
._data
= list(self
._data
)
150 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
152 TupleO
._RWClass
= TupleORW
# record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self
, db
, attr
, readOnly
= False):
158 self
.col
= db
._attrSQL
(attr
, sqlColumn
=True) # MAP THIS TO SQL COLUMN NAME
161 self
.__class
__ = self
._readOnlyClass
162 def __get__(self
, obj
, objtype
):
164 return obj
.__dict
__[self
.attr
]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self
.col
==self
.db
.primary_key
:
168 self
.db
._select
('where %s=%%s' % self
.db
.primary_key
,(obj
.id,),self
.col
)
169 l
= self
.db
.cursor
.fetchall()
171 raise AttributeError('db row not found or not unique!')
172 obj
.__dict
__[self
.attr
] = l
[0][0] # UPDATE THE CACHE
174 def __set__(self
, obj
, val
):
175 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self
.db
._update
(obj
.id, self
.col
, val
) # UPDATE THE DATABASE
177 obj
.__dict
__[self
.attr
] = val
# UPDATE THE CACHE
179 ## m = self.consequences
180 ## except AttributeError:
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor
):
188 def __set__(self
, obj
, val
):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor
._readOnlyClass
= ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor
= _sqlDescriptor
= SQLDescriptor
200 _idDescriptor
= ReadOnlyDescriptor
201 _init_subclass
= classmethod(init_row_subclass
)
202 _select
= select_from_row
204 def __init__(self
, rowID
):
208 class SQLRowRW(SQLRow
):
209 'read-write version of SQLRow'
210 _columnDescriptor
= SQLDescriptorRW
211 insert_and_cache_id
= insert_and_cache_id
212 def __init__(self
, rowID
, newRow
=False, **kwargs
):
213 if not newRow
: # just cache data from the database
214 return self
.cache_id(rowID
)
215 l
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
216 self
.insert_and_cache_id(l
, **kwargs
)
217 def cache_id(self
, rowID
):
220 SQLRow
._RWClass
= SQLRowRW
224 def list_to_dict(names
, values
):
225 'return dictionary of those named args that are present in values[]'
227 for i
,v
in enumerate(values
):
235 def get_name_cursor(name
=None, **kwargs
):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
240 argList
= name
.split() # TREAT AS WS-SEPARATED LIST
242 name
= argList
[0] # USE 1ST ARG AS TABLE NAME
243 argnames
= ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs
= kwargs
.copy() # a copy we can overwrite
245 kwargs
.update(list_to_dict(argnames
, argList
[1:]))
246 serverInfo
= DBServerInfo(**kwargs
)
247 return name
,serverInfo
.cursor(),serverInfo
249 def mysql_connect(connect
=None, configFile
=None, useStreaming
=False, **args
):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs
= args
.copy() # a copy we can modify
252 if 'user' not in kwargs
and configFile
is None: #Find where config file is
253 osname
= platform
.system()
254 if osname
in('Microsoft', 'Windows'): # Machine is a Windows box
256 try: # handle case where WINDIR not defined by Windows...
257 windir
= os
.environ
['WINDIR']
258 paths
+= [(windir
, 'my.ini'), (windir
, 'my.cnf')]
262 sysdrv
= os
.environ
['SYSTEMDRIVE']
263 paths
+= [(sysdrv
, os
.path
.sep
+ 'my.ini'),
264 (sysdrv
, os
.path
.sep
+ 'my.cnf')]
268 configFile
= get_valid_path(*paths
)
269 else: # treat as normal platform with home directories
270 configFile
= os
.path
.join(os
.path
.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile
= configFile
or os
.path
.join( os
.getcwd(), 'mysql.cnf' )
276 if configFile
and os
.path
.exists(configFile
):
277 kwargs
['read_default_file'] = configFile
278 connect
= None # force it to use MySQLdb
281 connect
= MySQLdb
.connect
282 kwargs
['compress'] = True
283 if useStreaming
: # use server side cursors for scalable result sets
285 from MySQLdb
import cursors
286 kwargs
['cursorclass'] = cursors
.SSCursor
287 except (ImportError, AttributeError):
289 conn
= connect(**kwargs
)
290 cursor
= conn
.cursor()
293 _mysqlMacros
= dict(IGNORE
='ignore', REPLACE
='replace',
294 AUTO_INCREMENT
='AUTO_INCREMENT', SUBSTRING
='substring',
295 SUBSTR_FROM
='FROM', SUBSTR_FOR
='FOR')
297 def mysql_table_schema(self
, analyzeSchema
=True):
298 'retrieve table schema from a MySQL database, save on self'
300 self
._format
_query
= SQLFormatDict(MySQLdb
.paramstyle
, _mysqlMacros
)
301 if not analyzeSchema
:
303 self
.clear_schema() # reset settings and dictionaries
304 self
.cursor
.execute('describe %s' % self
.name
) # get info about columns
305 columns
= self
.cursor
.fetchall()
306 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
307 for icol
,c
in enumerate(columns
):
309 self
.columnName
.append(field
) # list of columns in same order as table
310 if c
[3] == "PRI": # record as primary key
311 if self
.primary_key
is None:
312 self
.primary_key
= field
315 self
.primary_key
.append(field
)
316 except AttributeError:
317 self
.primary_key
= [self
.primary_key
,field
]
318 if c
[1][:3].lower() == 'int':
319 self
.usesIntID
= True
321 self
.usesIntID
= False
323 self
.indexed
[field
] = icol
324 self
.description
[field
] = self
.cursor
.description
[icol
]
325 self
.columnType
[field
] = c
[1] # SQL COLUMN TYPE
327 _sqliteMacros
= dict(IGNORE
='or ignore', REPLACE
='insert or replace',
328 AUTO_INCREMENT
='', SUBSTRING
='substr',
329 SUBSTR_FROM
=',', SUBSTR_FOR
=',')
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
334 import sqlite3
as sqlite
336 from pysqlite2
import dbapi2
as sqlite
339 def sqlite_table_schema(self
, analyzeSchema
=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite
= import_sqlite()
342 self
._format
_query
= SQLFormatDict(sqlite
.paramstyle
, _sqliteMacros
)
343 if not analyzeSchema
:
345 self
.clear_schema() # reset settings and dictionaries
346 self
.cursor
.execute('PRAGMA table_info("%s")' % self
.name
)
347 columns
= self
.cursor
.fetchall()
348 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
349 for icol
,c
in enumerate(columns
):
351 self
.columnName
.append(field
) # list of columns in same order as table
352 self
.description
[field
] = self
.cursor
.description
[icol
]
353 self
.columnType
[field
] = c
[2] # SQL COLUMN TYPE
354 self
.cursor
.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self
.name
) # get primary key / unique indexes
355 for indexname
in self
.cursor
.fetchall(): # search indexes for primary key
356 self
.cursor
.execute('PRAGMA index_info("%s")' % indexname
)
357 l
= self
.cursor
.fetchall() # get list of columns in this index
358 if len(l
) == 1: # assume 1st single-column unique index is primary key!
359 self
.primary_key
= l
[0][2]
360 break # done searching for primary key!
361 if self
.primary_key
is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self
.cursor
.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self
.name
)
363 sql
= self
.cursor
.fetchall()[0][0]
364 for columnSQL
in sql
[sql
.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL
.lower(): # must be the primary key!
366 col
= columnSQL
.split()[0] # get column name
367 if col
in self
.columnType
:
368 self
.primary_key
= col
369 break # done searching for primary key!
371 raise ValueError('unknown primary key %s in table %s'
373 if self
.primary_key
is not None: # check its type
374 if self
.columnType
[self
.primary_key
] == 'int' or \
375 self
.columnType
[self
.primary_key
] == 'integer':
376 self
.usesIntID
= True
378 self
.usesIntID
= False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats
= dict(pyformat
='%%(%d)s', numeric
=':%d', named
=':%d',
393 qmark
='(ignore)', format
='(ignore)')
394 def __init__(self
, paramstyle
, substitutionDict
={}):
395 self
.substitutionDict
= substitutionDict
.copy()
396 self
.paramstyle
= paramstyle
397 self
.paramFormat
= self
._paramFormats
[paramstyle
]
398 self
.makeDict
= (paramstyle
== 'pyformat' or paramstyle
== 'named')
399 if paramstyle
== 'qmark': # handle these as simple substitution
400 self
.substitutionDict
['?'] = '?'
401 elif paramstyle
== 'format':
402 self
.substitutionDict
['?'] = '%s'
403 def __getitem__(self
, k
):
404 'apply correct substitution for this SQL interface'
406 return self
.substitutionDict
[k
] # apply our substitutions
409 if k
== '?': # sequential parameter
410 s
= self
.paramFormat
% self
.iparam
411 self
.iparam
+= 1 # advance to the next parameter
413 raise KeyError('unknown macro: %s' % k
)
414 def __call__(self
, sql
, paramList
):
415 'returns corrected sql,params for this interface'
416 self
.iparam
= 1 # DB-ABI param indexing begins at 1
417 sql
= sql
.replace('%s', '%(?)s') # convert format into pyformat
418 s
= sql
% self
# apply all %(x)s replacements in sql
419 if self
.makeDict
: # construct a params dict
421 for i
,param
in enumerate(paramList
):
422 paramDict
[str(i
+ 1)] = param
#DB-ABI param indexing begins at 1
424 else: # just return the original params list
427 def get_table_schema(self
, analyzeSchema
=True):
428 'run the right schema function based on type of db server connection'
430 modname
= self
.cursor
.__class
__.__module
__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
434 schema_func
= self
._schemaModuleDict
[modname
]
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname
)
439 schema_func(self
, analyzeSchema
) # run the schema function
442 _schemaModuleDict
= {'MySQLdb.cursors':mysql_table_schema
,
443 'pysqlite2.dbapi2':sqlite_table_schema
,
444 'sqlite3':sqlite_table_schema
}
446 class SQLTableBase(object, UserDict
.DictMixin
):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict
= _schemaModuleDict
# default module list
449 get_table_schema
= get_table_schema
450 def __init__(self
,name
,cursor
=None,itemClass
=None,attrAlias
=None,
451 clusterKey
=None,createTable
=None,graph
=None,maxCache
=None,
452 arraysize
=1024, itemSliceClass
=None, dropIfExists
=False,
453 serverInfo
=None, autoGC
=True, orderBy
=None,
454 writeable
=False, **kwargs
):
455 if autoGC
: # automatically garbage collect unused objects
456 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
458 self
._weakValueDict
= {}
460 self
.orderBy
= orderBy
461 self
.writeable
= writeable
463 if serverInfo
is not None: # get cursor from serverInfo
464 cursor
= serverInfo
.cursor()
465 else: # try to read connection info from name or config file
466 name
,cursor
,serverInfo
= get_name_cursor(name
,**kwargs
)
468 warnings
.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
469 DeprecationWarning, stacklevel
=2)
471 if createTable
is not None: # RUN COMMAND TO CREATE THIS TABLE
472 if dropIfExists
: # get rid of any existing table
473 cursor
.execute('drop table if exists ' + name
)
474 self
.get_table_schema(False) # check dbtype, init _format_query
475 sql
,params
= self
._format
_query
(createTable
, ()) # apply macros
476 cursor
.execute(sql
) # create the table
478 if graph
is not None:
480 if maxCache
is not None:
481 self
.maxCache
= maxCache
482 if arraysize
is not None:
483 self
.arraysize
= arraysize
484 cursor
.arraysize
= arraysize
485 self
.get_table_schema() # get schema of columns to serve as attrs
486 self
.data
= {} # map of all attributes, including aliases
487 for icol
,field
in enumerate(self
.columnName
):
488 self
.data
[field
] = icol
# 1st add mappings to columns
490 self
.data
['id']=self
.data
[self
.primary_key
]
491 except (KeyError,TypeError):
493 if hasattr(self
,'_attr_alias'): # apply attribute aliases for this class
494 self
.addAttrAlias(False,**self
._attr
_alias
)
495 self
.objclass(itemClass
) # NEED TO SUBCLASS OUR ITEM CLASS
496 if itemSliceClass
is not None:
497 self
.itemSliceClass
= itemSliceClass
498 get_bound_subclass(self
, 'itemSliceClass', self
.name
) # need to subclass itemSliceClass
499 if attrAlias
is not None: # ADD ATTRIBUTE ALIASES
500 self
.attrAlias
= attrAlias
# RECORD FOR PICKLING PURPOSES
501 self
.data
.update(attrAlias
)
502 if clusterKey
is not None:
503 self
.clusterKey
=clusterKey
504 if serverInfo
is not None:
505 self
.serverInfo
= serverInfo
508 self
._select
(selectCols
='count(*)')
509 return self
.cursor
.fetchone()[0]
512 _pickleAttrs
= dict(name
=0, clusterKey
=0, maxCache
=0, arraysize
=0,
513 attrAlias
=0, serverInfo
=0, autoGC
=0, orderBy
=0,
515 __getstate__
= standard_getstate
516 def __setstate__(self
,state
):
517 # default cursor provisioning by worldbase is deprecated!
518 ## if 'serverInfo' not in state: # hmm, no address for db server?
519 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
520 ## from Data import getResource
521 ## state['cursor'] = getResource.getTableCursor(state['name'])
522 ## except ImportError:
523 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
524 self
.__init
__(**state
)
526 return '<SQL table '+self
.name
+'>'
528 def clear_schema(self
):
529 'reset all schema information for this table'
533 self
.usesIntID
= None
534 self
.primary_key
= None
536 def _attrSQL(self
,attr
,sqlColumn
=False,columnNumber
=False):
537 "Translate python attribute name to appropriate SQL expression"
538 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
539 field
=self
.data
[attr
]
541 raise AttributeError('attribute %s not a valid column or alias in %s'
543 if sqlColumn
: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
544 try: # CHECK IF field IS COLUMN NUMBER
545 return self
.columnName
[field
] # RETURN SQL COLUMN NAME
547 try: # CHECK IF field IS SQL COLUMN NAME
548 return self
.columnName
[self
.data
[field
]] # THIS WILL JUST RETURN field...
549 except (KeyError,TypeError):
550 raise AttributeError('attribute %s does not map to an SQL column in %s'
553 try: # CHECK IF field IS A COLUMN NUMBER
554 return field
+0 # ONLY RETURN AN INTEGER
556 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
557 return self
.data
[field
]+0 # ONLY RETURN AN INTEGER
558 except (KeyError,TypeError):
559 raise ValueError('attribute %s does not map to a SQL column!' % attr
)
560 if isinstance(field
,types
.StringType
):
561 attr
=field
# USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
563 attr
=self
.primary_key
565 def addAttrAlias(self
,saveToPickle
=True,**kwargs
):
566 """Add new attributes as aliases of existing attributes.
567 They can be specified either as named args:
568 t.addAttrAlias(newattr=oldattr)
569 or by passing a dictionary kwargs whose keys are newattr
570 and values are oldattr:
571 t.addAttrAlias(**kwargs)
572 saveToPickle=True forces these aliases to be saved if object is pickled.
575 self
.attrAlias
.update(kwargs
)
576 for key
,val
in kwargs
.items():
577 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
578 self
.data
[val
]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
579 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
580 except TypeError: # val IS ITSELF AN ALIAS
581 self
.data
[key
] = self
.data
[val
] # SO MAP TO WHAT IT MAPS TO
582 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
584 def objclass(self
,oclass
=None):
585 "Create class representing a row in this table by subclassing oclass, adding data"
586 if oclass
is not None: # use this as our base itemClass
587 self
.itemClass
= oclass
589 self
.itemClass
= self
.itemClass
._RWClass
# use its writeable version
590 oclass
= get_bound_subclass(self
, 'itemClass', self
.name
,
591 subclassArgs
=dict(db
=self
)) # bind itemClass
592 if issubclass(oclass
, TupleO
):
593 oclass
._attrcol
= self
.data
# BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
594 if hasattr(oclass
,'_tableclass') and not isinstance(self
,oclass
._tableclass
):
595 self
.__class
__=oclass
._tableclass
# ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
596 def _select(self
, whereClause
='', params
=(), selectCols
='t1.*',
597 cursor
=None, orderBy
=''):
598 'execute the specified query but do not fetch'
599 sql
,params
= self
._format
_query
('select %s from %s t1 %s %s'
600 % (selectCols
, self
.name
, whereClause
, orderBy
),
603 self
.cursor
.execute(sql
, params
)
605 cursor
.execute(sql
, params
)
606 def select(self
,whereClause
,params
=None,oclass
=None,selectCols
='t1.*'):
607 "Generate the list of objects that satisfy the database SELECT"
609 oclass
=self
.itemClass
610 self
._select
(whereClause
,params
,selectCols
)
611 l
=self
.cursor
.fetchall()
613 yield self
.cacheItem(t
,oclass
)
614 def query(self
,**kwargs
):
615 'query for intersection of all specified kwargs, returned as iterator'
618 for k
,v
in kwargs
.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
619 if v
is None: # CONVERT TO SQL NULL TEST
620 criteria
.append('%s IS NULL' % self
._attrSQL
(k
))
621 else: # TEST FOR EQUALITY
622 criteria
.append('%s=%%s' % self
._attrSQL
(k
))
624 return self
.select('where '+' and '.join(criteria
),params
)
625 def _update(self
,row_id
,col
,val
):
626 'update a single field in the specified row to the specified value'
627 sql
,params
= self
._format
_query
('update %s set %s=%%s where %s=%%s'
628 %(self
.name
,col
,self
.primary_key
),
630 self
.cursor
.execute(sql
, params
)
633 return t
[self
.data
['id']] # GET ID FROM TUPLE
634 except TypeError: # treat as alias
635 return t
[self
.data
[self
.data
['id']]]
636 def cacheItem(self
,t
,oclass
):
637 'get obj from cache if possible, or construct from tuple'
640 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
642 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
643 return self
._weakValueDict
[id]
647 self
._weakValueDict
[id] = o
# CACHE THIS ITEM IN OUR DICTIONARY
649 def cache_items(self
,rows
,oclass
=None):
651 oclass
=self
.itemClass
653 yield self
.cacheItem(t
,oclass
)
654 def foreignKey(self
,attr
,k
):
655 'get iterator for objects with specified foreign key value'
656 return self
.select('where %s=%%s'%attr
,(k
,))
657 def limit_cache(self
):
658 'APPLY maxCache LIMIT TO CACHE SIZE'
660 if self
.maxCache
<len(self
._weakValueDict
):
661 self
._weakValueDict
.clear()
662 except AttributeError:
665 def get_new_cursor(self
):
666 """Return a new cursor object, or None if not possible """
668 new_cursor
= self
.serverInfo
.new_cursor
669 except AttributeError:
671 return new_cursor(self
.arraysize
)
673 def generic_iterator(self
, cursor
=None, fetch_f
=None, cache_f
=None,
675 'generic iterator that runs fetch, cache and map functions'
676 if fetch_f
is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
678 fetch_f
= self
.cursor
.fetchmany
679 else: # isolate this iter from other queries
680 fetch_f
= cursor
.fetchmany
682 cache_f
= self
.cache_items
685 rows
= fetch_f() # FETCH THE NEXT SET OF ROWS
686 if len(rows
)==0: # NO MORE DATA SO ALL DONE
688 for v
in map_f(cache_f(rows
)): # CACHE AND GENERATE RESULTS
690 if cursor
is not None: # close iterator now that we're done
692 def tuple_from_dict(self
, d
):
693 'transform kwarg dict into tuple for storing in database'
694 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
695 for col
,icol
in self
.data
.items():
698 except (KeyError,TypeError):
701 def tuple_from_obj(self
, obj
):
702 'transform object attributes into tuple for storing in database'
703 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
704 for col
,icol
in self
.data
.items():
706 l
[icol
] = getattr(obj
,col
)
707 except (AttributeError,TypeError):
710 def _insert(self
, l
):
711 '''insert tuple into the database. Note this uses the MySQL
712 extension REPLACE, which overwrites any duplicate key.'''
713 s
= '%(REPLACE)s into ' + self
.name
+ ' values (' \
714 + ','.join(['%s']*len(l
)) + ')'
715 sql
,params
= self
._format
_query
(s
, l
)
716 self
.cursor
.execute(sql
, params
)
717 def insert(self
, obj
):
718 '''insert new row by transforming obj to tuple of values'''
719 l
= self
.tuple_from_obj(obj
)
721 def get_insert_id(self
):
722 'get the primary key value for the last INSERT'
723 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
724 auto_id
= self
.cursor
.lastrowid
725 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
726 raise NotImplementedError('''your db lacks lastrowid support?''')
728 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
730 def new(self
, **kwargs
):
731 'return a new record with the assigned attributes, added to DB'
732 if not self
.writeable
:
733 raise ValueError('this database is read only!')
734 obj
= self
.itemClass(None, newRow
=True, **kwargs
) # saves itself to db
735 self
._weakValueDict
[obj
.id] = obj
# AND SAVE TO OUR LOCAL DICT CACHE
737 def clear_cache(self
):
739 self
._weakValueDict
.clear()
740 def __delitem__(self
, k
):
741 if not self
.writeable
:
742 raise ValueError('this database is read only!')
743 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
744 % (self
.name
,self
.primary_key
),(k
,))
745 self
.cursor
.execute(sql
, params
)
747 del self
._weakValueDict
[k
]
751 def getKeys(self
,queryOption
='', selectCols
=None):
752 'uses db select; does not force load'
753 if selectCols
is None:
754 selectCols
=self
.primary_key
755 if queryOption
=='' and self
.orderBy
is not None:
756 queryOption
= self
.orderBy
# apply default ordering
757 self
.cursor
.execute('select %s from %s %s'
758 %(selectCols
,self
.name
,queryOption
))
759 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
761 def iter_keys(self
, selectCols
=None, orderBy
='', map_f
=iter,
762 cache_f
=lambda x
:[t
[0] for t
in x
], get_f
=None, **kwargs
):
763 'guarantee correct iteration insulated from other queries'
764 if selectCols
is None:
765 selectCols
=self
.primary_key
766 if orderBy
=='' and self
.orderBy
is not None:
767 orderBy
= self
.orderBy
# apply default ordering
768 cursor
= self
.get_new_cursor()
769 if cursor
: # got our own cursor, guaranteeing query isolation
770 self
._select
(cursor
=cursor
, selectCols
=selectCols
, orderBy
=orderBy
,
772 return self
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
, map_f
=map_f
)
773 else: # must pre-fetch all keys to ensure query isolation
774 if get_f
is not None:
777 return iter(self
.keys())
779 class SQLTable(SQLTableBase
):
780 "Provide on-the-fly access to rows in the database, caching the results in dict"
781 itemClass
= TupleO
# our default itemClass; constructor can override
784 def load(self
,oclass
=None):
785 "Load all data from the table"
786 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
787 return self
._isLoaded
788 except AttributeError:
791 oclass
=self
.itemClass
792 self
.cursor
.execute('select * from %s' % self
.name
)
793 l
=self
.cursor
.fetchall()
794 self
._weakValueDict
= {} # just store the whole dataset in memory
796 self
.cacheItem(t
,oclass
) # CACHE IT IN LOCAL DICTIONARY
797 self
._isLoaded
=True # MARK THIS CONTAINER AS FULLY LOADED
799 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
801 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
802 except KeyError: # NOT FOUND, SO TRY THE DATABASE
803 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 2'
804 % (self
.name
,self
.primary_key
),(k
,))
805 self
.cursor
.execute(sql
, params
)
806 l
= self
.cursor
.fetchmany(2) # get at most 2 rows
808 raise KeyError('%s not found in %s, or not unique' %(str(k
),self
.name
))
810 return self
.cacheItem(l
[0],self
.itemClass
) # CACHE IT IN LOCAL DICTIONARY
811 def __setitem__(self
, k
, v
):
812 if not self
.writeable
:
813 raise ValueError('this database is read only!')
817 except AttributeError:
818 raise ValueError('object not bound to itemClass for this db!')
823 except AttributeError:
825 else: # delete row with old ID
827 v
.cache_id(k
) # cache the new ID on the object
828 self
.insert(v
) # SAVE TO THE RELATIONAL DB SERVER
829 self
._weakValueDict
[k
] = v
# CACHE THIS ITEM IN OUR DICTIONARY
831 'forces load of entire table into memory'
833 return self
._weakValueDict
.items()
835 'uses arraysize / maxCache and fetchmany() to manage data transfer'
836 cursor
= self
.get_new_cursor()
837 self
._select
(cursor
=cursor
)
838 return self
.generic_iterator(cursor
=cursor
, map_f
=generate_items
)
840 'forces load of entire table into memory'
842 return self
._weakValueDict
.values()
843 def itervalues(self
):
844 'uses arraysize / maxCache and fetchmany() to manage data transfer'
845 cursor
= self
.get_new_cursor()
846 self
._select
(cursor
=cursor
)
847 return self
.generic_iterator(cursor
=cursor
)
849 def getClusterKeys(self
,queryOption
=''):
850 'uses db select; does not force load'
851 self
.cursor
.execute('select distinct %s from %s %s'
852 %(self
.clusterKey
,self
.name
,queryOption
))
853 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
856 class SQLTableClustered(SQLTable
):
857 '''use clusterKey to load a whole cluster of rows at once,
858 specifically, all rows that share the same clusterKey value.'''
859 def __init__(self
, *args
, **kwargs
):
860 kwargs
= kwargs
.copy() # get a copy we can alter
861 kwargs
['autoGC'] = False # don't use WeakValueDictionary
862 SQLTable
.__init
__(self
, *args
, **kwargs
)
864 return getKeys(self
,'order by %s' %self
.clusterKey
)
865 def clusterkeys(self
):
866 return getClusterKeys(self
, 'order by %s' %self
.clusterKey
)
867 def __getitem__(self
,k
):
869 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
870 except KeyError: # NOT FOUND, SO TRY THE DATABASE
871 sql
,params
= self
._format
_query
('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
872 % (self
.name
,self
.name
,self
.primary_key
,
873 self
.clusterKey
,self
.clusterKey
),(k
,))
874 self
.cursor
.execute(sql
, params
)
875 l
=self
.cursor
.fetchall()
877 for t
in l
: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
878 self
.cacheItem(t
,self
.itemClass
)
879 return self
._weakValueDict
[k
] # should be in cache, if row k exists
880 def itercluster(self
,cluster_id
):
881 'iterate over all items from the specified cluster'
883 return self
.select('where %s=%%s'%self
.clusterKey
,(cluster_id
,))
884 def fetch_cluster(self
):
885 'use self.cursor.fetchmany to obtain all rows for next cluster'
886 icol
= self
._attrSQL
(self
.clusterKey
,columnNumber
=True)
889 rows
= self
._fetch
_cluster
_cache
# USE SAVED ROWS FROM PREVIOUS CALL
890 del self
._fetch
_cluster
_cache
891 except AttributeError:
892 rows
= self
.cursor
.fetchmany()
894 cluster_id
= rows
[0][icol
]
898 for i
,t
in enumerate(rows
): # CHECK THAT ALL ROWS FROM THIS CLUSTER
899 if cluster_id
!= t
[icol
]: # START OF A NEW CLUSTER
900 result
+= rows
[:i
] # RETURN ROWS OF CURRENT CLUSTER
901 self
._fetch
_cluster
_cache
= rows
[i
:] # SAVE NEXT CLUSTER
904 rows
= self
.cursor
.fetchmany() # GET NEXT SET OF ROWS
906 def itervalues(self
):
907 'uses arraysize / maxCache and fetchmany() to manage data transfer'
908 cursor
= self
.get_new_cursor()
909 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
910 return self
.generic_iterator(cursor
, self
.fetch_cluster
)
912 'uses arraysize / maxCache and fetchmany() to manage data transfer'
913 cursor
= self
.get_new_cursor()
914 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
915 return self
.generic_iterator(cursor
, self
.fetch_cluster
,
916 map_f
=generate_items
)
918 class SQLForeignRelation(object):
919 'mapping based on matching a foreign key in an SQL table'
920 def __init__(self
,table
,keyName
):
923 def __getitem__(self
,k
):
924 'get list of objects o with getattr(o,keyName)==k.id'
926 for o
in self
.table
.select('where %s=%%s'%self
.keyName
,(k
.id,)):
929 raise KeyError('%s not found in %s' %(str(k
),self
.name
))
933 class SQLTableNoCache(SQLTableBase
):
934 '''Provide on-the-fly access to rows in the database;
935 values are simply an object interface (SQLRow) to back-end db query.
936 Row data are not stored locally, but always accessed by querying the db'''
937 itemClass
=SQLRow
# DEFAULT OBJECT CLASS FOR ROWS...
940 def getID(self
,t
): return t
[0] # GET ID FROM TUPLE
941 def select(self
,whereClause
,params
):
942 return SQLTableBase
.select(self
,whereClause
,params
,self
.oclass
,
944 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
946 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
947 except KeyError: # NOT FOUND, SO TRY THE DATABASE
948 self
._select
('where %s=%%s' % self
.primary_key
, (k
,),
950 t
= self
.cursor
.fetchmany(2)
952 raise KeyError('id %s non-existent or not unique' % k
)
953 o
= self
.itemClass(k
) # create obj referencing this ID
954 self
._weakValueDict
[k
] = o
# cache the SQLRow object
956 def __setitem__(self
, k
, v
):
957 if not self
.writeable
:
958 raise ValueError('this database is read only!')
962 except AttributeError:
963 raise ValueError('object not bound to itemClass for this db!')
965 del self
[k
] # delete row with new ID if any
969 del self
._weakValueDict
[v
.id] # delete from old cache location
972 self
._update
(v
.id, self
.primary_key
, k
) # just change its ID in db
973 v
.cache_id(k
) # change the cached ID value
974 self
._weakValueDict
[k
] = v
# assign to new cache location
975 def addAttrAlias(self
,**kwargs
):
976 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
978 SQLRow
._tableclass
=SQLTableNoCache
# SQLRow IS FOR NON-CACHING TABLE INTERFACE
981 class SQLTableMultiNoCache(SQLTableBase
):
982 "Trivial on-the-fly access for table with key that returns multiple rows"
983 itemClass
= TupleO
# default itemClass; constructor can override
984 _distinct_key
='id' # DEFAULT COLUMN TO USE AS KEY
986 return getKeys(self
, selectCols
='distinct(%s)'
987 % self
._attrSQL
(self
._distinct
_key
))
989 return iter_keys(self
, 'distinct(%s)' % self
._attrSQL
(self
._distinct
_key
))
990 def __getitem__(self
,id):
991 sql
,params
= self
._format
_query
('select * from %s where %s=%%s'
992 %(self
.name
,self
._attrSQL
(self
._distinct
_key
)),(id,))
993 self
.cursor
.execute(sql
, params
)
994 l
=self
.cursor
.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
996 yield self
.itemClass(row
)
997 def addAttrAlias(self
,**kwargs
):
998 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
1002 class SQLEdges(SQLTableMultiNoCache
):
1003 '''provide iterator over edges as (source,target,edge)
1004 and getitem[edge] --> [(source,target),...]'''
1005 _distinct_key
='edge_id'
1006 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1007 _pickleAttrs
.update(dict(graph
=0))
1009 self
.cursor
.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1010 %(self
._attrSQL
('source_id'),self
._attrSQL
('target_id'),
1011 self
._attrSQL
('edge_id'),self
.name
,
1012 self
._attrSQL
('target_id'),self
._attrSQL
('source_id'),
1013 self
._attrSQL
('target_id')))
1014 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1015 for source_id
,target_id
,edge_id
in self
.cursor
.fetchall():
1016 l
.append((self
.graph
.unpack_source(source_id
),
1017 self
.graph
.unpack_target(target_id
),
1018 self
.graph
.unpack_edge(edge_id
)))
1022 return iter(self
.keys())
1023 def __getitem__(self
,edge
):
1024 sql
,params
= self
._format
_query
('select %s,%s from %s where %s=%%s'
1025 %(self
._attrSQL
('source_id'),
1026 self
._attrSQL
('target_id'),
1028 self
._attrSQL
(self
._distinct
_key
)),
1029 (self
.graph
.pack_edge(edge
),))
1030 self
.cursor
.execute(sql
, params
)
1031 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1032 for source_id
,target_id
in self
.cursor
.fetchall():
1033 l
.append((self
.graph
.unpack_source(source_id
),
1034 self
.graph
.unpack_target(target_id
)))
1038 class SQLEdgeDict(object):
1039 '2nd level graph interface to SQL database'
1040 def __init__(self
,fromNode
,table
):
1041 self
.fromNode
=fromNode
1043 if not hasattr(self
.table
,'allowMissingNodes'):
1044 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s limit 1'
1045 %(self
.table
.sourceSQL
,
1047 self
.table
.sourceSQL
),
1049 self
.table
.cursor
.execute(sql
, params
)
1050 if len(self
.table
.cursor
.fetchall())<1:
1051 raise KeyError('node not in graph!')
1053 def __getitem__(self
,target
):
1054 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1055 %(self
.table
.edgeSQL
,
1057 self
.table
.sourceSQL
,
1058 self
.table
.targetSQL
),
1060 self
.table
.pack_target(target
)))
1061 self
.table
.cursor
.execute(sql
, params
)
1062 l
= self
.table
.cursor
.fetchmany(2) # get at most two rows
1064 raise KeyError('either no edge from source to target or not unique!')
1066 return self
.table
.unpack_edge(l
[0][0]) # RETURN EDGE
1068 raise KeyError('no edge from node to target')
1069 def __setitem__(self
,target
,edge
):
1070 sql
,params
= self
.table
._format
_query
('replace into %s values (%%s,%%s,%%s)'
1073 self
.table
.pack_target(target
),
1074 self
.table
.pack_edge(edge
)))
1075 self
.table
.cursor
.execute(sql
, params
)
1076 if not hasattr(self
.table
,'sourceDB') or \
1077 (hasattr(self
.table
,'targetDB') and self
.table
.sourceDB
==self
.table
.targetDB
):
1078 self
.table
+= target
# ADD AS NODE TO GRAPH
1079 def __iadd__(self
,target
):
1081 return self
# iadd MUST RETURN self!
1082 def __delitem__(self
,target
):
1083 sql
,params
= self
.table
._format
_query
('delete from %s where %s=%%s and %s=%%s'
1085 self
.table
.sourceSQL
,
1086 self
.table
.targetSQL
),
1088 self
.table
.pack_target(target
)))
1089 self
.table
.cursor
.execute(sql
, params
)
1090 if self
.table
.cursor
.rowcount
< 1: # no rows deleted?
1091 raise KeyError('no edge from node to target')
1093 def iterator_query(self
):
1094 sql
,params
= self
.table
._format
_query
('select %s,%s from %s where %s=%%s and %s is not null'
1095 %(self
.table
.targetSQL
,
1098 self
.table
.sourceSQL
,
1099 self
.table
.targetSQL
),
1101 self
.table
.cursor
.execute(sql
, params
)
1102 return self
.table
.cursor
.fetchall()
1104 return [self
.table
.unpack_target(target_id
)
1105 for target_id
,edge_id
in self
.iterator_query()]
1107 return [self
.table
.unpack_edge(edge_id
)
1108 for target_id
,edge_id
in self
.iterator_query()]
1110 return [(self
.table
.unpack_source(self
.fromNode
),self
.table
.unpack_target(target_id
),
1111 self
.table
.unpack_edge(edge_id
))
1112 for target_id
,edge_id
in self
.iterator_query()]
1114 return [(self
.table
.unpack_target(target_id
),self
.table
.unpack_edge(edge_id
))
1115 for target_id
,edge_id
in self
.iterator_query()]
1116 def __iter__(self
): return iter(self
.keys())
1117 def itervalues(self
): return iter(self
.values())
1118 def iteritems(self
): return iter(self
.items())
1120 return len(self
.keys())
1123 class SQLEdgelessDict(SQLEdgeDict
):
1124 'for SQLGraph tables that lack edge_id column'
1125 def __getitem__(self
,target
):
1126 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1127 %(self
.table
.targetSQL
,
1129 self
.table
.sourceSQL
,
1130 self
.table
.targetSQL
),
1132 self
.table
.pack_target(target
)))
1133 self
.table
.cursor
.execute(sql
, params
)
1134 l
= self
.table
.cursor
.fetchmany(2)
1136 raise KeyError('either no edge from source to target or not unique!')
1137 return None # no edge info!
1138 def iterator_query(self
):
1139 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s is not null'
1140 %(self
.table
.targetSQL
,
1142 self
.table
.sourceSQL
,
1143 self
.table
.targetSQL
),
1145 self
.table
.cursor
.execute(sql
, params
)
1146 return [(t
[0],None) for t
in self
.table
.cursor
.fetchall()]
1148 SQLEdgeDict
._edgelessClass
= SQLEdgelessDict
1150 class SQLGraphEdgeDescriptor(object):
1151 'provide an SQLEdges interface on demand'
1152 def __get__(self
,obj
,objtype
):
1154 attrAlias
=obj
.attrAlias
.copy()
1155 except AttributeError:
1156 return SQLEdges(obj
.name
, obj
.cursor
, graph
=obj
)
1158 return SQLEdges(obj
.name
, obj
.cursor
, attrAlias
=attrAlias
,
1161 def getColumnTypes(createTable
,attrAlias
={},defaultColumnType
='int',
1162 columnAttrs
=('source','target','edge'),**kwargs
):
1163 'return list of [(colname,coltype),...] for source,target,edge'
1165 for attr
in columnAttrs
:
1167 attrName
= attrAlias
[attr
+'_id']
1169 attrName
= attr
+'_id'
1170 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1171 l
.append((attrName
,createTable
[attr
+'_id']))
1173 except (KeyError,TypeError):
1175 try: # get type info from primary key for that database
1176 db
= kwargs
[attr
+'DB']
1178 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1181 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1183 try: # GET ONE IDENTIFIER FROM THE DATABASE
1185 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1187 l
.append((attrName
,db
.columnType
[db
.primary_key
]))
1189 except AttributeError:
1191 else: # GET THE TYPE FROM THIS IDENTIFIER
1192 if isinstance(k
,int) or isinstance(k
,long):
1193 l
.append((attrName
,'int'))
1195 elif isinstance(k
,str):
1196 l
.append((attrName
,'varchar(32)'))
1199 raise ValueError('SQLGraph node / edge must be int or str!')
1200 l
.append((attrName
,defaultColumnType
))
1201 logger
.warn('no type info found for %s, so using default: %s'
1202 % (attrName
, defaultColumnType
))
1208 class SQLGraph(SQLTableMultiNoCache
):
1209 '''provide a graph interface via a SQL table. Key capabilities are:
1210 - setitem with an empty dictionary: a dummy operation
1211 - getitem with a key that exists: return a placeholder
1212 - setitem with non empty placeholder: again a dummy operation
1213 EXAMPLE TABLE SCHEMA:
1214 create table mygraph (source_id int not null,target_id int,edge_id int,
1215 unique(source_id,target_id));
1217 _distinct_key
='source_id'
1218 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1219 _pickleAttrs
.update(dict(sourceDB
=0,targetDB
=0,edgeDB
=0,allowMissingNodes
=0))
1220 _edgeClass
= SQLEdgeDict
1221 def __init__(self
,name
,*l
,**kwargs
):
1222 graphArgs
,tableArgs
= split_kwargs(kwargs
,
1223 ('attrAlias','defaultColumnType','columnAttrs',
1224 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1225 'edgeDictClass','graph'))
1226 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1227 c
= getColumnTypes(**kwargs
)
1228 tableArgs
['createTable'] = \
1229 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1230 % (name
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1232 self
.allowMissingNodes
= kwargs
['allowMissingNodes']
1233 except KeyError: pass
1234 SQLTableMultiNoCache
.__init
__(self
,name
,*l
,**tableArgs
)
1235 self
.sourceSQL
= self
._attrSQL
('source_id')
1236 self
.targetSQL
= self
._attrSQL
('target_id')
1238 self
.edgeSQL
= self
._attrSQL
('edge_id')
1239 except AttributeError:
1241 self
._edgeClass
= self
._edgeClass
._edgelessClass
1242 save_graph_db_refs(self
,**kwargs
)
1243 def __getitem__(self
,k
):
1244 return self
._edgeClass
(self
.pack_source(k
),self
)
1245 def __iadd__(self
,k
):
1246 sql
,params
= self
._format
_query
('delete from %s where %s=%%s and %s is null'
1247 % (self
.name
,self
.sourceSQL
,self
.targetSQL
),
1248 (self
.pack_source(k
),))
1249 self
.cursor
.execute(sql
, params
)
1250 sql
,params
= self
._format
_query
('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1251 % self
.name
,(self
.pack_source(k
),))
1252 self
.cursor
.execute(sql
, params
)
1253 return self
# iadd MUST RETURN SELF!
1254 def __isub__(self
,k
):
1255 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
1256 % (self
.name
,self
.sourceSQL
),
1257 (self
.pack_source(k
),))
1258 self
.cursor
.execute(sql
, params
)
1259 if self
.cursor
.rowcount
== 0:
1260 raise KeyError('node not found in graph')
1261 return self
# iadd MUST RETURN SELF!
1262 __setitem__
= graph_setitem
1263 def __contains__(self
,k
):
1264 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 1'
1265 %(self
.name
,self
.sourceSQL
),
1266 (self
.pack_source(k
),))
1267 self
.cursor
.execute(sql
, params
)
1268 l
= self
.cursor
.fetchmany(2)
1270 def __invert__(self
):
1271 'get an interface to the inverse graph mapping'
1273 return self
._inverse
1274 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1275 attrAlias
= dict(source_id
=self
.targetSQL
, # SWAP SOURCE & TARGET
1276 target_id
=self
.sourceSQL
,
1277 edge_id
=self
.edgeSQL
)
1278 if self
.edgeSQL
is None: # no edge interface
1279 del attrAlias
['edge_id']
1280 self
._inverse
=SQLGraph(self
.name
,self
.cursor
,
1281 attrAlias
=attrAlias
,
1282 **graph_db_inverse_refs(self
))
1283 self
._inverse
._inverse
=self
1284 return self
._inverse
1286 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1287 yield self
.unpack_source(k
)
1288 def iteritems(self
):
1289 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1290 yield (self
.unpack_source(k
), self
._edgeClass
(k
, self
))
1291 def itervalues(self
):
1292 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1293 yield self
._edgeClass
(k
, self
)
1295 return [self
.unpack_source(k
) for k
in SQLTableMultiNoCache
.keys(self
)]
1296 def values(self
): return list(self
.itervalues())
1297 def items(self
): return list(self
.iteritems())
1298 edges
=SQLGraphEdgeDescriptor()
1299 update
= update_graph
1301 'get number of source nodes in graph'
1302 self
.cursor
.execute('select count(distinct %s) from %s'
1303 %(self
.sourceSQL
,self
.name
))
1304 return self
.cursor
.fetchone()[0]
1306 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1307 ## def __cmp__(self,other):
1311 ## it = iter(self.edges)
1314 ## source,target,edge = it.next()
1315 ## except StopIteration:
1318 ## if d is not None:
1319 ## diff = cmp(n_target,len(d))
1322 ## if source is None:
1325 ## n += 1 # COUNT SOURCE NODES
1332 ## diff = cmp(edge,d[target])
1337 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1338 ## return cmp(n,len(other))
1340 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1342 class SQLIDGraph(SQLGraph
):
1343 add_trivial_packing_methods(locals())
1344 SQLGraph
._IDGraphClass
= SQLIDGraph
1348 class SQLEdgeDictClustered(dict):
1349 'simple cache for 2nd level dictionary of target_id:edge_id'
1350 def __init__(self
,g
,fromNode
):
1352 self
.fromNode
=fromNode
1354 def __iadd__(self
,l
):
1355 for target_id
,edge_id
in l
:
1356 dict.__setitem
__(self
,target_id
,edge_id
)
1357 return self
# iadd MUST RETURN SELF!
1359 class SQLEdgesClusteredDescr(object):
1360 def __get__(self
,obj
,objtype
):
1361 e
=SQLEdgesClustered(obj
.table
,obj
.edge_id
,obj
.source_id
,obj
.target_id
,
1362 graph
=obj
,**graph_db_inverse_refs(obj
,True))
1363 for source_id
,d
in obj
.d
.iteritems(): # COPY EDGE CACHE
1364 e
.load([(edge_id
,source_id
,target_id
)
1365 for (target_id
,edge_id
) in d
.iteritems()])
1368 class SQLGraphClustered(object):
1369 'SQL graph with clustered caching -- loads an entire cluster at a time'
1370 _edgeDictClass
=SQLEdgeDictClustered
1371 def __init__(self
,table
,source_id
='source_id',target_id
='target_id',
1372 edge_id
='edge_id',clusterKey
=None,**kwargs
):
1374 if isinstance(table
,types
.StringType
): # CREATE THE TABLE INTERFACE
1375 if clusterKey
is None:
1376 raise ValueError('you must provide a clusterKey argument!')
1377 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1378 c
= getColumnTypes(attrAlias
=dict(source_id
=source_id
,target_id
=target_id
,
1379 edge_id
=edge_id
),**kwargs
)
1380 kwargs
['createTable'] = \
1381 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1382 % (table
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],
1383 c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1384 table
= SQLTableClustered(table
,clusterKey
=clusterKey
,**kwargs
)
1386 self
.source_id
=source_id
1387 self
.target_id
=target_id
1388 self
.edge_id
=edge_id
1390 save_graph_db_refs(self
,**kwargs
)
1391 _pickleAttrs
= dict(table
=0,source_id
=0,target_id
=0,edge_id
=0,sourceDB
=0,targetDB
=0,
1393 def __getstate__(self
):
1394 state
= standard_getstate(self
)
1395 state
['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1397 def __getitem__(self
,k
):
1398 'get edgeDict for source node k, from cache or by loading its cluster'
1399 try: # GET DIRECTLY FROM CACHE
1402 if hasattr(self
,'_isLoaded'):
1403 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1404 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1405 sql
,params
= self
.table
._format
_query
('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1406 %(self
.source_id
,self
.target_id
,
1407 self
.edge_id
,self
.table
.name
,
1408 self
.table
.name
,self
.source_id
,
1409 self
.table
.clusterKey
,self
.table
.clusterKey
,
1410 self
.table
.primary_key
),
1411 (self
.pack_source(k
),))
1412 self
.table
.cursor
.execute(sql
, params
)
1413 self
.load(self
.table
.cursor
.fetchall()) # CACHE THIS CLUSTER
1414 return self
.d
[k
] # RETURN EDGE DICT FOR THIS NODE
1415 def load(self
,l
=None,unpack
=True):
1416 'load the specified rows (or all, if None provided) into local cache'
1418 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1419 return self
._isLoaded
1420 except AttributeError:
1422 self
.table
.cursor
.execute('select %s,%s,%s from %s'
1423 %(self
.source_id
,self
.target_id
,
1424 self
.edge_id
,self
.table
.name
))
1425 l
=self
.table
.cursor
.fetchall()
1427 self
.d
.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1428 for source
,target
,edge
in l
: # SAVE TO OUR CACHE
1430 source
= self
.unpack_source(source
)
1431 target
= self
.unpack_target(target
)
1432 edge
= self
.unpack_edge(edge
)
1434 self
.d
[source
] += [(target
,edge
)]
1436 d
= self
._edgeDictClass
(self
,source
)
1437 d
+= [(target
,edge
)]
1439 def __invert__(self
):
1440 'interface to reverse graph mapping'
1442 return self
._inverse
# INVERSE MAP ALREADY EXISTS
1443 except AttributeError:
1445 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1446 self
._inverse
=SQLGraphClustered(self
.table
,self
.target_id
,self
.source_id
,
1447 self
.edge_id
,**graph_db_inverse_refs(self
))
1448 self
._inverse
._inverse
=self
1449 for source
,d
in self
.d
.iteritems(): # INVERT OUR CACHE
1450 self
._inverse
.load([(target
,source
,edge
)
1451 for (target
,edge
) in d
.iteritems()],unpack
=False)
1452 return self
._inverse
1453 edges
=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1454 update
= update_graph
1455 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1456 def __iter__(self
): ################# ITERATORS
1457 'uses db select; does not force load'
1458 return iter(self
.keys())
1460 'uses db select; does not force load'
1461 self
.table
.cursor
.execute('select distinct(%s) from %s'
1462 %(self
.source_id
,self
.table
.name
))
1463 return [self
.unpack_source(t
[0])
1464 for t
in self
.table
.cursor
.fetchall()]
1465 methodFactory(['iteritems','items','itervalues','values'],
1466 'lambda self:(self.load(),self.d.%s())[1]',locals())
1467 def __contains__(self
,k
):
1474 class SQLIDGraphClustered(SQLGraphClustered
):
1475 add_trivial_packing_methods(locals())
1476 SQLGraphClustered
._IDGraphClass
= SQLIDGraphClustered
1478 class SQLEdgesClustered(SQLGraphClustered
):
1479 'edges interface for SQLGraphClustered'
1480 _edgeDictClass
= list
1481 _pickleAttrs
= SQLGraphClustered
._pickleAttrs
.copy()
1482 _pickleAttrs
.update(dict(graph
=0))
1486 for edge_id
,l
in self
.d
.iteritems():
1487 for source_id
,target_id
in l
:
1488 result
.append((self
.graph
.unpack_source(source_id
),
1489 self
.graph
.unpack_target(target_id
),
1490 self
.graph
.unpack_edge(edge_id
)))
1493 class ForeignKeyInverse(object):
1494 'map each key to a single value according to its foreign key'
1495 def __init__(self
,g
):
1497 def __getitem__(self
,obj
):
1499 source_id
= getattr(obj
,self
.g
.keyColumn
)
1500 if source_id
is None:
1502 return self
.g
.sourceDB
[source_id
]
1503 def __setitem__(self
,obj
,source
):
1505 if source
is not None:
1506 self
.g
[source
][obj
] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1507 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1508 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1509 old_source
= self
[obj
]
1510 if old_source
is not None:
1511 del self
.g
[old_source
][obj
]
1512 def check_obj(self
,obj
):
1513 'raise KeyError if obj not from this db'
1515 if obj
.db
!= self
.g
.targetDB
:
1516 raise AttributeError
1517 except AttributeError:
1518 raise KeyError('key is not from targetDB of this graph!')
1519 def __contains__(self
,obj
):
1526 return self
.g
.targetDB
.itervalues()
1528 return self
.g
.targetDB
.values()
1529 def iteritems(self
):
1531 source_id
= getattr(obj
,self
.g
.keyColumn
)
1532 if source_id
is None:
1535 yield obj
,self
.g
.sourceDB
[source_id
]
1537 return list(self
.iteritems())
1538 def itervalues(self
):
1539 for obj
,val
in self
.iteritems():
1542 return list(self
.itervalues())
1543 def __invert__(self
):
1546 class ForeignKeyEdge(dict):
1547 '''edge interface to a foreign key in an SQL table.
1548 Caches dict of target nodes in itself; provides dict interface.
1549 Adds or deletes edges by setting foreign key values in the table'''
1550 def __init__(self
,g
,k
):
1554 for v
in g
.targetDB
.select('where %s=%%s' % g
.keyColumn
,(k
.id,)): # SEARCH THE DB
1555 dict.__setitem
__(self
,v
,None) # SAVE IN CACHE
1556 def __setitem__(self
,dest
,v
):
1557 if not hasattr(dest
,'db') or dest
.db
!= self
.g
.targetDB
:
1558 raise KeyError('dest is not in the targetDB bound to this graph!')
1560 raise ValueError('sorry,this graph cannot store edge information!')
1561 if not hasattr(dest
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1562 old_source
= self
.g
._inverse
[dest
] # CHECK FOR PRE-EXISTING EDGE
1563 if old_source
is not None: # REMOVE OLD EDGE FROM CACHE
1564 dict.__delitem
__(self
.g
[old_source
],dest
)
1565 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1566 setattr(dest
,self
.g
.keyColumn
,self
.src
.id) # SAVE TO DB ATTRIBUTE
1567 dict.__setitem
__(self
,dest
,None) # SAVE IN CACHE
1568 def __delitem__(self
,dest
):
1569 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1570 setattr(dest
,self
.g
.keyColumn
,None) # SAVE TO DB ATTRIBUTE
1571 dict.__delitem
__(self
,dest
) # REMOVE FROM CACHE
1573 class ForeignKeyGraph(object, UserDict
.DictMixin
):
1574 '''graph interface to a foreign key in an SQL table
1575 Caches dict of target nodes in itself; provides dict interface.
1577 def __init__(self
, sourceDB
, targetDB
, keyColumn
, autoGC
=True, **kwargs
):
1578 '''sourceDB is any database of source nodes;
1579 targetDB must be an SQL database of target nodes;
1580 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1581 if autoGC
: # automatically garbage collect unused objects
1582 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
1584 self
._weakValueDict
= {}
1585 self
.autoGC
= autoGC
1586 self
.sourceDB
= sourceDB
1587 self
.targetDB
= targetDB
1588 self
.keyColumn
= keyColumn
1589 self
._inverse
= ForeignKeyInverse(self
)
1590 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, keyColumn
=0, autoGC
=0)
1591 __getstate__
= standard_getstate
########### SUPPORT FOR PICKLING
1592 __setstate__
= standard_setstate
1593 def _inverse_schema(self
):
1594 'provide custom schema rule for inverting this graph... just use keyColumn!'
1595 return dict(invert
=True,uniqueMapping
=True)
1596 def __getitem__(self
,k
):
1597 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1598 raise KeyError('object is not in the sourceDB bound to this graph!')
1600 return self
._weakValueDict
[k
.id] # get from cache
1603 d
= ForeignKeyEdge(self
,k
)
1604 self
._weakValueDict
[k
.id] = d
# save in cache
1606 def __setitem__(self
, k
, v
):
1607 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1608 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1609 def __delitem__(self
, k
):
1610 raise KeyError('''Instead of del g[k], follow a graph
1611 interface: del g[src][dest]''')
1613 return self
.sourceDB
.values()
1614 __invert__
= standard_invert
1616 def describeDBTables(name
,cursor
,idDict
):
1618 Get table info about database <name> via <cursor>, and store primary keys
1619 in idDict, along with a list of the tables each key indexes.
1621 cursor
.execute('use %s' % name
)
1622 cursor
.execute('show tables')
1624 l
=[c
[0] for c
in cursor
.fetchall()]
1627 o
=SQLTable(tname
,cursor
)
1629 for f
in o
.description
:
1630 if f
==o
.primary_key
:
1631 idDict
.setdefault(f
, []).append(o
)
1632 elif f
[-3:]=='_id' and f
not in idDict
:
1638 def indexIDs(tables
,idDict
=None):
1639 "Get an index of primary keys in the <tables> dictionary."
1642 for o
in tables
.values():
1644 if o
.primary_key
not in idDict
:
1645 idDict
[o
.primary_key
]=[]
1646 idDict
[o
.primary_key
].append(o
) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1647 for f
in o
.description
:
1648 if f
[-3:]=='_id' and f
not in idDict
:
1654 def suffixSubset(tables
,suffix
):
1655 "Filter table index for those matching a specific suffix"
1657 for name
,t
in tables
.items():
1658 if name
.endswith(suffix
):
1665 def graphDBTables(tables
,idDict
):
1667 for t
in tables
.values():
1668 for f
in t
.description
:
1669 if f
==t
.primary_key
:
1670 edgeInfo
=PRIMARY_KEY
1673 g
.setEdge(f
,t
,edgeInfo
)
1674 g
.setEdge(t
,f
,edgeInfo
)
1677 SQLTypeTranslation
= {types
.StringType
:'varchar(32)',
1678 types
.IntType
:'int',
1679 types
.FloatType
:'float'}
1681 def createTableFromRepr(rows
,tableName
,cursor
,typeTranslation
=None,
1682 optionalDict
=None,indexDict
=()):
1683 """Save rows into SQL tableName using cursor, with optional
1684 translations of columns to specific SQL types (specified
1685 by typeTranslation dict).
1686 - optionDict can specify columns that are allowed to be NULL.
1687 - indexDict can specify columns that must be indexed; columns
1688 whose names end in _id will be indexed by default.
1689 - rows must be an iterator which in turn returns dictionaries,
1690 each representing a tuple of values (indexed by their column
1694 row
=rows
.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1695 except StopIteration:
1696 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1698 createTableFromRow(cursor
, tableName
,row
,typeTranslation
,
1699 optionalDict
,indexDict
)
1702 storeRow(cursor
,tableName
,row
) # SAVE OUR FIRST ROW
1703 for row
in rows
: # NOW SAVE ALL THE ROWS
1704 storeRow(cursor
,tableName
,row
)
1706 def createTableFromRow(cursor
, tableName
, row
,typeTranslation
=None,
1707 optionalDict
=None,indexDict
=()):
1709 for col
,val
in row
.items(): # PREPARE SQL TYPES FOR COLUMNS
1711 if typeTranslation
!=None and col
in typeTranslation
:
1712 coltype
=typeTranslation
[col
] # USER-SUPPLIED TRANSLATION
1713 elif type(val
) in SQLTypeTranslation
:
1714 coltype
=SQLTypeTranslation
[type(val
)]
1715 else: # SEARCH FOR A COMPATIBLE TYPE
1716 for t
in SQLTypeTranslation
:
1717 if isinstance(val
,t
):
1718 coltype
=SQLTypeTranslation
[t
]
1721 raise TypeError("Don't know SQL type to use for %s" % col
)
1722 create_def
='%s %s' %(col
,coltype
)
1723 if optionalDict
==None or col
not in optionalDict
:
1724 create_def
+=' not null'
1725 create_defs
.append(create_def
)
1726 for col
in row
: # CREATE INDEXES FOR ID COLUMNS
1727 if col
[-3:]=='_id' or col
in indexDict
:
1728 create_defs
.append('index(%s)' % col
)
1729 cmd
='create table if not exists %s (%s)' % (tableName
,','.join(create_defs
))
1730 cursor
.execute(cmd
) # CREATE THE TABLE IN THE DATABASE
1733 def storeRow(cursor
, tableName
, row
):
1734 row_format
=','.join(len(row
)*['%s'])
1735 cmd
='insert into %s values (%s)' % (tableName
,row_format
)
1736 cursor
.execute(cmd
,tuple(row
.values()))
1738 def storeRowDelayed(cursor
, tableName
, row
):
1739 row_format
=','.join(len(row
)*['%s'])
1740 cmd
='insert delayed into %s values (%s)' % (tableName
,row_format
)
1741 cursor
.execute(cmd
,tuple(row
.values()))
1744 class TableGroup(dict):
1745 'provide attribute access to dbname qualified tablenames'
1746 def __init__(self
,db
='test',suffix
=None,**kw
):
1749 if suffix
is not None:
1751 for k
,v
in kw
.items():
1752 if v
is not None and '.' not in v
:
1753 v
=self
.db
+'.'+v
# ADD DATABASE NAME AS PREFIX
1755 def __getattr__(self
,k
):
1758 def sqlite_connect(*args
, **kwargs
):
1759 sqlite
= import_sqlite()
1760 connection
= sqlite
.connect(*args
, **kwargs
)
1761 cursor
= connection
.cursor()
1762 return connection
, cursor
1764 class DBServerInfo(object):
1765 'picklable reference to a database server'
1766 def __init__(self
, moduleName
='MySQLdb', *args
, **kwargs
):
1768 self
.__class
__ = _DBServerModuleDict
[moduleName
]
1770 raise ValueError('Module name not found in _DBServerModuleDict: '\
1772 self
.moduleName
= moduleName
1774 self
.kwargs
= kwargs
# connection arguments
1777 """returns cursor associated with the DB server info (reused)"""
1780 except AttributeError:
1781 self
._start
_connection
()
1784 def new_cursor(self
, arraysize
=None):
1785 """returns a NEW cursor; you must close it yourself! """
1786 if not hasattr(self
, '_connection'):
1787 self
._start
_connection
()
1788 cursor
= self
._connection
.cursor()
1789 if arraysize
is not None:
1790 cursor
.arraysize
= arraysize
1794 """Close file containing this database"""
1795 self
._cursor
.close()
1796 self
._connection
.close()
1798 del self
._connection
1800 def __getstate__(self
):
1801 """return all picklable arguments"""
1802 return dict(args
=self
.args
, kwargs
=self
.kwargs
,
1803 moduleName
=self
.moduleName
)
1806 class MySQLServerInfo(DBServerInfo
):
1807 'customized for MySQLdb SSCursor support via new_cursor()'
1808 def _start_connection(self
):
1809 self
._connection
,self
._cursor
= mysql_connect(*self
.args
, **self
.kwargs
)
1810 def new_cursor(self
, arraysize
=None):
1811 'provide streaming cursor support'
1813 conn
= self
._conn
_sscursor
1814 except AttributeError:
1815 self
._conn
_sscursor
,cursor
= mysql_connect(useStreaming
=True,
1816 *self
.args
, **self
.kwargs
)
1818 cursor
= self
._conn
_sscursor
.cursor()
1819 if arraysize
is not None:
1820 cursor
.arraysize
= arraysize
1823 DBServerInfo
.close(self
)
1825 self
._conn
_sscursor
.close()
1826 del self
._conn
_sscursor
1827 except AttributeError:
1830 class SQLiteServerInfo(DBServerInfo
):
1831 """picklable reference to a sqlite database"""
1832 def __init__(self
, database
, *args
, **kwargs
):
1833 """Takes same arguments as sqlite3.connect()"""
1834 DBServerInfo
.__init
__(self
, 'sqlite',
1835 SourceFileName(database
), # save abs path!
1837 def _start_connection(self
):
1838 self
._connection
,self
._cursor
= sqlite_connect(*self
.args
, **self
.kwargs
)
1839 def __getstate__(self
):
1840 if self
.args
[0] == ':memory:':
1841 raise ValueError('SQLite in-memory database is not picklable!')
1842 return DBServerInfo
.__getstate
__(self
)
1844 # list of DBServerInfo subclasses for different modules
1845 _DBServerModuleDict
= dict(MySQLdb
=MySQLServerInfo
, sqlite
=SQLiteServerInfo
)
1848 class MapView(object, UserDict
.DictMixin
):
1849 'general purpose 1:1 mapping defined by any SQL query'
1850 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None,
1851 serverInfo
=None, inverseSQL
=None, **kwargs
):
1852 self
.sourceDB
= sourceDB
1853 self
.targetDB
= targetDB
1854 self
.viewSQL
= viewSQL
1855 self
.inverseSQL
= inverseSQL
1857 if serverInfo
is not None: # get cursor from serverInfo
1858 cursor
= serverInfo
.cursor()
1860 try: # can we get it from our other db?
1861 serverInfo
= sourceDB
.serverInfo
1862 except AttributeError:
1863 raise ValueError('you must provide serverInfo or cursor!')
1865 cursor
= serverInfo
.cursor()
1866 self
.cursor
= cursor
1867 self
.serverInfo
= serverInfo
1868 self
.get_sql_format(False) # get sql formatter for this db interface
1869 _schemaModuleDict
= _schemaModuleDict
# default module list
1870 get_sql_format
= get_table_schema
1871 def __getitem__(self
, k
):
1872 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1873 raise KeyError('object is not in the sourceDB bound to this map!')
1874 sql
,params
= self
._format
_query
(self
.viewSQL
, (k
.id,))
1875 self
.cursor
.execute(sql
, params
) # formatted for this db interface
1876 t
= self
.cursor
.fetchmany(2) # get at most two rows
1878 raise KeyError('%s not found in MapView, or not unique'
1880 return self
.targetDB
[t
[0][0]] # get the corresponding object
1881 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, viewSQL
=0, serverInfo
=0,
1883 __getstate__
= standard_getstate
1884 __setstate__
= standard_setstate
1885 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
1886 setdefault
= read_only_error
1888 'only yield sourceDB items that are actually in this mapping!'
1889 for k
in self
.sourceDB
.itervalues():
1896 return [k
for k
in self
] # don't use list(self); causes infinite loop!
1897 def __invert__(self
):
1899 return self
._inverse
1900 except AttributeError:
1901 if self
.inverseSQL
is None:
1902 raise ValueError('this MapView has no inverseSQL!')
1903 self
._inverse
= self
.__class
__(self
.targetDB
, self
.sourceDB
,
1904 self
.inverseSQL
, self
.cursor
,
1905 serverInfo
=self
.serverInfo
,
1906 inverseSQL
=self
.viewSQL
)
1907 self
._inverse
._inverse
= self
1908 return self
._inverse
1910 class GraphViewEdgeDict(UserDict
.DictMixin
):
1911 'edge dictionary for GraphView: just pre-loaded on init'
1912 def __init__(self
, g
, k
):
1915 sql
,params
= self
.g
._format
_query
(self
.g
.viewSQL
, (k
.id,))
1916 self
.g
.cursor
.execute(sql
, params
) # run the query
1917 l
= self
.g
.cursor
.fetchall() # get results
1919 raise KeyError('key %s not in GraphView' % k
.id)
1920 self
.targets
= [t
[0] for t
in l
] # preserve order of the results
1921 d
= {} # also keep targetID:edgeID mapping
1922 if self
.g
.edgeDB
is not None: # save with edge info
1930 return len(self
.targets
)
1932 for k
in self
.targets
:
1933 yield self
.g
.targetDB
[k
]
1936 def iteritems(self
):
1937 if self
.g
.edgeDB
is not None: # save with edge info
1938 for k
in self
.targets
:
1939 yield (self
.g
.targetDB
[k
], self
.g
.edgeDB
[self
.targetDict
[k
]])
1940 else: # just save the list of targets, no edge info
1941 for k
in self
.targets
:
1942 yield (self
.g
.targetDB
[k
], None)
1943 def __getitem__(self
, o
, exitIfFound
=False):
1944 'for the specified target object, return its associated edge object'
1946 if o
.db
is not self
.g
.targetDB
:
1947 raise KeyError('key is not part of targetDB!')
1948 edgeID
= self
.targetDict
[o
.id]
1949 except AttributeError:
1950 raise KeyError('key has no id or db attribute?!')
1953 if self
.g
.edgeDB
is not None: # return the edge object
1954 return self
.g
.edgeDB
[edgeID
]
1955 else: # no edge info
1957 def __contains__(self
, o
):
1959 self
.__getitem
__(o
, True) # raise KeyError if not found
1963 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
1964 setdefault
= read_only_error
1966 class GraphView(MapView
):
1967 'general purpose graph interface defined by any SQL query'
1968 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None, edgeDB
=None,
1970 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
1971 self
.edgeDB
= edgeDB
1972 MapView
.__init
__(self
, sourceDB
, targetDB
, viewSQL
, cursor
, **kwargs
)
1973 def __getitem__(self
, k
):
1974 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1975 raise KeyError('object is not in the sourceDB bound to this map!')
1976 return GraphViewEdgeDict(self
, k
)
1977 _pickleAttrs
= MapView
._pickleAttrs
.copy()
1978 _pickleAttrs
.update(dict(edgeDB
=0))
1980 # @CTB move to sqlgraph.py?
1982 class SQLSequence(SQLRow
, SequenceBase
):
1983 """Transparent access to a DB row representing a sequence.
1985 Use attrAlias dict to rename 'length' to something else.
1987 def _init_subclass(cls
, db
, **kwargs
):
1988 db
.seqInfoDict
= db
# db will act as its own seqInfoDict
1989 SQLRow
._init
_subclass
(db
=db
, **kwargs
)
1990 _init_subclass
= classmethod(_init_subclass
)
1991 def __init__(self
, id):
1992 SQLRow
.__init
__(self
, id)
1993 SequenceBase
.__init
__(self
)
1996 def strslice(self
,start
,end
):
1997 "Efficient access to slice of a sequence, useful for huge contigs"
1998 return self
._select
('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
1999 %(self
.db
._attrSQL
('seq'),start
+1,end
-start
))
2001 class DNASQLSequence(SQLSequence
):
2002 _seqtype
=DNA_SEQTYPE
2004 class RNASQLSequence(SQLSequence
):
2005 _seqtype
=RNA_SEQTYPE
2007 class ProteinSQLSequence(SQLSequence
):
2008 _seqtype
=PROTEIN_SEQTYPE