3 from __future__
import generators
5 from sequence
import SequenceBase
, DNA_SEQTYPE
, RNA_SEQTYPE
, PROTEIN_SEQTYPE
7 from classutil
import methodFactory
,standard_getstate
,\
8 override_rich_cmp
,generate_items
,get_bound_subclass
,standard_setstate
,\
9 get_valid_path
,standard_invert
,RecentValueDictionary
,read_only_error
,\
10 SourceFileName
, split_kwargs
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self
, db
, attr
):
20 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
21 def __get__(self
, obj
, klass
):
22 return obj
._data
[self
.icol
]
23 def __set__(self
, obj
, val
):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor
):
27 def __set__(self
, obj
, val
):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor
):
32 'read-write interface to named attribute'
33 def __init__(self
, db
, attr
):
35 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
36 self
.attrSQL
= db
._attrSQL
(attr
, sqlColumn
=True) # SQL column name
37 def __set__(self
, obj
, val
):
38 obj
.db
._update
(obj
.id, self
.attrSQL
, val
) # AND UPDATE THE DATABASE
39 obj
.save_local(self
.attr
, val
)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self
, db
, attr
):
44 self
.selectSQL
= db
._attrSQL
(attr
) # SQL expression for this attr
45 def __get__(self
, obj
, klass
):
46 return obj
._select
(self
.selectSQL
)
47 def __set__(self
, obj
, val
):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor
):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self
, obj
, val
):
53 obj
.db
._update
(obj
.id, self
.selectSQL
, val
) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self
, db
, attr
):
59 def __get__(self
, obj
, klass
):
60 return getattr(obj
, self
.attr
)
61 def __set__(self
, obj
, val
):
62 raise AttributeError('attribute %s is read-only!' % self
.attr
)
65 def select_from_row(row
, what
):
66 "return value from SQL expression applied to this row"
67 sql
,params
= row
.db
._format
_query
('select %s from %s where %s=%%s limit 2'
68 % (what
,row
.db
.name
,row
.db
.primary_key
),
70 row
.db
.cursor
.execute(sql
, params
)
71 t
= row
.db
.cursor
.fetchmany(2) # get at most two rows
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row
.db
.name
,str(row
.id),what
))
75 return t
[0][0] #return the single field we requested
77 def init_row_subclass(cls
, db
):
78 'add descriptors for db attributes'
79 for attr
in db
.data
: # bind all database columns
80 if attr
== 'id': # handle ID attribute specially
81 setattr(cls
, attr
, cls
._idDescriptor
(db
, attr
))
83 try: # check if this attr maps to an SQL column
84 db
._attrSQL
(attr
, columnNumber
=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls
, attr
, cls
._sqlDescriptor
(db
, attr
))
87 else: # treat as interface to our stored tuple
88 setattr(cls
, attr
, cls
._columnDescriptor
(db
, attr
))
91 """get list of column names as our attributes """
92 return self
.db
.data
.keys()
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor
= TupleDescriptor
115 _idDescriptor
= TupleIDDescriptor
116 _sqlDescriptor
= SQLDescriptor
117 _init_subclass
= classmethod(init_row_subclass
)
118 _select
= select_from_row
120 def __init__(self
, data
):
121 self
._data
= data
# save our data tuple
123 def insert_and_cache_id(self
, l
, **kwargs
):
124 'insert tuple into db and cache its rowID on self'
125 self
.db
._insert
(l
) # save to database
127 rowID
= kwargs
['id'] # use the ID supplied by user
129 rowID
= self
.db
.get_insert_id() # get auto-inc ID value
130 self
.cache_id(rowID
) # cache this ID on self
132 class TupleORW(TupleO
):
133 'read-write version of TupleO'
134 _columnDescriptor
= TupleDescriptorRW
135 insert_and_cache_id
= insert_and_cache_id
136 def __init__(self
, data
, newRow
=False, **kwargs
):
137 if not newRow
: # just cache data from the database
140 self
._data
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
141 self
.insert_and_cache_id(self
._data
, **kwargs
)
142 def cache_id(self
,row_id
):
143 self
.save_local('id',row_id
)
144 def save_local(self
,attr
,val
):
145 icol
= self
._attrcol
[attr
]
147 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self
._data
= list(self
._data
)
150 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
152 TupleO
._RWClass
= TupleORW
# record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self
, db
, attr
, readOnly
= False):
158 self
.col
= db
._attrSQL
(attr
, sqlColumn
=True) # MAP THIS TO SQL COLUMN NAME
161 self
.__class
__ = self
._readOnlyClass
162 def __get__(self
, obj
, objtype
):
164 return obj
.__dict
__[self
.attr
]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self
.col
==self
.db
.primary_key
:
168 self
.db
._select
('where %s=%%s' % self
.db
.primary_key
,(obj
.id,),self
.col
)
169 l
= self
.db
.cursor
.fetchall()
171 raise AttributeError('db row not found or not unique!')
172 obj
.__dict
__[self
.attr
] = l
[0][0] # UPDATE THE CACHE
174 def __set__(self
, obj
, val
):
175 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self
.db
._update
(obj
.id, self
.col
, val
) # UPDATE THE DATABASE
177 obj
.__dict
__[self
.attr
] = val
# UPDATE THE CACHE
179 ## m = self.consequences
180 ## except AttributeError:
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor
):
188 def __set__(self
, obj
, val
):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor
._readOnlyClass
= ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor
= _sqlDescriptor
= SQLDescriptor
200 _idDescriptor
= ReadOnlyDescriptor
201 _init_subclass
= classmethod(init_row_subclass
)
202 _select
= select_from_row
204 def __init__(self
, rowID
):
208 class SQLRowRW(SQLRow
):
209 'read-write version of SQLRow'
210 _columnDescriptor
= SQLDescriptorRW
211 insert_and_cache_id
= insert_and_cache_id
212 def __init__(self
, rowID
, newRow
=False, **kwargs
):
213 if not newRow
: # just cache data from the database
214 return self
.cache_id(rowID
)
215 l
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
216 self
.insert_and_cache_id(l
, **kwargs
)
217 def cache_id(self
, rowID
):
220 SQLRow
._RWClass
= SQLRowRW
224 def list_to_dict(names
, values
):
225 'return dictionary of those named args that are present in values[]'
227 for i
,v
in enumerate(values
):
235 def get_name_cursor(name
=None, **kwargs
):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
240 argList
= name
.split() # TREAT AS WS-SEPARATED LIST
242 name
= argList
[0] # USE 1ST ARG AS TABLE NAME
243 argnames
= ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs
= kwargs
.copy() # a copy we can overwrite
245 kwargs
.update(list_to_dict(argnames
, argList
[1:]))
246 serverInfo
= DBServerInfo(**kwargs
)
247 return name
,serverInfo
.cursor(),serverInfo
249 def mysql_connect(connect
=None, configFile
=None, useStreaming
=False, **args
):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs
= args
.copy() # a copy we can modify
252 if 'user' not in kwargs
and configFile
is None: #Find where config file is
253 osname
= platform
.system()
254 if osname
in('Microsoft', 'Windows'): # Machine is a Windows box
256 try: # handle case where WINDIR not defined by Windows...
257 windir
= os
.environ
['WINDIR']
258 paths
+= [(windir
, 'my.ini'), (windir
, 'my.cnf')]
262 sysdrv
= os
.environ
['SYSTEMDRIVE']
263 paths
+= [(sysdrv
, os
.path
.sep
+ 'my.ini'),
264 (sysdrv
, os
.path
.sep
+ 'my.cnf')]
268 configFile
= get_valid_path(*paths
)
269 else: # treat as normal platform with home directories
270 configFile
= os
.path
.join(os
.path
.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile
= configFile
or os
.path
.join( os
.getcwd(), 'mysql.cnf' )
276 if configFile
and os
.path
.exists(configFile
):
277 kwargs
['read_default_file'] = configFile
278 connect
= None # force it to use MySQLdb
281 connect
= MySQLdb
.connect
282 kwargs
['compress'] = True
283 if useStreaming
: # use server side cursors for scalable result sets
285 from MySQLdb
import cursors
286 kwargs
['cursorclass'] = cursors
.SSCursor
287 except (ImportError, AttributeError):
289 conn
= connect(**kwargs
)
290 cursor
= conn
.cursor()
293 _mysqlMacros
= dict(IGNORE
='ignore', REPLACE
='replace',
294 AUTO_INCREMENT
='AUTO_INCREMENT', SUBSTRING
='substring',
295 SUBSTR_FROM
='FROM', SUBSTR_FOR
='FOR')
297 def mysql_table_schema(self
, analyzeSchema
=True):
298 'retrieve table schema from a MySQL database, save on self'
300 self
._format
_query
= SQLFormatDict(MySQLdb
.paramstyle
, _mysqlMacros
)
301 if not analyzeSchema
:
303 self
.clear_schema() # reset settings and dictionaries
304 self
.cursor
.execute('describe %s' % self
.name
) # get info about columns
305 columns
= self
.cursor
.fetchall()
306 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
307 for icol
,c
in enumerate(columns
):
309 self
.columnName
.append(field
) # list of columns in same order as table
310 if c
[3] == "PRI": # record as primary key
311 if self
.primary_key
is None:
312 self
.primary_key
= field
315 self
.primary_key
.append(field
)
316 except AttributeError:
317 self
.primary_key
= [self
.primary_key
,field
]
318 if c
[1][:3].lower() == 'int':
319 self
.usesIntID
= True
321 self
.usesIntID
= False
323 self
.indexed
[field
] = icol
324 self
.description
[field
] = self
.cursor
.description
[icol
]
325 self
.columnType
[field
] = c
[1] # SQL COLUMN TYPE
327 _sqliteMacros
= dict(IGNORE
='or ignore', REPLACE
='insert or replace',
328 AUTO_INCREMENT
='', SUBSTRING
='substr',
329 SUBSTR_FROM
=',', SUBSTR_FOR
=',')
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
334 import sqlite3
as sqlite
336 from pysqlite2
import dbapi2
as sqlite
339 def sqlite_table_schema(self
, analyzeSchema
=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite
= import_sqlite()
342 self
._format
_query
= SQLFormatDict(sqlite
.paramstyle
, _sqliteMacros
)
343 if not analyzeSchema
:
345 self
.clear_schema() # reset settings and dictionaries
346 self
.cursor
.execute('PRAGMA table_info("%s")' % self
.name
)
347 columns
= self
.cursor
.fetchall()
348 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
349 for icol
,c
in enumerate(columns
):
351 self
.columnName
.append(field
) # list of columns in same order as table
352 self
.description
[field
] = self
.cursor
.description
[icol
]
353 self
.columnType
[field
] = c
[2] # SQL COLUMN TYPE
354 self
.cursor
.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self
.name
) # get primary key / unique indexes
355 for indexname
in self
.cursor
.fetchall(): # search indexes for primary key
356 self
.cursor
.execute('PRAGMA index_info("%s")' % indexname
)
357 l
= self
.cursor
.fetchall() # get list of columns in this index
358 if len(l
) == 1: # assume 1st single-column unique index is primary key!
359 self
.primary_key
= l
[0][2]
360 break # done searching for primary key!
361 if self
.primary_key
is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self
.cursor
.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self
.name
)
363 sql
= self
.cursor
.fetchall()[0][0]
364 for columnSQL
in sql
[sql
.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL
.lower(): # must be the primary key!
366 col
= columnSQL
.split()[0] # get column name
367 if col
in self
.columnType
:
368 self
.primary_key
= col
369 break # done searching for primary key!
371 raise ValueError('unknown primary key %s in table %s'
373 if self
.primary_key
is not None: # check its type
374 if self
.columnType
[self
.primary_key
] == 'int' or \
375 self
.columnType
[self
.primary_key
] == 'integer':
376 self
.usesIntID
= True
378 self
.usesIntID
= False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats
= dict(pyformat
='%%(%d)s', numeric
=':%d', named
=':%d',
393 qmark
='(ignore)', format
='(ignore)')
394 def __init__(self
, paramstyle
, substitutionDict
={}):
395 self
.substitutionDict
= substitutionDict
.copy()
396 self
.paramstyle
= paramstyle
397 self
.paramFormat
= self
._paramFormats
[paramstyle
]
398 self
.makeDict
= (paramstyle
== 'pyformat' or paramstyle
== 'named')
399 if paramstyle
== 'qmark': # handle these as simple substitution
400 self
.substitutionDict
['?'] = '?'
401 elif paramstyle
== 'format':
402 self
.substitutionDict
['?'] = '%s'
403 def __getitem__(self
, k
):
404 'apply correct substitution for this SQL interface'
406 return self
.substitutionDict
[k
] # apply our substitutions
409 if k
== '?': # sequential parameter
410 s
= self
.paramFormat
% self
.iparam
411 self
.iparam
+= 1 # advance to the next parameter
413 raise KeyError('unknown macro: %s' % k
)
414 def __call__(self
, sql
, paramList
):
415 'returns corrected sql,params for this interface'
416 self
.iparam
= 1 # DB-ABI param indexing begins at 1
417 sql
= sql
.replace('%s', '%(?)s') # convert format into pyformat
418 s
= sql
% self
# apply all %(x)s replacements in sql
419 if self
.makeDict
: # construct a params dict
421 for i
,param
in enumerate(paramList
):
422 paramDict
[str(i
+ 1)] = param
#DB-ABI param indexing begins at 1
424 else: # just return the original params list
427 def get_table_schema(self
, analyzeSchema
=True):
428 'run the right schema function based on type of db server connection'
430 modname
= self
.cursor
.__class
__.__module
__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
434 schema_func
= self
._schemaModuleDict
[modname
]
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname
)
439 schema_func(self
, analyzeSchema
) # run the schema function
442 _schemaModuleDict
= {'MySQLdb.cursors':mysql_table_schema
,
443 'pysqlite2.dbapi2':sqlite_table_schema
,
444 'sqlite3':sqlite_table_schema
}
446 class SQLTableBase(object, UserDict
.DictMixin
):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict
= _schemaModuleDict
# default module list
449 get_table_schema
= get_table_schema
450 def __init__(self
,name
,cursor
=None,itemClass
=None,attrAlias
=None,
451 clusterKey
=None,createTable
=None,graph
=None,maxCache
=None,
452 arraysize
=1024, itemSliceClass
=None, dropIfExists
=False,
453 serverInfo
=None, autoGC
=True, orderBy
=None,
454 writeable
=False, iterSQL
=None, iterColumns
=None, **kwargs
):
455 if autoGC
: # automatically garbage collect unused objects
456 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
458 self
._weakValueDict
= {}
460 self
.orderBy
= orderBy
461 if orderBy
and serverInfo
and serverInfo
._serverType
== 'mysql':
462 if iterSQL
and iterColumns
: # both required for mysql!
463 self
.iterSQL
, self
.iterColumns
= iterSQL
, iterColumns
465 raise ValueError('For MySQL tables with orderBy, you MUST specify iterSQL and iterColumns as well!')
467 self
.writeable
= writeable
469 if serverInfo
is not None: # get cursor from serverInfo
470 cursor
= serverInfo
.cursor()
471 else: # try to read connection info from name or config file
472 name
,cursor
,serverInfo
= get_name_cursor(name
,**kwargs
)
474 warnings
.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
475 DeprecationWarning, stacklevel
=2)
477 if createTable
is not None: # RUN COMMAND TO CREATE THIS TABLE
478 if dropIfExists
: # get rid of any existing table
479 cursor
.execute('drop table if exists ' + name
)
480 self
.get_table_schema(False) # check dbtype, init _format_query
481 sql
,params
= self
._format
_query
(createTable
, ()) # apply macros
482 cursor
.execute(sql
) # create the table
484 if graph
is not None:
486 if maxCache
is not None:
487 self
.maxCache
= maxCache
488 if arraysize
is not None:
489 self
.arraysize
= arraysize
490 cursor
.arraysize
= arraysize
491 self
.get_table_schema() # get schema of columns to serve as attrs
492 self
.data
= {} # map of all attributes, including aliases
493 for icol
,field
in enumerate(self
.columnName
):
494 self
.data
[field
] = icol
# 1st add mappings to columns
496 self
.data
['id']=self
.data
[self
.primary_key
]
497 except (KeyError,TypeError):
499 if hasattr(self
,'_attr_alias'): # apply attribute aliases for this class
500 self
.addAttrAlias(False,**self
._attr
_alias
)
501 self
.objclass(itemClass
) # NEED TO SUBCLASS OUR ITEM CLASS
502 if itemSliceClass
is not None:
503 self
.itemSliceClass
= itemSliceClass
504 get_bound_subclass(self
, 'itemSliceClass', self
.name
) # need to subclass itemSliceClass
505 if attrAlias
is not None: # ADD ATTRIBUTE ALIASES
506 self
.attrAlias
= attrAlias
# RECORD FOR PICKLING PURPOSES
507 self
.data
.update(attrAlias
)
508 if clusterKey
is not None:
509 self
.clusterKey
=clusterKey
510 if serverInfo
is not None:
511 self
.serverInfo
= serverInfo
514 self
._select
(selectCols
='count(*)')
515 return self
.cursor
.fetchone()[0]
518 def __cmp__(self
, other
):
519 'only match self and no other!'
523 return cmp(id(self
), id(other
))
524 _pickleAttrs
= dict(name
=0, clusterKey
=0, maxCache
=0, arraysize
=0,
525 attrAlias
=0, serverInfo
=0, autoGC
=0, orderBy
=0,
526 writeable
=0, iterSQL
=0, iterColumns
=0)
527 __getstate__
= standard_getstate
528 def __setstate__(self
,state
):
529 # default cursor provisioning by worldbase is deprecated!
530 ## if 'serverInfo' not in state: # hmm, no address for db server?
531 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
532 ## from Data import getResource
533 ## state['cursor'] = getResource.getTableCursor(state['name'])
534 ## except ImportError:
535 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
536 self
.__init
__(**state
)
538 return '<SQL table '+self
.name
+'>'
540 def clear_schema(self
):
541 'reset all schema information for this table'
545 self
.usesIntID
= None
546 self
.primary_key
= None
548 def _attrSQL(self
,attr
,sqlColumn
=False,columnNumber
=False):
549 "Translate python attribute name to appropriate SQL expression"
550 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
551 field
=self
.data
[attr
]
553 raise AttributeError('attribute %s not a valid column or alias in %s'
555 if sqlColumn
: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
556 try: # CHECK IF field IS COLUMN NUMBER
557 return self
.columnName
[field
] # RETURN SQL COLUMN NAME
559 try: # CHECK IF field IS SQL COLUMN NAME
560 return self
.columnName
[self
.data
[field
]] # THIS WILL JUST RETURN field...
561 except (KeyError,TypeError):
562 raise AttributeError('attribute %s does not map to an SQL column in %s'
565 try: # CHECK IF field IS A COLUMN NUMBER
566 return field
+0 # ONLY RETURN AN INTEGER
568 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
569 return self
.data
[field
]+0 # ONLY RETURN AN INTEGER
570 except (KeyError,TypeError):
571 raise ValueError('attribute %s does not map to a SQL column!' % attr
)
572 if isinstance(field
,types
.StringType
):
573 attr
=field
# USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
575 attr
=self
.primary_key
577 def addAttrAlias(self
,saveToPickle
=True,**kwargs
):
578 """Add new attributes as aliases of existing attributes.
579 They can be specified either as named args:
580 t.addAttrAlias(newattr=oldattr)
581 or by passing a dictionary kwargs whose keys are newattr
582 and values are oldattr:
583 t.addAttrAlias(**kwargs)
584 saveToPickle=True forces these aliases to be saved if object is pickled.
587 self
.attrAlias
.update(kwargs
)
588 for key
,val
in kwargs
.items():
589 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
590 self
.data
[val
]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
591 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
592 except TypeError: # val IS ITSELF AN ALIAS
593 self
.data
[key
] = self
.data
[val
] # SO MAP TO WHAT IT MAPS TO
594 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
596 def objclass(self
,oclass
=None):
597 "Create class representing a row in this table by subclassing oclass, adding data"
598 if oclass
is not None: # use this as our base itemClass
599 self
.itemClass
= oclass
601 self
.itemClass
= self
.itemClass
._RWClass
# use its writeable version
602 oclass
= get_bound_subclass(self
, 'itemClass', self
.name
,
603 subclassArgs
=dict(db
=self
)) # bind itemClass
604 if issubclass(oclass
, TupleO
):
605 oclass
._attrcol
= self
.data
# BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
606 if hasattr(oclass
,'_tableclass') and not isinstance(self
,oclass
._tableclass
):
607 self
.__class
__=oclass
._tableclass
# ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
608 def _select(self
, whereClause
='', params
=(), selectCols
='t1.*',
609 cursor
=None, orderBy
='', limit
=''):
610 'execute the specified query but do not fetch'
611 sql
,params
= self
._format
_query
('select %s from %s t1 %s %s %s'
612 % (selectCols
, self
.name
, whereClause
, orderBy
,
615 self
.cursor
.execute(sql
, params
)
617 cursor
.execute(sql
, params
)
618 def select(self
,whereClause
,params
=None,oclass
=None,selectCols
='t1.*'):
619 "Generate the list of objects that satisfy the database SELECT"
621 oclass
=self
.itemClass
622 self
._select
(whereClause
,params
,selectCols
)
623 l
=self
.cursor
.fetchall()
625 yield self
.cacheItem(t
,oclass
)
626 def query(self
,**kwargs
):
627 'query for intersection of all specified kwargs, returned as iterator'
630 for k
,v
in kwargs
.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
631 if v
is None: # CONVERT TO SQL NULL TEST
632 criteria
.append('%s IS NULL' % self
._attrSQL
(k
))
633 else: # TEST FOR EQUALITY
634 criteria
.append('%s=%%s' % self
._attrSQL
(k
))
636 return self
.select('where '+' and '.join(criteria
),params
)
637 def _update(self
,row_id
,col
,val
):
638 'update a single field in the specified row to the specified value'
639 sql
,params
= self
._format
_query
('update %s set %s=%%s where %s=%%s'
640 %(self
.name
,col
,self
.primary_key
),
642 self
.cursor
.execute(sql
, params
)
645 return t
[self
.data
['id']] # GET ID FROM TUPLE
646 except TypeError: # treat as alias
647 return t
[self
.data
[self
.data
['id']]]
648 def cacheItem(self
,t
,oclass
):
649 'get obj from cache if possible, or construct from tuple'
652 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
654 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
655 return self
._weakValueDict
[id]
659 self
._weakValueDict
[id] = o
# CACHE THIS ITEM IN OUR DICTIONARY
661 def cache_items(self
,rows
,oclass
=None):
663 oclass
=self
.itemClass
665 yield self
.cacheItem(t
,oclass
)
666 def foreignKey(self
,attr
,k
):
667 'get iterator for objects with specified foreign key value'
668 return self
.select('where %s=%%s'%attr
,(k
,))
669 def limit_cache(self
):
670 'APPLY maxCache LIMIT TO CACHE SIZE'
672 if self
.maxCache
<len(self
._weakValueDict
):
673 self
._weakValueDict
.clear()
674 except AttributeError:
677 def get_new_cursor(self
):
678 """Return a new cursor object, or None if not possible """
680 new_cursor
= self
.serverInfo
.new_cursor
681 except AttributeError:
683 return new_cursor(self
.arraysize
)
685 def generic_iterator(self
, cursor
=None, fetch_f
=None, cache_f
=None,
686 map_f
=iter, cursorHolder
=None):
687 """generic iterator that runs fetch, cache and map functions.
688 cursorHolder is used only to keep a ref in this function's locals,
689 so that if it is prematurely terminated (by deleting its
690 iterator), cursorHolder.__del__() will close the cursor."""
691 if fetch_f
is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
693 fetch_f
= self
.cursor
.fetchmany
694 else: # isolate this iter from other queries
695 fetch_f
= cursor
.fetchmany
697 cache_f
= self
.cache_items
700 rows
= fetch_f() # FETCH THE NEXT SET OF ROWS
701 if len(rows
)==0: # NO MORE DATA SO ALL DONE
703 for v
in map_f(cache_f(rows
)): # CACHE AND GENERATE RESULTS
705 def tuple_from_dict(self
, d
):
706 'transform kwarg dict into tuple for storing in database'
707 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
708 for col
,icol
in self
.data
.items():
711 except (KeyError,TypeError):
714 def tuple_from_obj(self
, obj
):
715 'transform object attributes into tuple for storing in database'
716 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
717 for col
,icol
in self
.data
.items():
719 l
[icol
] = getattr(obj
,col
)
720 except (AttributeError,TypeError):
723 def _insert(self
, l
):
724 '''insert tuple into the database. Note this uses the MySQL
725 extension REPLACE, which overwrites any duplicate key.'''
726 s
= '%(REPLACE)s into ' + self
.name
+ ' values (' \
727 + ','.join(['%s']*len(l
)) + ')'
728 sql
,params
= self
._format
_query
(s
, l
)
729 self
.cursor
.execute(sql
, params
)
730 def insert(self
, obj
):
731 '''insert new row by transforming obj to tuple of values'''
732 l
= self
.tuple_from_obj(obj
)
734 def get_insert_id(self
):
735 'get the primary key value for the last INSERT'
736 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
737 auto_id
= self
.cursor
.lastrowid
738 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
739 raise NotImplementedError('''your db lacks lastrowid support?''')
741 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
743 def new(self
, **kwargs
):
744 'return a new record with the assigned attributes, added to DB'
745 if not self
.writeable
:
746 raise ValueError('this database is read only!')
747 obj
= self
.itemClass(None, newRow
=True, **kwargs
) # saves itself to db
748 self
._weakValueDict
[obj
.id] = obj
# AND SAVE TO OUR LOCAL DICT CACHE
750 def clear_cache(self
):
752 self
._weakValueDict
.clear()
753 def __delitem__(self
, k
):
754 if not self
.writeable
:
755 raise ValueError('this database is read only!')
756 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
757 % (self
.name
,self
.primary_key
),(k
,))
758 self
.cursor
.execute(sql
, params
)
760 del self
._weakValueDict
[k
]
764 def getKeys(self
,queryOption
='', selectCols
=None):
765 'uses db select; does not force load'
766 if selectCols
is None:
767 selectCols
=self
.primary_key
768 if queryOption
=='' and self
.orderBy
is not None:
769 queryOption
= self
.orderBy
# apply default ordering
770 self
.cursor
.execute('select %s from %s %s'
771 %(selectCols
,self
.name
,queryOption
))
772 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
774 def iter_keys(self
, selectCols
=None, orderBy
='', map_f
=iter,
775 cache_f
=lambda x
:[t
[0] for t
in x
], get_f
=None, **kwargs
):
776 'guarantee correct iteration insulated from other queries'
777 if selectCols
is None:
778 selectCols
=self
.primary_key
779 if orderBy
=='' and self
.orderBy
is not None:
780 orderBy
= self
.orderBy
# apply default ordering
781 cursor
= self
.get_new_cursor()
782 if cursor
: # got our own cursor, guaranteeing query isolation
783 if hasattr(self
.serverInfo
, 'iter_keys') \
784 and self
.serverInfo
.custom_iter_keys
:
785 # use custom iter_keys() method from serverInfo
786 return self
.serverInfo
.iter_keys(self
, cursor
, selectCols
=selectCols
,
787 map_f
=map_f
, orderBy
=orderBy
,
788 cache_f
=cache_f
, **kwargs
)
790 self
._select
(cursor
=cursor
, selectCols
=selectCols
,
791 orderBy
=orderBy
, **kwargs
)
792 return self
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
,
794 cursorHolder
=CursorCloser(cursor
))
795 else: # must pre-fetch all keys to ensure query isolation
796 if get_f
is not None:
799 return iter(self
.keys())
801 class SQLTable(SQLTableBase
):
802 "Provide on-the-fly access to rows in the database, caching the results in dict"
803 itemClass
= TupleO
# our default itemClass; constructor can override
806 def load(self
,oclass
=None):
807 "Load all data from the table"
808 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
809 return self
._isLoaded
810 except AttributeError:
813 oclass
=self
.itemClass
814 self
.cursor
.execute('select * from %s' % self
.name
)
815 l
=self
.cursor
.fetchall()
816 self
._weakValueDict
= {} # just store the whole dataset in memory
818 self
.cacheItem(t
,oclass
) # CACHE IT IN LOCAL DICTIONARY
819 self
._isLoaded
=True # MARK THIS CONTAINER AS FULLY LOADED
821 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
823 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
824 except KeyError: # NOT FOUND, SO TRY THE DATABASE
825 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 2'
826 % (self
.name
,self
.primary_key
),(k
,))
827 self
.cursor
.execute(sql
, params
)
828 l
= self
.cursor
.fetchmany(2) # get at most 2 rows
830 raise KeyError('%s not found in %s, or not unique' %(str(k
),self
.name
))
832 return self
.cacheItem(l
[0],self
.itemClass
) # CACHE IT IN LOCAL DICTIONARY
833 def __setitem__(self
, k
, v
):
834 if not self
.writeable
:
835 raise ValueError('this database is read only!')
839 except AttributeError:
840 raise ValueError('object not bound to itemClass for this db!')
845 except AttributeError:
847 else: # delete row with old ID
849 v
.cache_id(k
) # cache the new ID on the object
850 self
.insert(v
) # SAVE TO THE RELATIONAL DB SERVER
851 self
._weakValueDict
[k
] = v
# CACHE THIS ITEM IN OUR DICTIONARY
853 'forces load of entire table into memory'
855 return [(k
,self
[k
]) for k
in self
] # apply orderBy rules...
857 'uses arraysize / maxCache and fetchmany() to manage data transfer'
858 return iter_keys(self
, selectCols
='*', cache_f
=None,
859 map_f
=generate_items
, get_f
=self
.items
)
861 'forces load of entire table into memory'
863 return [self
[k
] for k
in self
] # apply orderBy rules...
864 def itervalues(self
):
865 'uses arraysize / maxCache and fetchmany() to manage data transfer'
866 return iter_keys(self
, selectCols
='*', cache_f
=None, get_f
=self
.values
)
868 def getClusterKeys(self
,queryOption
=''):
869 'uses db select; does not force load'
870 self
.cursor
.execute('select distinct %s from %s %s'
871 %(self
.clusterKey
,self
.name
,queryOption
))
872 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
875 class SQLTableClustered(SQLTable
):
876 '''use clusterKey to load a whole cluster of rows at once,
877 specifically, all rows that share the same clusterKey value.'''
878 def __init__(self
, *args
, **kwargs
):
879 kwargs
= kwargs
.copy() # get a copy we can alter
880 kwargs
['autoGC'] = False # don't use WeakValueDictionary
881 SQLTable
.__init
__(self
, *args
, **kwargs
)
883 return getKeys(self
,'order by %s' %self
.clusterKey
)
884 def clusterkeys(self
):
885 return getClusterKeys(self
, 'order by %s' %self
.clusterKey
)
886 def __getitem__(self
,k
):
888 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
889 except KeyError: # NOT FOUND, SO TRY THE DATABASE
890 sql
,params
= self
._format
_query
('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
891 % (self
.name
,self
.name
,self
.primary_key
,
892 self
.clusterKey
,self
.clusterKey
),(k
,))
893 self
.cursor
.execute(sql
, params
)
894 l
=self
.cursor
.fetchall()
896 for t
in l
: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
897 self
.cacheItem(t
,self
.itemClass
)
898 return self
._weakValueDict
[k
] # should be in cache, if row k exists
899 def itercluster(self
,cluster_id
):
900 'iterate over all items from the specified cluster'
902 return self
.select('where %s=%%s'%self
.clusterKey
,(cluster_id
,))
903 def fetch_cluster(self
):
904 'use self.cursor.fetchmany to obtain all rows for next cluster'
905 icol
= self
._attrSQL
(self
.clusterKey
,columnNumber
=True)
908 rows
= self
._fetch
_cluster
_cache
# USE SAVED ROWS FROM PREVIOUS CALL
909 del self
._fetch
_cluster
_cache
910 except AttributeError:
911 rows
= self
.cursor
.fetchmany()
913 cluster_id
= rows
[0][icol
]
917 for i
,t
in enumerate(rows
): # CHECK THAT ALL ROWS FROM THIS CLUSTER
918 if cluster_id
!= t
[icol
]: # START OF A NEW CLUSTER
919 result
+= rows
[:i
] # RETURN ROWS OF CURRENT CLUSTER
920 self
._fetch
_cluster
_cache
= rows
[i
:] # SAVE NEXT CLUSTER
923 rows
= self
.cursor
.fetchmany() # GET NEXT SET OF ROWS
925 def itervalues(self
):
926 'uses arraysize / maxCache and fetchmany() to manage data transfer'
927 cursor
= self
.get_new_cursor()
928 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
929 return self
.generic_iterator(cursor
, self
.fetch_cluster
,
930 cursorHolder
=CursorHolder(cursor
))
932 'uses arraysize / maxCache and fetchmany() to manage data transfer'
933 cursor
= self
.get_new_cursor()
934 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
935 return self
.generic_iterator(cursor
, self
.fetch_cluster
,
936 map_f
=generate_items
,
937 cursorHolder
=CursorHolder(cursor
))
939 class SQLForeignRelation(object):
940 'mapping based on matching a foreign key in an SQL table'
941 def __init__(self
,table
,keyName
):
944 def __getitem__(self
,k
):
945 'get list of objects o with getattr(o,keyName)==k.id'
947 for o
in self
.table
.select('where %s=%%s'%self
.keyName
,(k
.id,)):
950 raise KeyError('%s not found in %s' %(str(k
),self
.name
))
954 class SQLTableNoCache(SQLTableBase
):
955 '''Provide on-the-fly access to rows in the database;
956 values are simply an object interface (SQLRow) to back-end db query.
957 Row data are not stored locally, but always accessed by querying the db'''
958 itemClass
=SQLRow
# DEFAULT OBJECT CLASS FOR ROWS...
961 def getID(self
,t
): return t
[0] # GET ID FROM TUPLE
962 def select(self
,whereClause
,params
):
963 return SQLTableBase
.select(self
,whereClause
,params
,self
.oclass
,
965 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
967 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
968 except KeyError: # NOT FOUND, SO TRY THE DATABASE
969 self
._select
('where %s=%%s' % self
.primary_key
, (k
,),
971 t
= self
.cursor
.fetchmany(2)
973 raise KeyError('id %s non-existent or not unique' % k
)
974 o
= self
.itemClass(k
) # create obj referencing this ID
975 self
._weakValueDict
[k
] = o
# cache the SQLRow object
977 def __setitem__(self
, k
, v
):
978 if not self
.writeable
:
979 raise ValueError('this database is read only!')
983 except AttributeError:
984 raise ValueError('object not bound to itemClass for this db!')
986 del self
[k
] # delete row with new ID if any
990 del self
._weakValueDict
[v
.id] # delete from old cache location
993 self
._update
(v
.id, self
.primary_key
, k
) # just change its ID in db
994 v
.cache_id(k
) # change the cached ID value
995 self
._weakValueDict
[k
] = v
# assign to new cache location
996 def addAttrAlias(self
,**kwargs
):
997 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
999 SQLRow
._tableclass
=SQLTableNoCache
# SQLRow IS FOR NON-CACHING TABLE INTERFACE
1002 class SQLTableMultiNoCache(SQLTableBase
):
1003 "Trivial on-the-fly access for table with key that returns multiple rows"
1004 itemClass
= TupleO
# default itemClass; constructor can override
1005 _distinct_key
='id' # DEFAULT COLUMN TO USE AS KEY
1007 return getKeys(self
, selectCols
='distinct(%s)'
1008 % self
._attrSQL
(self
._distinct
_key
))
1010 return iter_keys(self
, 'distinct(%s)' % self
._attrSQL
(self
._distinct
_key
))
1011 def __getitem__(self
,id):
1012 sql
,params
= self
._format
_query
('select * from %s where %s=%%s'
1013 %(self
.name
,self
._attrSQL
(self
._distinct
_key
)),(id,))
1014 self
.cursor
.execute(sql
, params
)
1015 l
=self
.cursor
.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1017 yield self
.itemClass(row
)
1018 def addAttrAlias(self
,**kwargs
):
1019 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
1023 class SQLEdges(SQLTableMultiNoCache
):
1024 '''provide iterator over edges as (source,target,edge)
1025 and getitem[edge] --> [(source,target),...]'''
1026 _distinct_key
='edge_id'
1027 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1028 _pickleAttrs
.update(dict(graph
=0))
1030 self
.cursor
.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1031 %(self
._attrSQL
('source_id'),self
._attrSQL
('target_id'),
1032 self
._attrSQL
('edge_id'),self
.name
,
1033 self
._attrSQL
('target_id'),self
._attrSQL
('source_id'),
1034 self
._attrSQL
('target_id')))
1035 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1036 for source_id
,target_id
,edge_id
in self
.cursor
.fetchall():
1037 l
.append((self
.graph
.unpack_source(source_id
),
1038 self
.graph
.unpack_target(target_id
),
1039 self
.graph
.unpack_edge(edge_id
)))
1043 return iter(self
.keys())
1044 def __getitem__(self
,edge
):
1045 sql
,params
= self
._format
_query
('select %s,%s from %s where %s=%%s'
1046 %(self
._attrSQL
('source_id'),
1047 self
._attrSQL
('target_id'),
1049 self
._attrSQL
(self
._distinct
_key
)),
1050 (self
.graph
.pack_edge(edge
),))
1051 self
.cursor
.execute(sql
, params
)
1052 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1053 for source_id
,target_id
in self
.cursor
.fetchall():
1054 l
.append((self
.graph
.unpack_source(source_id
),
1055 self
.graph
.unpack_target(target_id
)))
1059 class SQLEdgeDict(object):
1060 '2nd level graph interface to SQL database'
1061 def __init__(self
,fromNode
,table
):
1062 self
.fromNode
=fromNode
1064 if not hasattr(self
.table
,'allowMissingNodes'):
1065 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s limit 1'
1066 %(self
.table
.sourceSQL
,
1068 self
.table
.sourceSQL
),
1070 self
.table
.cursor
.execute(sql
, params
)
1071 if len(self
.table
.cursor
.fetchall())<1:
1072 raise KeyError('node not in graph!')
1074 def __getitem__(self
,target
):
1075 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1076 %(self
.table
.edgeSQL
,
1078 self
.table
.sourceSQL
,
1079 self
.table
.targetSQL
),
1081 self
.table
.pack_target(target
)))
1082 self
.table
.cursor
.execute(sql
, params
)
1083 l
= self
.table
.cursor
.fetchmany(2) # get at most two rows
1085 raise KeyError('either no edge from source to target or not unique!')
1087 return self
.table
.unpack_edge(l
[0][0]) # RETURN EDGE
1089 raise KeyError('no edge from node to target')
1090 def __setitem__(self
,target
,edge
):
1091 sql
,params
= self
.table
._format
_query
('replace into %s values (%%s,%%s,%%s)'
1094 self
.table
.pack_target(target
),
1095 self
.table
.pack_edge(edge
)))
1096 self
.table
.cursor
.execute(sql
, params
)
1097 if not hasattr(self
.table
,'sourceDB') or \
1098 (hasattr(self
.table
,'targetDB') and
1099 self
.table
.sourceDB
is self
.table
.targetDB
):
1100 self
.table
+= target
# ADD AS NODE TO GRAPH
1101 def __iadd__(self
,target
):
1103 return self
# iadd MUST RETURN self!
1104 def __delitem__(self
,target
):
1105 sql
,params
= self
.table
._format
_query
('delete from %s where %s=%%s and %s=%%s'
1107 self
.table
.sourceSQL
,
1108 self
.table
.targetSQL
),
1110 self
.table
.pack_target(target
)))
1111 self
.table
.cursor
.execute(sql
, params
)
1112 if self
.table
.cursor
.rowcount
< 1: # no rows deleted?
1113 raise KeyError('no edge from node to target')
1115 def iterator_query(self
):
1116 sql
,params
= self
.table
._format
_query
('select %s,%s from %s where %s=%%s and %s is not null'
1117 %(self
.table
.targetSQL
,
1120 self
.table
.sourceSQL
,
1121 self
.table
.targetSQL
),
1123 self
.table
.cursor
.execute(sql
, params
)
1124 return self
.table
.cursor
.fetchall()
1126 return [self
.table
.unpack_target(target_id
)
1127 for target_id
,edge_id
in self
.iterator_query()]
1129 return [self
.table
.unpack_edge(edge_id
)
1130 for target_id
,edge_id
in self
.iterator_query()]
1132 return [(self
.table
.unpack_source(self
.fromNode
),self
.table
.unpack_target(target_id
),
1133 self
.table
.unpack_edge(edge_id
))
1134 for target_id
,edge_id
in self
.iterator_query()]
1136 return [(self
.table
.unpack_target(target_id
),self
.table
.unpack_edge(edge_id
))
1137 for target_id
,edge_id
in self
.iterator_query()]
1138 def __iter__(self
): return iter(self
.keys())
1139 def itervalues(self
): return iter(self
.values())
1140 def iteritems(self
): return iter(self
.items())
1142 return len(self
.keys())
1145 class SQLEdgelessDict(SQLEdgeDict
):
1146 'for SQLGraph tables that lack edge_id column'
1147 def __getitem__(self
,target
):
1148 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1149 %(self
.table
.targetSQL
,
1151 self
.table
.sourceSQL
,
1152 self
.table
.targetSQL
),
1154 self
.table
.pack_target(target
)))
1155 self
.table
.cursor
.execute(sql
, params
)
1156 l
= self
.table
.cursor
.fetchmany(2)
1158 raise KeyError('either no edge from source to target or not unique!')
1159 return None # no edge info!
1160 def iterator_query(self
):
1161 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s is not null'
1162 %(self
.table
.targetSQL
,
1164 self
.table
.sourceSQL
,
1165 self
.table
.targetSQL
),
1167 self
.table
.cursor
.execute(sql
, params
)
1168 return [(t
[0],None) for t
in self
.table
.cursor
.fetchall()]
1170 SQLEdgeDict
._edgelessClass
= SQLEdgelessDict
1172 class SQLGraphEdgeDescriptor(object):
1173 'provide an SQLEdges interface on demand'
1174 def __get__(self
,obj
,objtype
):
1176 attrAlias
=obj
.attrAlias
.copy()
1177 except AttributeError:
1178 return SQLEdges(obj
.name
, obj
.cursor
, graph
=obj
)
1180 return SQLEdges(obj
.name
, obj
.cursor
, attrAlias
=attrAlias
,
1183 def getColumnTypes(createTable
,attrAlias
={},defaultColumnType
='int',
1184 columnAttrs
=('source','target','edge'),**kwargs
):
1185 'return list of [(colname,coltype),...] for source,target,edge'
1187 for attr
in columnAttrs
:
1189 attrName
= attrAlias
[attr
+'_id']
1191 attrName
= attr
+'_id'
1192 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1193 l
.append((attrName
,createTable
[attr
+'_id']))
1195 except (KeyError,TypeError):
1197 try: # get type info from primary key for that database
1198 db
= kwargs
[attr
+'DB']
1200 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1203 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1205 try: # GET ONE IDENTIFIER FROM THE DATABASE
1207 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1209 l
.append((attrName
,db
.columnType
[db
.primary_key
]))
1211 except AttributeError:
1213 else: # GET THE TYPE FROM THIS IDENTIFIER
1214 if isinstance(k
,int) or isinstance(k
,long):
1215 l
.append((attrName
,'int'))
1217 elif isinstance(k
,str):
1218 l
.append((attrName
,'varchar(32)'))
1221 raise ValueError('SQLGraph node / edge must be int or str!')
1222 l
.append((attrName
,defaultColumnType
))
1223 logger
.warn('no type info found for %s, so using default: %s'
1224 % (attrName
, defaultColumnType
))
1230 class SQLGraph(SQLTableMultiNoCache
):
1231 '''provide a graph interface via a SQL table. Key capabilities are:
1232 - setitem with an empty dictionary: a dummy operation
1233 - getitem with a key that exists: return a placeholder
1234 - setitem with non empty placeholder: again a dummy operation
1235 EXAMPLE TABLE SCHEMA:
1236 create table mygraph (source_id int not null,target_id int,edge_id int,
1237 unique(source_id,target_id));
1239 _distinct_key
='source_id'
1240 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1241 _pickleAttrs
.update(dict(sourceDB
=0,targetDB
=0,edgeDB
=0,allowMissingNodes
=0))
1242 _edgeClass
= SQLEdgeDict
1243 def __init__(self
,name
,*l
,**kwargs
):
1244 graphArgs
,tableArgs
= split_kwargs(kwargs
,
1245 ('attrAlias','defaultColumnType','columnAttrs',
1246 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1247 'edgeDictClass','graph'))
1248 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1249 c
= getColumnTypes(**kwargs
)
1250 tableArgs
['createTable'] = \
1251 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1252 % (name
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1254 self
.allowMissingNodes
= kwargs
['allowMissingNodes']
1255 except KeyError: pass
1256 SQLTableMultiNoCache
.__init
__(self
,name
,*l
,**tableArgs
)
1257 self
.sourceSQL
= self
._attrSQL
('source_id')
1258 self
.targetSQL
= self
._attrSQL
('target_id')
1260 self
.edgeSQL
= self
._attrSQL
('edge_id')
1261 except AttributeError:
1263 self
._edgeClass
= self
._edgeClass
._edgelessClass
1264 save_graph_db_refs(self
,**kwargs
)
1265 def __getitem__(self
,k
):
1266 return self
._edgeClass
(self
.pack_source(k
),self
)
1267 def __iadd__(self
,k
):
1268 sql
,params
= self
._format
_query
('delete from %s where %s=%%s and %s is null'
1269 % (self
.name
,self
.sourceSQL
,self
.targetSQL
),
1270 (self
.pack_source(k
),))
1271 self
.cursor
.execute(sql
, params
)
1272 sql
,params
= self
._format
_query
('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1273 % self
.name
,(self
.pack_source(k
),))
1274 self
.cursor
.execute(sql
, params
)
1275 return self
# iadd MUST RETURN SELF!
1276 def __isub__(self
,k
):
1277 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
1278 % (self
.name
,self
.sourceSQL
),
1279 (self
.pack_source(k
),))
1280 self
.cursor
.execute(sql
, params
)
1281 if self
.cursor
.rowcount
== 0:
1282 raise KeyError('node not found in graph')
1283 return self
# iadd MUST RETURN SELF!
1284 __setitem__
= graph_setitem
1285 def __contains__(self
,k
):
1286 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 1'
1287 %(self
.name
,self
.sourceSQL
),
1288 (self
.pack_source(k
),))
1289 self
.cursor
.execute(sql
, params
)
1290 l
= self
.cursor
.fetchmany(2)
1292 def __invert__(self
):
1293 'get an interface to the inverse graph mapping'
1295 return self
._inverse
1296 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1297 attrAlias
= dict(source_id
=self
.targetSQL
, # SWAP SOURCE & TARGET
1298 target_id
=self
.sourceSQL
,
1299 edge_id
=self
.edgeSQL
)
1300 if self
.edgeSQL
is None: # no edge interface
1301 del attrAlias
['edge_id']
1302 self
._inverse
=SQLGraph(self
.name
,self
.cursor
,
1303 attrAlias
=attrAlias
,
1304 **graph_db_inverse_refs(self
))
1305 self
._inverse
._inverse
=self
1306 return self
._inverse
1308 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1309 yield self
.unpack_source(k
)
1310 def iteritems(self
):
1311 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1312 yield (self
.unpack_source(k
), self
._edgeClass
(k
, self
))
1313 def itervalues(self
):
1314 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1315 yield self
._edgeClass
(k
, self
)
1317 return [self
.unpack_source(k
) for k
in SQLTableMultiNoCache
.keys(self
)]
1318 def values(self
): return list(self
.itervalues())
1319 def items(self
): return list(self
.iteritems())
1320 edges
=SQLGraphEdgeDescriptor()
1321 update
= update_graph
1323 'get number of source nodes in graph'
1324 self
.cursor
.execute('select count(distinct %s) from %s'
1325 %(self
.sourceSQL
,self
.name
))
1326 return self
.cursor
.fetchone()[0]
1328 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1329 ## def __cmp__(self,other):
1333 ## it = iter(self.edges)
1336 ## source,target,edge = it.next()
1337 ## except StopIteration:
1340 ## if d is not None:
1341 ## diff = cmp(n_target,len(d))
1344 ## if source is None:
1347 ## n += 1 # COUNT SOURCE NODES
1354 ## diff = cmp(edge,d[target])
1359 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1360 ## return cmp(n,len(other))
1362 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1364 class SQLIDGraph(SQLGraph
):
1365 add_trivial_packing_methods(locals())
1366 SQLGraph
._IDGraphClass
= SQLIDGraph
1370 class SQLEdgeDictClustered(dict):
1371 'simple cache for 2nd level dictionary of target_id:edge_id'
1372 def __init__(self
,g
,fromNode
):
1374 self
.fromNode
=fromNode
1376 def __iadd__(self
,l
):
1377 for target_id
,edge_id
in l
:
1378 dict.__setitem
__(self
,target_id
,edge_id
)
1379 return self
# iadd MUST RETURN SELF!
1381 class SQLEdgesClusteredDescr(object):
1382 def __get__(self
,obj
,objtype
):
1383 e
=SQLEdgesClustered(obj
.table
,obj
.edge_id
,obj
.source_id
,obj
.target_id
,
1384 graph
=obj
,**graph_db_inverse_refs(obj
,True))
1385 for source_id
,d
in obj
.d
.iteritems(): # COPY EDGE CACHE
1386 e
.load([(edge_id
,source_id
,target_id
)
1387 for (target_id
,edge_id
) in d
.iteritems()])
1390 class SQLGraphClustered(object):
1391 'SQL graph with clustered caching -- loads an entire cluster at a time'
1392 _edgeDictClass
=SQLEdgeDictClustered
1393 def __init__(self
,table
,source_id
='source_id',target_id
='target_id',
1394 edge_id
='edge_id',clusterKey
=None,**kwargs
):
1396 if isinstance(table
,types
.StringType
): # CREATE THE TABLE INTERFACE
1397 if clusterKey
is None:
1398 raise ValueError('you must provide a clusterKey argument!')
1399 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1400 c
= getColumnTypes(attrAlias
=dict(source_id
=source_id
,target_id
=target_id
,
1401 edge_id
=edge_id
),**kwargs
)
1402 kwargs
['createTable'] = \
1403 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1404 % (table
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],
1405 c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1406 table
= SQLTableClustered(table
,clusterKey
=clusterKey
,**kwargs
)
1408 self
.source_id
=source_id
1409 self
.target_id
=target_id
1410 self
.edge_id
=edge_id
1412 save_graph_db_refs(self
,**kwargs
)
1413 _pickleAttrs
= dict(table
=0,source_id
=0,target_id
=0,edge_id
=0,sourceDB
=0,targetDB
=0,
1415 def __getstate__(self
):
1416 state
= standard_getstate(self
)
1417 state
['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1419 def __getitem__(self
,k
):
1420 'get edgeDict for source node k, from cache or by loading its cluster'
1421 try: # GET DIRECTLY FROM CACHE
1424 if hasattr(self
,'_isLoaded'):
1425 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1426 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1427 sql
,params
= self
.table
._format
_query
('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1428 %(self
.source_id
,self
.target_id
,
1429 self
.edge_id
,self
.table
.name
,
1430 self
.table
.name
,self
.source_id
,
1431 self
.table
.clusterKey
,self
.table
.clusterKey
,
1432 self
.table
.primary_key
),
1433 (self
.pack_source(k
),))
1434 self
.table
.cursor
.execute(sql
, params
)
1435 self
.load(self
.table
.cursor
.fetchall()) # CACHE THIS CLUSTER
1436 return self
.d
[k
] # RETURN EDGE DICT FOR THIS NODE
1437 def load(self
,l
=None,unpack
=True):
1438 'load the specified rows (or all, if None provided) into local cache'
1440 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1441 return self
._isLoaded
1442 except AttributeError:
1444 self
.table
.cursor
.execute('select %s,%s,%s from %s'
1445 %(self
.source_id
,self
.target_id
,
1446 self
.edge_id
,self
.table
.name
))
1447 l
=self
.table
.cursor
.fetchall()
1449 self
.d
.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1450 for source
,target
,edge
in l
: # SAVE TO OUR CACHE
1452 source
= self
.unpack_source(source
)
1453 target
= self
.unpack_target(target
)
1454 edge
= self
.unpack_edge(edge
)
1456 self
.d
[source
] += [(target
,edge
)]
1458 d
= self
._edgeDictClass
(self
,source
)
1459 d
+= [(target
,edge
)]
1461 def __invert__(self
):
1462 'interface to reverse graph mapping'
1464 return self
._inverse
# INVERSE MAP ALREADY EXISTS
1465 except AttributeError:
1467 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1468 self
._inverse
=SQLGraphClustered(self
.table
,self
.target_id
,self
.source_id
,
1469 self
.edge_id
,**graph_db_inverse_refs(self
))
1470 self
._inverse
._inverse
=self
1471 for source
,d
in self
.d
.iteritems(): # INVERT OUR CACHE
1472 self
._inverse
.load([(target
,source
,edge
)
1473 for (target
,edge
) in d
.iteritems()],unpack
=False)
1474 return self
._inverse
1475 edges
=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1476 update
= update_graph
1477 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1478 def __iter__(self
): ################# ITERATORS
1479 'uses db select; does not force load'
1480 return iter(self
.keys())
1482 'uses db select; does not force load'
1483 self
.table
.cursor
.execute('select distinct(%s) from %s'
1484 %(self
.source_id
,self
.table
.name
))
1485 return [self
.unpack_source(t
[0])
1486 for t
in self
.table
.cursor
.fetchall()]
1487 methodFactory(['iteritems','items','itervalues','values'],
1488 'lambda self:(self.load(),self.d.%s())[1]',locals())
1489 def __contains__(self
,k
):
1496 class SQLIDGraphClustered(SQLGraphClustered
):
1497 add_trivial_packing_methods(locals())
1498 SQLGraphClustered
._IDGraphClass
= SQLIDGraphClustered
1500 class SQLEdgesClustered(SQLGraphClustered
):
1501 'edges interface for SQLGraphClustered'
1502 _edgeDictClass
= list
1503 _pickleAttrs
= SQLGraphClustered
._pickleAttrs
.copy()
1504 _pickleAttrs
.update(dict(graph
=0))
1508 for edge_id
,l
in self
.d
.iteritems():
1509 for source_id
,target_id
in l
:
1510 result
.append((self
.graph
.unpack_source(source_id
),
1511 self
.graph
.unpack_target(target_id
),
1512 self
.graph
.unpack_edge(edge_id
)))
1515 class ForeignKeyInverse(object):
1516 'map each key to a single value according to its foreign key'
1517 def __init__(self
,g
):
1519 def __getitem__(self
,obj
):
1521 source_id
= getattr(obj
,self
.g
.keyColumn
)
1522 if source_id
is None:
1524 return self
.g
.sourceDB
[source_id
]
1525 def __setitem__(self
,obj
,source
):
1527 if source
is not None:
1528 self
.g
[source
][obj
] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1529 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1530 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1531 old_source
= self
[obj
]
1532 if old_source
is not None:
1533 del self
.g
[old_source
][obj
]
1534 def check_obj(self
,obj
):
1535 'raise KeyError if obj not from this db'
1537 if obj
.db
is not self
.g
.targetDB
:
1538 raise AttributeError
1539 except AttributeError:
1540 raise KeyError('key is not from targetDB of this graph!')
1541 def __contains__(self
,obj
):
1548 return self
.g
.targetDB
.itervalues()
1550 return self
.g
.targetDB
.values()
1551 def iteritems(self
):
1553 source_id
= getattr(obj
,self
.g
.keyColumn
)
1554 if source_id
is None:
1557 yield obj
,self
.g
.sourceDB
[source_id
]
1559 return list(self
.iteritems())
1560 def itervalues(self
):
1561 for obj
,val
in self
.iteritems():
1564 return list(self
.itervalues())
1565 def __invert__(self
):
1568 class ForeignKeyEdge(dict):
1569 '''edge interface to a foreign key in an SQL table.
1570 Caches dict of target nodes in itself; provides dict interface.
1571 Adds or deletes edges by setting foreign key values in the table'''
1572 def __init__(self
,g
,k
):
1576 for v
in g
.targetDB
.select('where %s=%%s' % g
.keyColumn
,(k
.id,)): # SEARCH THE DB
1577 dict.__setitem
__(self
,v
,None) # SAVE IN CACHE
1578 def __setitem__(self
,dest
,v
):
1579 if not hasattr(dest
,'db') or dest
.db
is not self
.g
.targetDB
:
1580 raise KeyError('dest is not in the targetDB bound to this graph!')
1582 raise ValueError('sorry,this graph cannot store edge information!')
1583 if not hasattr(dest
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1584 old_source
= self
.g
._inverse
[dest
] # CHECK FOR PRE-EXISTING EDGE
1585 if old_source
is not None: # REMOVE OLD EDGE FROM CACHE
1586 dict.__delitem
__(self
.g
[old_source
],dest
)
1587 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1588 setattr(dest
,self
.g
.keyColumn
,self
.src
.id) # SAVE TO DB ATTRIBUTE
1589 dict.__setitem
__(self
,dest
,None) # SAVE IN CACHE
1590 def __delitem__(self
,dest
):
1591 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1592 setattr(dest
,self
.g
.keyColumn
,None) # SAVE TO DB ATTRIBUTE
1593 dict.__delitem
__(self
,dest
) # REMOVE FROM CACHE
1595 class ForeignKeyGraph(object, UserDict
.DictMixin
):
1596 '''graph interface to a foreign key in an SQL table
1597 Caches dict of target nodes in itself; provides dict interface.
1599 def __init__(self
, sourceDB
, targetDB
, keyColumn
, autoGC
=True, **kwargs
):
1600 '''sourceDB is any database of source nodes;
1601 targetDB must be an SQL database of target nodes;
1602 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1603 if autoGC
: # automatically garbage collect unused objects
1604 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
1606 self
._weakValueDict
= {}
1607 self
.autoGC
= autoGC
1608 self
.sourceDB
= sourceDB
1609 self
.targetDB
= targetDB
1610 self
.keyColumn
= keyColumn
1611 self
._inverse
= ForeignKeyInverse(self
)
1612 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, keyColumn
=0, autoGC
=0)
1613 __getstate__
= standard_getstate
########### SUPPORT FOR PICKLING
1614 __setstate__
= standard_setstate
1615 def _inverse_schema(self
):
1616 'provide custom schema rule for inverting this graph... just use keyColumn!'
1617 return dict(invert
=True,uniqueMapping
=True)
1618 def __getitem__(self
,k
):
1619 if not hasattr(k
,'db') or k
.db
is not self
.sourceDB
:
1620 raise KeyError('object is not in the sourceDB bound to this graph!')
1622 return self
._weakValueDict
[k
.id] # get from cache
1625 d
= ForeignKeyEdge(self
,k
)
1626 self
._weakValueDict
[k
.id] = d
# save in cache
1628 def __setitem__(self
, k
, v
):
1629 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1630 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1631 def __delitem__(self
, k
):
1632 raise KeyError('''Instead of del g[k], follow a graph
1633 interface: del g[src][dest]''')
1635 return self
.sourceDB
.values()
1636 __invert__
= standard_invert
1638 def describeDBTables(name
,cursor
,idDict
):
1640 Get table info about database <name> via <cursor>, and store primary keys
1641 in idDict, along with a list of the tables each key indexes.
1643 cursor
.execute('use %s' % name
)
1644 cursor
.execute('show tables')
1646 l
=[c
[0] for c
in cursor
.fetchall()]
1649 o
=SQLTable(tname
,cursor
)
1651 for f
in o
.description
:
1652 if f
==o
.primary_key
:
1653 idDict
.setdefault(f
, []).append(o
)
1654 elif f
[-3:]=='_id' and f
not in idDict
:
1660 def indexIDs(tables
,idDict
=None):
1661 "Get an index of primary keys in the <tables> dictionary."
1664 for o
in tables
.values():
1666 if o
.primary_key
not in idDict
:
1667 idDict
[o
.primary_key
]=[]
1668 idDict
[o
.primary_key
].append(o
) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1669 for f
in o
.description
:
1670 if f
[-3:]=='_id' and f
not in idDict
:
1676 def suffixSubset(tables
,suffix
):
1677 "Filter table index for those matching a specific suffix"
1679 for name
,t
in tables
.items():
1680 if name
.endswith(suffix
):
1687 def graphDBTables(tables
,idDict
):
1689 for t
in tables
.values():
1690 for f
in t
.description
:
1691 if f
==t
.primary_key
:
1692 edgeInfo
=PRIMARY_KEY
1695 g
.setEdge(f
,t
,edgeInfo
)
1696 g
.setEdge(t
,f
,edgeInfo
)
1699 SQLTypeTranslation
= {types
.StringType
:'varchar(32)',
1700 types
.IntType
:'int',
1701 types
.FloatType
:'float'}
1703 def createTableFromRepr(rows
,tableName
,cursor
,typeTranslation
=None,
1704 optionalDict
=None,indexDict
=()):
1705 """Save rows into SQL tableName using cursor, with optional
1706 translations of columns to specific SQL types (specified
1707 by typeTranslation dict).
1708 - optionDict can specify columns that are allowed to be NULL.
1709 - indexDict can specify columns that must be indexed; columns
1710 whose names end in _id will be indexed by default.
1711 - rows must be an iterator which in turn returns dictionaries,
1712 each representing a tuple of values (indexed by their column
1716 row
=rows
.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1717 except StopIteration:
1718 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1720 createTableFromRow(cursor
, tableName
,row
,typeTranslation
,
1721 optionalDict
,indexDict
)
1724 storeRow(cursor
,tableName
,row
) # SAVE OUR FIRST ROW
1725 for row
in rows
: # NOW SAVE ALL THE ROWS
1726 storeRow(cursor
,tableName
,row
)
1728 def createTableFromRow(cursor
, tableName
, row
,typeTranslation
=None,
1729 optionalDict
=None,indexDict
=()):
1731 for col
,val
in row
.items(): # PREPARE SQL TYPES FOR COLUMNS
1733 if typeTranslation
!=None and col
in typeTranslation
:
1734 coltype
=typeTranslation
[col
] # USER-SUPPLIED TRANSLATION
1735 elif type(val
) in SQLTypeTranslation
:
1736 coltype
=SQLTypeTranslation
[type(val
)]
1737 else: # SEARCH FOR A COMPATIBLE TYPE
1738 for t
in SQLTypeTranslation
:
1739 if isinstance(val
,t
):
1740 coltype
=SQLTypeTranslation
[t
]
1743 raise TypeError("Don't know SQL type to use for %s" % col
)
1744 create_def
='%s %s' %(col
,coltype
)
1745 if optionalDict
==None or col
not in optionalDict
:
1746 create_def
+=' not null'
1747 create_defs
.append(create_def
)
1748 for col
in row
: # CREATE INDEXES FOR ID COLUMNS
1749 if col
[-3:]=='_id' or col
in indexDict
:
1750 create_defs
.append('index(%s)' % col
)
1751 cmd
='create table if not exists %s (%s)' % (tableName
,','.join(create_defs
))
1752 cursor
.execute(cmd
) # CREATE THE TABLE IN THE DATABASE
1755 def storeRow(cursor
, tableName
, row
):
1756 row_format
=','.join(len(row
)*['%s'])
1757 cmd
='insert into %s values (%s)' % (tableName
,row_format
)
1758 cursor
.execute(cmd
,tuple(row
.values()))
1760 def storeRowDelayed(cursor
, tableName
, row
):
1761 row_format
=','.join(len(row
)*['%s'])
1762 cmd
='insert delayed into %s values (%s)' % (tableName
,row_format
)
1763 cursor
.execute(cmd
,tuple(row
.values()))
1766 class TableGroup(dict):
1767 'provide attribute access to dbname qualified tablenames'
1768 def __init__(self
,db
='test',suffix
=None,**kw
):
1771 if suffix
is not None:
1773 for k
,v
in kw
.items():
1774 if v
is not None and '.' not in v
:
1775 v
=self
.db
+'.'+v
# ADD DATABASE NAME AS PREFIX
1777 def __getattr__(self
,k
):
1780 def sqlite_connect(*args
, **kwargs
):
1781 sqlite
= import_sqlite()
1782 connection
= sqlite
.connect(*args
, **kwargs
)
1783 cursor
= connection
.cursor()
1784 return connection
, cursor
1786 class DBServerInfo(object):
1787 'picklable reference to a database server'
1788 def __init__(self
, moduleName
='MySQLdb', serverSideCursors
=True,
1789 blockIterators
=True, *args
, **kwargs
):
1791 self
.__class
__ = _DBServerModuleDict
[moduleName
]
1793 raise ValueError('Module name not found in _DBServerModuleDict: '\
1795 self
.moduleName
= moduleName
1796 self
.args
= args
# connection arguments
1797 self
.kwargs
= kwargs
1798 self
.serverSideCursors
= serverSideCursors
1799 self
.custom_iter_keys
= blockIterators
1800 if self
.serverSideCursors
and not self
.custom_iter_keys
:
1801 raise ValueError('serverSideCursors=True requires blockIterators=True!')
1804 """returns cursor associated with the DB server info (reused)"""
1807 except AttributeError:
1808 self
._start
_connection
()
1811 def new_cursor(self
, arraysize
=None):
1812 """returns a NEW cursor; you must close it yourself! """
1813 if not hasattr(self
, '_connection'):
1814 self
._start
_connection
()
1815 cursor
= self
._connection
.cursor()
1816 if arraysize
is not None:
1817 cursor
.arraysize
= arraysize
1821 """Close file containing this database"""
1822 self
._cursor
.close()
1823 self
._connection
.close()
1825 del self
._connection
1827 def __getstate__(self
):
1828 """return all picklable arguments"""
1829 return dict(args
=self
.args
, kwargs
=self
.kwargs
,
1830 moduleName
=self
.moduleName
,
1831 serverSideCursors
=self
.serverSideCursors
,
1832 blockIterators
=self
.custom_iter_keys
)
1834 def __setstate__(self
, moduleName
, serverSideCursors
, blockIterators
,
1836 self
.__init
__(moduleName
, serverSideCursors
=serverSideCursors
,
1837 blockIterators
=blockIterators
, *args
, **kwargs
)
1840 class MySQLServerInfo(DBServerInfo
):
1841 'customized for MySQLdb SSCursor support via new_cursor()'
1842 _serverType
= 'mysql'
1843 def _start_connection(self
):
1844 self
._connection
,self
._cursor
= mysql_connect(*self
.args
, **self
.kwargs
)
1845 def new_cursor(self
, arraysize
=None):
1846 'provide streaming cursor support'
1847 if not self
.serverSideCursors
: # use regular MySQLdb cursor
1848 return DBServerInfo
.new_cursor(self
, arraysize
)
1850 conn
= self
._conn
_sscursor
1851 except AttributeError:
1852 self
._conn
_sscursor
,cursor
= mysql_connect(useStreaming
=True,
1853 *self
.args
, **self
.kwargs
)
1855 cursor
= self
._conn
_sscursor
.cursor()
1856 if arraysize
is not None:
1857 cursor
.arraysize
= arraysize
1860 DBServerInfo
.close(self
)
1862 self
._conn
_sscursor
.close()
1863 del self
._conn
_sscursor
1864 except AttributeError:
1866 def iter_keys(self
, db
, cursor
, map_f
=iter,
1867 cache_f
=lambda x
:[t
[0] for t
in x
], **kwargs
):
1868 block_generator
= BlockGenerator(db
, self
, cursor
, **kwargs
)
1869 return db
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
,
1870 map_f
=map_f
, fetch_f
=block_generator
)
1872 class CursorCloser(object):
1873 """container for ensuring cursor.close() is called, when this obj deleted.
1874 For Python 2.5+, we could replace this with a try... finally clause
1875 in a generator function such as generic_iterator(); see PEP 342 or
1876 What's New in Python 2.5. """
1877 def __init__(self
, cursor
):
1878 self
.cursor
= cursor
1882 class BlockGenerator(CursorCloser
):
1883 'workaround for MySQLdb iteration horrible performance'
1884 def __init__(self
, db
, serverInfo
, cursor
, whereClause
='', **kwargs
):
1886 self
.serverInfo
= serverInfo
1887 self
.cursor
= cursor
1888 self
.kwargs
= kwargs
1889 self
.whereClause
= ''
1890 if kwargs
['orderBy']: # use iterSQL/iterColumns for WHERE / SELECT
1891 self
.whereSQL
= db
.iterSQL
1892 if kwargs
['selectCols'] == db
.primary_key
: # extract iterColumns
1893 self
.whereColumns
= ','.join(db
.iterColumns
) # required!!
1894 else: # extracting all columns
1895 self
.whereParams
= [db
.data
[col
] for col
in db
.iterColumns
]
1896 else: # just use primary key
1897 self
.whereSQL
= 'WHERE %s>%%s' % self
.db
.primary_key
1898 self
.whereParams
= (db
.data
['id'],)
1903 'get the next block of data'
1906 self
.db
._select
(self
.whereClause
, self
.params
, cursor
=self
.cursor
,
1907 limit
='LIMIT %s' % self
.cursor
.arraysize
, **(self
.kwargs
))
1908 rows
= self
.cursor
.fetchall()
1909 if len(rows
) < self
.cursor
.arraysize
: # iteration complete
1912 lastrow
= rows
[-1] # extract params from the last row in this block
1913 if len(lastrow
) > 1:
1914 self
.params
= [lastrow
[icol
] for icol
in self
.whereParams
]
1916 try: # get whereColumns values for last row
1917 self
.db
._select
('WHERE %s=%%s' % self
.db
.primary_key
,
1918 lastrow
, self
.whereColumns
, self
.cursor
)
1919 except AttributeError:
1920 self
.params
= lastrow
1922 self
.params
= self
.cursor
.fetchall()[0]
1923 self
.whereClause
= self
.whereSQL
1928 class SQLiteServerInfo(DBServerInfo
):
1929 """picklable reference to a sqlite database"""
1930 _serverType
= 'sqlite'
1931 def __init__(self
, database
, *args
, **kwargs
):
1932 """Takes same arguments as sqlite3.connect()"""
1933 DBServerInfo
.__init
__(self
, 'sqlite', # save abs path!
1934 database
=SourceFileName(database
),
1936 def _start_connection(self
):
1937 self
._connection
,self
._cursor
= sqlite_connect(*self
.args
, **self
.kwargs
)
1938 def __getstate__(self
):
1939 if self
.args
[0] == ':memory:':
1940 raise ValueError('SQLite in-memory database is not picklable!')
1941 return DBServerInfo
.__getstate
__(self
)
1943 # list of DBServerInfo subclasses for different modules
1944 _DBServerModuleDict
= dict(MySQLdb
=MySQLServerInfo
, sqlite
=SQLiteServerInfo
)
1947 class MapView(object, UserDict
.DictMixin
):
1948 'general purpose 1:1 mapping defined by any SQL query'
1949 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None,
1950 serverInfo
=None, inverseSQL
=None, **kwargs
):
1951 self
.sourceDB
= sourceDB
1952 self
.targetDB
= targetDB
1953 self
.viewSQL
= viewSQL
1954 self
.inverseSQL
= inverseSQL
1956 if serverInfo
is not None: # get cursor from serverInfo
1957 cursor
= serverInfo
.cursor()
1959 try: # can we get it from our other db?
1960 serverInfo
= sourceDB
.serverInfo
1961 except AttributeError:
1962 raise ValueError('you must provide serverInfo or cursor!')
1964 cursor
= serverInfo
.cursor()
1965 self
.cursor
= cursor
1966 self
.serverInfo
= serverInfo
1967 self
.get_sql_format(False) # get sql formatter for this db interface
1968 _schemaModuleDict
= _schemaModuleDict
# default module list
1969 get_sql_format
= get_table_schema
1970 def __getitem__(self
, k
):
1971 if not hasattr(k
,'db') or k
.db
is not self
.sourceDB
:
1972 raise KeyError('object is not in the sourceDB bound to this map!')
1973 sql
,params
= self
._format
_query
(self
.viewSQL
, (k
.id,))
1974 self
.cursor
.execute(sql
, params
) # formatted for this db interface
1975 t
= self
.cursor
.fetchmany(2) # get at most two rows
1977 raise KeyError('%s not found in MapView, or not unique'
1979 return self
.targetDB
[t
[0][0]] # get the corresponding object
1980 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, viewSQL
=0, serverInfo
=0,
1982 __getstate__
= standard_getstate
1983 __setstate__
= standard_setstate
1984 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
1985 setdefault
= read_only_error
1987 'only yield sourceDB items that are actually in this mapping!'
1988 for k
in self
.sourceDB
.itervalues():
1995 return [k
for k
in self
] # don't use list(self); causes infinite loop!
1996 def __invert__(self
):
1998 return self
._inverse
1999 except AttributeError:
2000 if self
.inverseSQL
is None:
2001 raise ValueError('this MapView has no inverseSQL!')
2002 self
._inverse
= self
.__class
__(self
.targetDB
, self
.sourceDB
,
2003 self
.inverseSQL
, self
.cursor
,
2004 serverInfo
=self
.serverInfo
,
2005 inverseSQL
=self
.viewSQL
)
2006 self
._inverse
._inverse
= self
2007 return self
._inverse
2009 class GraphViewEdgeDict(UserDict
.DictMixin
):
2010 'edge dictionary for GraphView: just pre-loaded on init'
2011 def __init__(self
, g
, k
):
2014 sql
,params
= self
.g
._format
_query
(self
.g
.viewSQL
, (k
.id,))
2015 self
.g
.cursor
.execute(sql
, params
) # run the query
2016 l
= self
.g
.cursor
.fetchall() # get results
2018 raise KeyError('key %s not in GraphView' % k
.id)
2019 self
.targets
= [t
[0] for t
in l
] # preserve order of the results
2020 d
= {} # also keep targetID:edgeID mapping
2021 if self
.g
.edgeDB
is not None: # save with edge info
2029 return len(self
.targets
)
2031 for k
in self
.targets
:
2032 yield self
.g
.targetDB
[k
]
2035 def iteritems(self
):
2036 if self
.g
.edgeDB
is not None: # save with edge info
2037 for k
in self
.targets
:
2038 yield (self
.g
.targetDB
[k
], self
.g
.edgeDB
[self
.targetDict
[k
]])
2039 else: # just save the list of targets, no edge info
2040 for k
in self
.targets
:
2041 yield (self
.g
.targetDB
[k
], None)
2042 def __getitem__(self
, o
, exitIfFound
=False):
2043 'for the specified target object, return its associated edge object'
2045 if o
.db
is not self
.g
.targetDB
:
2046 raise KeyError('key is not part of targetDB!')
2047 edgeID
= self
.targetDict
[o
.id]
2048 except AttributeError:
2049 raise KeyError('key has no id or db attribute?!')
2052 if self
.g
.edgeDB
is not None: # return the edge object
2053 return self
.g
.edgeDB
[edgeID
]
2054 else: # no edge info
2056 def __contains__(self
, o
):
2058 self
.__getitem
__(o
, True) # raise KeyError if not found
2062 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
2063 setdefault
= read_only_error
2065 class GraphView(MapView
):
2066 'general purpose graph interface defined by any SQL query'
2067 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None, edgeDB
=None,
2069 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
2070 self
.edgeDB
= edgeDB
2071 MapView
.__init
__(self
, sourceDB
, targetDB
, viewSQL
, cursor
, **kwargs
)
2072 def __getitem__(self
, k
):
2073 if not hasattr(k
,'db') or k
.db
is not self
.sourceDB
:
2074 raise KeyError('object is not in the sourceDB bound to this map!')
2075 return GraphViewEdgeDict(self
, k
)
2076 _pickleAttrs
= MapView
._pickleAttrs
.copy()
2077 _pickleAttrs
.update(dict(edgeDB
=0))
2079 # @CTB move to sqlgraph.py?
2081 class SQLSequence(SQLRow
, SequenceBase
):
2082 """Transparent access to a DB row representing a sequence.
2084 Use attrAlias dict to rename 'length' to something else.
2086 def _init_subclass(cls
, db
, **kwargs
):
2087 db
.seqInfoDict
= db
# db will act as its own seqInfoDict
2088 SQLRow
._init
_subclass
(db
=db
, **kwargs
)
2089 _init_subclass
= classmethod(_init_subclass
)
2090 def __init__(self
, id):
2091 SQLRow
.__init
__(self
, id)
2092 SequenceBase
.__init
__(self
)
2095 def strslice(self
,start
,end
):
2096 "Efficient access to slice of a sequence, useful for huge contigs"
2097 return self
._select
('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
2098 %(self
.db
._attrSQL
('seq'),start
+1,end
-start
))
2100 class DNASQLSequence(SQLSequence
):
2101 _seqtype
=DNA_SEQTYPE
2103 class RNASQLSequence(SQLSequence
):
2104 _seqtype
=RNA_SEQTYPE
2106 class ProteinSQLSequence(SQLSequence
):
2107 _seqtype
=PROTEIN_SEQTYPE