3 from __future__
import generators
5 from sequence
import SequenceBase
, DNA_SEQTYPE
, RNA_SEQTYPE
, PROTEIN_SEQTYPE
7 from classutil
import methodFactory
,standard_getstate
,\
8 override_rich_cmp
,generate_items
,get_bound_subclass
,standard_setstate
,\
9 get_valid_path
,standard_invert
,RecentValueDictionary
,read_only_error
,\
10 SourceFileName
, split_kwargs
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self
, db
, attr
):
20 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
21 def __get__(self
, obj
, klass
):
22 return obj
._data
[self
.icol
]
23 def __set__(self
, obj
, val
):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor
):
27 def __set__(self
, obj
, val
):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor
):
32 'read-write interface to named attribute'
33 def __init__(self
, db
, attr
):
35 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
36 self
.attrSQL
= db
._attrSQL
(attr
, sqlColumn
=True) # SQL column name
37 def __set__(self
, obj
, val
):
38 obj
.db
._update
(obj
.id, self
.attrSQL
, val
) # AND UPDATE THE DATABASE
39 obj
.save_local(self
.attr
, val
)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self
, db
, attr
):
44 self
.selectSQL
= db
._attrSQL
(attr
) # SQL expression for this attr
45 def __get__(self
, obj
, klass
):
46 return obj
._select
(self
.selectSQL
)
47 def __set__(self
, obj
, val
):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor
):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self
, obj
, val
):
53 obj
.db
._update
(obj
.id, self
.selectSQL
, val
) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self
, db
, attr
):
59 def __get__(self
, obj
, klass
):
60 return getattr(obj
, self
.attr
)
61 def __set__(self
, obj
, val
):
62 raise AttributeError('attribute %s is read-only!' % self
.attr
)
65 def select_from_row(row
, what
):
66 "return value from SQL expression applied to this row"
67 sql
,params
= row
.db
._format
_query
('select %s from %s where %s=%%s limit 2'
68 % (what
,row
.db
.name
,row
.db
.primary_key
),
70 row
.db
.cursor
.execute(sql
, params
)
71 t
= row
.db
.cursor
.fetchmany(2) # get at most two rows
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row
.db
.name
,str(row
.id),what
))
75 return t
[0][0] #return the single field we requested
77 def init_row_subclass(cls
, db
):
78 'add descriptors for db attributes'
79 for attr
in db
.data
: # bind all database columns
80 if attr
== 'id': # handle ID attribute specially
81 setattr(cls
, attr
, cls
._idDescriptor
(db
, attr
))
83 try: # check if this attr maps to an SQL column
84 db
._attrSQL
(attr
, columnNumber
=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls
, attr
, cls
._sqlDescriptor
(db
, attr
))
87 else: # treat as interface to our stored tuple
88 setattr(cls
, attr
, cls
._columnDescriptor
(db
, attr
))
91 """get list of column names as our attributes """
92 return self
.db
.data
.keys()
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor
= TupleDescriptor
115 _idDescriptor
= TupleIDDescriptor
116 _sqlDescriptor
= SQLDescriptor
117 _init_subclass
= classmethod(init_row_subclass
)
118 _select
= select_from_row
120 def __init__(self
, data
):
121 self
._data
= data
# save our data tuple
123 def insert_and_cache_id(self
, l
, **kwargs
):
124 'insert tuple into db and cache its rowID on self'
125 self
.db
._insert
(l
) # save to database
127 rowID
= kwargs
['id'] # use the ID supplied by user
129 rowID
= self
.db
.get_insert_id() # get auto-inc ID value
130 self
.cache_id(rowID
) # cache this ID on self
132 class TupleORW(TupleO
):
133 'read-write version of TupleO'
134 _columnDescriptor
= TupleDescriptorRW
135 insert_and_cache_id
= insert_and_cache_id
136 def __init__(self
, data
, newRow
=False, **kwargs
):
137 if not newRow
: # just cache data from the database
140 self
._data
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
141 self
.insert_and_cache_id(self
._data
, **kwargs
)
142 def cache_id(self
,row_id
):
143 self
.save_local('id',row_id
)
144 def save_local(self
,attr
,val
):
145 icol
= self
._attrcol
[attr
]
147 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self
._data
= list(self
._data
)
150 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
152 TupleO
._RWClass
= TupleORW
# record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self
, db
, attr
, readOnly
= False):
158 self
.col
= db
._attrSQL
(attr
, sqlColumn
=True) # MAP THIS TO SQL COLUMN NAME
161 self
.__class
__ = self
._readOnlyClass
162 def __get__(self
, obj
, objtype
):
164 return obj
.__dict
__[self
.attr
]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self
.col
==self
.db
.primary_key
:
168 self
.db
._select
('where %s=%%s' % self
.db
.primary_key
,(obj
.id,),self
.col
)
169 l
= self
.db
.cursor
.fetchall()
171 raise AttributeError('db row not found or not unique!')
172 obj
.__dict
__[self
.attr
] = l
[0][0] # UPDATE THE CACHE
174 def __set__(self
, obj
, val
):
175 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self
.db
._update
(obj
.id, self
.col
, val
) # UPDATE THE DATABASE
177 obj
.__dict
__[self
.attr
] = val
# UPDATE THE CACHE
179 ## m = self.consequences
180 ## except AttributeError:
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor
):
188 def __set__(self
, obj
, val
):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor
._readOnlyClass
= ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor
= _sqlDescriptor
= SQLDescriptor
200 _idDescriptor
= ReadOnlyDescriptor
201 _init_subclass
= classmethod(init_row_subclass
)
202 _select
= select_from_row
204 def __init__(self
, rowID
):
208 class SQLRowRW(SQLRow
):
209 'read-write version of SQLRow'
210 _columnDescriptor
= SQLDescriptorRW
211 insert_and_cache_id
= insert_and_cache_id
212 def __init__(self
, rowID
, newRow
=False, **kwargs
):
213 if not newRow
: # just cache data from the database
214 return self
.cache_id(rowID
)
215 l
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
216 self
.insert_and_cache_id(l
, **kwargs
)
217 def cache_id(self
, rowID
):
220 SQLRow
._RWClass
= SQLRowRW
224 def list_to_dict(names
, values
):
225 'return dictionary of those named args that are present in values[]'
227 for i
,v
in enumerate(values
):
235 def get_name_cursor(name
=None, **kwargs
):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
240 argList
= name
.split() # TREAT AS WS-SEPARATED LIST
242 name
= argList
[0] # USE 1ST ARG AS TABLE NAME
243 argnames
= ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs
= kwargs
.copy() # a copy we can overwrite
245 kwargs
.update(list_to_dict(argnames
, argList
[1:]))
246 serverInfo
= DBServerInfo(**kwargs
)
247 return name
,serverInfo
.cursor(),serverInfo
249 def mysql_connect(connect
=None, configFile
=None, useStreaming
=False, **args
):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs
= args
.copy() # a copy we can modify
252 if 'user' not in kwargs
and configFile
is None: #Find where config file is
253 osname
= platform
.system()
254 if osname
in('Microsoft', 'Windows'): # Machine is a Windows box
256 try: # handle case where WINDIR not defined by Windows...
257 windir
= os
.environ
['WINDIR']
258 paths
+= [(windir
, 'my.ini'), (windir
, 'my.cnf')]
262 sysdrv
= os
.environ
['SYSTEMDRIVE']
263 paths
+= [(sysdrv
, os
.path
.sep
+ 'my.ini'),
264 (sysdrv
, os
.path
.sep
+ 'my.cnf')]
268 configFile
= get_valid_path(*paths
)
269 else: # treat as normal platform with home directories
270 configFile
= os
.path
.join(os
.path
.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile
= configFile
or os
.path
.join( os
.getcwd(), 'mysql.cnf' )
276 if configFile
and os
.path
.exists(configFile
):
277 kwargs
['read_default_file'] = configFile
278 connect
= None # force it to use MySQLdb
281 connect
= MySQLdb
.connect
282 kwargs
['compress'] = True
283 if useStreaming
: # use server side cursors for scalable result sets
285 from MySQLdb
import cursors
286 kwargs
['cursorclass'] = cursors
.SSCursor
287 except (ImportError, AttributeError):
289 conn
= connect(**kwargs
)
290 cursor
= conn
.cursor()
293 _mysqlMacros
= dict(IGNORE
='ignore', REPLACE
='replace',
294 AUTO_INCREMENT
='AUTO_INCREMENT', SUBSTRING
='substring',
295 SUBSTR_FROM
='FROM', SUBSTR_FOR
='FOR')
297 def mysql_table_schema(self
, analyzeSchema
=True):
298 'retrieve table schema from a MySQL database, save on self'
300 self
._format
_query
= SQLFormatDict(MySQLdb
.paramstyle
, _mysqlMacros
)
301 if not analyzeSchema
:
303 self
.clear_schema() # reset settings and dictionaries
304 self
.cursor
.execute('describe %s' % self
.name
) # get info about columns
305 columns
= self
.cursor
.fetchall()
306 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
307 for icol
,c
in enumerate(columns
):
309 self
.columnName
.append(field
) # list of columns in same order as table
310 if c
[3] == "PRI": # record as primary key
311 if self
.primary_key
is None:
312 self
.primary_key
= field
315 self
.primary_key
.append(field
)
316 except AttributeError:
317 self
.primary_key
= [self
.primary_key
,field
]
318 if c
[1][:3].lower() == 'int':
319 self
.usesIntID
= True
321 self
.usesIntID
= False
323 self
.indexed
[field
] = icol
324 self
.description
[field
] = self
.cursor
.description
[icol
]
325 self
.columnType
[field
] = c
[1] # SQL COLUMN TYPE
327 _sqliteMacros
= dict(IGNORE
='or ignore', REPLACE
='insert or replace',
328 AUTO_INCREMENT
='', SUBSTRING
='substr',
329 SUBSTR_FROM
=',', SUBSTR_FOR
=',')
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
334 import sqlite3
as sqlite
336 from pysqlite2
import dbapi2
as sqlite
339 def sqlite_table_schema(self
, analyzeSchema
=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite
= import_sqlite()
342 self
._format
_query
= SQLFormatDict(sqlite
.paramstyle
, _sqliteMacros
)
343 if not analyzeSchema
:
345 self
.clear_schema() # reset settings and dictionaries
346 self
.cursor
.execute('PRAGMA table_info("%s")' % self
.name
)
347 columns
= self
.cursor
.fetchall()
348 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
349 for icol
,c
in enumerate(columns
):
351 self
.columnName
.append(field
) # list of columns in same order as table
352 self
.description
[field
] = self
.cursor
.description
[icol
]
353 self
.columnType
[field
] = c
[2] # SQL COLUMN TYPE
354 self
.cursor
.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self
.name
) # get primary key / unique indexes
355 for indexname
in self
.cursor
.fetchall(): # search indexes for primary key
356 self
.cursor
.execute('PRAGMA index_info("%s")' % indexname
)
357 l
= self
.cursor
.fetchall() # get list of columns in this index
358 if len(l
) == 1: # assume 1st single-column unique index is primary key!
359 self
.primary_key
= l
[0][2]
360 break # done searching for primary key!
361 if self
.primary_key
is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self
.cursor
.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self
.name
)
363 sql
= self
.cursor
.fetchall()[0][0]
364 for columnSQL
in sql
[sql
.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL
.lower(): # must be the primary key!
366 col
= columnSQL
.split()[0] # get column name
367 if col
in self
.columnType
:
368 self
.primary_key
= col
369 break # done searching for primary key!
371 raise ValueError('unknown primary key %s in table %s'
373 if self
.primary_key
is not None: # check its type
374 if self
.columnType
[self
.primary_key
] == 'int' or \
375 self
.columnType
[self
.primary_key
] == 'integer':
376 self
.usesIntID
= True
378 self
.usesIntID
= False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats
= dict(pyformat
='%%(%d)s', numeric
=':%d', named
=':%d',
393 qmark
='(ignore)', format
='(ignore)')
394 def __init__(self
, paramstyle
, substitutionDict
={}):
395 self
.substitutionDict
= substitutionDict
.copy()
396 self
.paramstyle
= paramstyle
397 self
.paramFormat
= self
._paramFormats
[paramstyle
]
398 self
.makeDict
= (paramstyle
== 'pyformat' or paramstyle
== 'named')
399 if paramstyle
== 'qmark': # handle these as simple substitution
400 self
.substitutionDict
['?'] = '?'
401 elif paramstyle
== 'format':
402 self
.substitutionDict
['?'] = '%s'
403 def __getitem__(self
, k
):
404 'apply correct substitution for this SQL interface'
406 return self
.substitutionDict
[k
] # apply our substitutions
409 if k
== '?': # sequential parameter
410 s
= self
.paramFormat
% self
.iparam
411 self
.iparam
+= 1 # advance to the next parameter
413 raise KeyError('unknown macro: %s' % k
)
414 def __call__(self
, sql
, paramList
):
415 'returns corrected sql,params for this interface'
416 self
.iparam
= 1 # DB-ABI param indexing begins at 1
417 sql
= sql
.replace('%s', '%(?)s') # convert format into pyformat
418 s
= sql
% self
# apply all %(x)s replacements in sql
419 if self
.makeDict
: # construct a params dict
421 for i
,param
in enumerate(paramList
):
422 paramDict
[str(i
+ 1)] = param
#DB-ABI param indexing begins at 1
424 else: # just return the original params list
427 def get_table_schema(self
, analyzeSchema
=True):
428 'run the right schema function based on type of db server connection'
430 modname
= self
.cursor
.__class
__.__module
__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
434 schema_func
= self
._schemaModuleDict
[modname
]
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname
)
439 schema_func(self
, analyzeSchema
) # run the schema function
442 _schemaModuleDict
= {'MySQLdb.cursors':mysql_table_schema
,
443 'pysqlite2.dbapi2':sqlite_table_schema
,
444 'sqlite3':sqlite_table_schema
}
446 class SQLTableBase(object, UserDict
.DictMixin
):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict
= _schemaModuleDict
# default module list
449 get_table_schema
= get_table_schema
450 def __init__(self
,name
,cursor
=None,itemClass
=None,attrAlias
=None,
451 clusterKey
=None,createTable
=None,graph
=None,maxCache
=None,
452 arraysize
=1024, itemSliceClass
=None, dropIfExists
=False,
453 serverInfo
=None, autoGC
=True, orderBy
=None,
454 writeable
=False, iterSQL
=None, iterColumns
=None, **kwargs
):
455 if autoGC
: # automatically garbage collect unused objects
456 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
458 self
._weakValueDict
= {}
460 self
.orderBy
= orderBy
461 if orderBy
and serverInfo
and serverInfo
._serverType
== 'mysql':
462 if iterSQL
and iterColumns
: # both required for mysql!
463 self
.iterSQL
, self
.iterColumns
= iterSQL
, iterColumns
465 raise ValueError('For MySQL tables with orderBy, you MUST specify iterSQL and iterColumns as well!')
467 self
.writeable
= writeable
469 if serverInfo
is not None: # get cursor from serverInfo
470 cursor
= serverInfo
.cursor()
471 else: # try to read connection info from name or config file
472 name
,cursor
,serverInfo
= get_name_cursor(name
,**kwargs
)
474 warnings
.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
475 DeprecationWarning, stacklevel
=2)
477 if createTable
is not None: # RUN COMMAND TO CREATE THIS TABLE
478 if dropIfExists
: # get rid of any existing table
479 cursor
.execute('drop table if exists ' + name
)
480 self
.get_table_schema(False) # check dbtype, init _format_query
481 sql
,params
= self
._format
_query
(createTable
, ()) # apply macros
482 cursor
.execute(sql
) # create the table
484 if graph
is not None:
486 if maxCache
is not None:
487 self
.maxCache
= maxCache
488 if arraysize
is not None:
489 self
.arraysize
= arraysize
490 cursor
.arraysize
= arraysize
491 self
.get_table_schema() # get schema of columns to serve as attrs
492 self
.data
= {} # map of all attributes, including aliases
493 for icol
,field
in enumerate(self
.columnName
):
494 self
.data
[field
] = icol
# 1st add mappings to columns
496 self
.data
['id']=self
.data
[self
.primary_key
]
497 except (KeyError,TypeError):
499 if hasattr(self
,'_attr_alias'): # apply attribute aliases for this class
500 self
.addAttrAlias(False,**self
._attr
_alias
)
501 self
.objclass(itemClass
) # NEED TO SUBCLASS OUR ITEM CLASS
502 if itemSliceClass
is not None:
503 self
.itemSliceClass
= itemSliceClass
504 get_bound_subclass(self
, 'itemSliceClass', self
.name
) # need to subclass itemSliceClass
505 if attrAlias
is not None: # ADD ATTRIBUTE ALIASES
506 self
.attrAlias
= attrAlias
# RECORD FOR PICKLING PURPOSES
507 self
.data
.update(attrAlias
)
508 if clusterKey
is not None:
509 self
.clusterKey
=clusterKey
510 if serverInfo
is not None:
511 self
.serverInfo
= serverInfo
514 self
._select
(selectCols
='count(*)')
515 return self
.cursor
.fetchone()[0]
518 def __cmp__(self
, other
):
519 'only match self and no other!'
523 return cmp(id(self
), id(other
))
524 _pickleAttrs
= dict(name
=0, clusterKey
=0, maxCache
=0, arraysize
=0,
525 attrAlias
=0, serverInfo
=0, autoGC
=0, orderBy
=0,
526 writeable
=0, iterSQL
=0, iterColumns
=0)
527 __getstate__
= standard_getstate
528 def __setstate__(self
,state
):
529 # default cursor provisioning by worldbase is deprecated!
530 ## if 'serverInfo' not in state: # hmm, no address for db server?
531 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
532 ## from Data import getResource
533 ## state['cursor'] = getResource.getTableCursor(state['name'])
534 ## except ImportError:
535 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
536 self
.__init
__(**state
)
538 return '<SQL table '+self
.name
+'>'
540 def clear_schema(self
):
541 'reset all schema information for this table'
545 self
.usesIntID
= None
546 self
.primary_key
= None
548 def _attrSQL(self
,attr
,sqlColumn
=False,columnNumber
=False):
549 "Translate python attribute name to appropriate SQL expression"
550 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
551 field
=self
.data
[attr
]
553 raise AttributeError('attribute %s not a valid column or alias in %s'
555 if sqlColumn
: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
556 try: # CHECK IF field IS COLUMN NUMBER
557 return self
.columnName
[field
] # RETURN SQL COLUMN NAME
559 try: # CHECK IF field IS SQL COLUMN NAME
560 return self
.columnName
[self
.data
[field
]] # THIS WILL JUST RETURN field...
561 except (KeyError,TypeError):
562 raise AttributeError('attribute %s does not map to an SQL column in %s'
565 try: # CHECK IF field IS A COLUMN NUMBER
566 return field
+0 # ONLY RETURN AN INTEGER
568 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
569 return self
.data
[field
]+0 # ONLY RETURN AN INTEGER
570 except (KeyError,TypeError):
571 raise ValueError('attribute %s does not map to a SQL column!' % attr
)
572 if isinstance(field
,types
.StringType
):
573 attr
=field
# USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
575 attr
=self
.primary_key
577 def addAttrAlias(self
,saveToPickle
=True,**kwargs
):
578 """Add new attributes as aliases of existing attributes.
579 They can be specified either as named args:
580 t.addAttrAlias(newattr=oldattr)
581 or by passing a dictionary kwargs whose keys are newattr
582 and values are oldattr:
583 t.addAttrAlias(**kwargs)
584 saveToPickle=True forces these aliases to be saved if object is pickled.
587 self
.attrAlias
.update(kwargs
)
588 for key
,val
in kwargs
.items():
589 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
590 self
.data
[val
]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
591 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
592 except TypeError: # val IS ITSELF AN ALIAS
593 self
.data
[key
] = self
.data
[val
] # SO MAP TO WHAT IT MAPS TO
594 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
596 def objclass(self
,oclass
=None):
597 "Create class representing a row in this table by subclassing oclass, adding data"
598 if oclass
is not None: # use this as our base itemClass
599 self
.itemClass
= oclass
601 self
.itemClass
= self
.itemClass
._RWClass
# use its writeable version
602 oclass
= get_bound_subclass(self
, 'itemClass', self
.name
,
603 subclassArgs
=dict(db
=self
)) # bind itemClass
604 if issubclass(oclass
, TupleO
):
605 oclass
._attrcol
= self
.data
# BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
606 if hasattr(oclass
,'_tableclass') and not isinstance(self
,oclass
._tableclass
):
607 self
.__class
__=oclass
._tableclass
# ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
608 def _select(self
, whereClause
='', params
=(), selectCols
='t1.*',
609 cursor
=None, orderBy
='', limit
=''):
610 'execute the specified query but do not fetch'
611 sql
,params
= self
._format
_query
('select %s from %s t1 %s %s %s'
612 % (selectCols
, self
.name
, whereClause
, orderBy
,
615 self
.cursor
.execute(sql
, params
)
617 cursor
.execute(sql
, params
)
618 def select(self
,whereClause
,params
=None,oclass
=None,selectCols
='t1.*'):
619 "Generate the list of objects that satisfy the database SELECT"
621 oclass
=self
.itemClass
622 self
._select
(whereClause
,params
,selectCols
)
623 l
=self
.cursor
.fetchall()
625 yield self
.cacheItem(t
,oclass
)
626 def query(self
,**kwargs
):
627 'query for intersection of all specified kwargs, returned as iterator'
630 for k
,v
in kwargs
.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
631 if v
is None: # CONVERT TO SQL NULL TEST
632 criteria
.append('%s IS NULL' % self
._attrSQL
(k
))
633 else: # TEST FOR EQUALITY
634 criteria
.append('%s=%%s' % self
._attrSQL
(k
))
636 return self
.select('where '+' and '.join(criteria
),params
)
637 def _update(self
,row_id
,col
,val
):
638 'update a single field in the specified row to the specified value'
639 sql
,params
= self
._format
_query
('update %s set %s=%%s where %s=%%s'
640 %(self
.name
,col
,self
.primary_key
),
642 self
.cursor
.execute(sql
, params
)
645 return t
[self
.data
['id']] # GET ID FROM TUPLE
646 except TypeError: # treat as alias
647 return t
[self
.data
[self
.data
['id']]]
648 def cacheItem(self
,t
,oclass
):
649 'get obj from cache if possible, or construct from tuple'
652 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
654 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
655 return self
._weakValueDict
[id]
659 self
._weakValueDict
[id] = o
# CACHE THIS ITEM IN OUR DICTIONARY
661 def cache_items(self
,rows
,oclass
=None):
663 oclass
=self
.itemClass
665 yield self
.cacheItem(t
,oclass
)
666 def foreignKey(self
,attr
,k
):
667 'get iterator for objects with specified foreign key value'
668 return self
.select('where %s=%%s'%attr
,(k
,))
669 def limit_cache(self
):
670 'APPLY maxCache LIMIT TO CACHE SIZE'
672 if self
.maxCache
<len(self
._weakValueDict
):
673 self
._weakValueDict
.clear()
674 except AttributeError:
677 def get_new_cursor(self
):
678 """Return a new cursor object, or None if not possible """
680 new_cursor
= self
.serverInfo
.new_cursor
681 except AttributeError:
683 return new_cursor(self
.arraysize
)
685 def generic_iterator(self
, cursor
=None, fetch_f
=None, cache_f
=None,
686 map_f
=iter, cursorHolder
=None):
687 """generic iterator that runs fetch, cache and map functions.
688 cursorHolder is used only to keep a ref in this function's locals,
689 so that if it is prematurely terminated (by deleting its
690 iterator), cursorHolder.__del__() will close the cursor."""
691 if fetch_f
is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
693 fetch_f
= self
.cursor
.fetchmany
694 else: # isolate this iter from other queries
695 fetch_f
= cursor
.fetchmany
697 cache_f
= self
.cache_items
700 rows
= fetch_f() # FETCH THE NEXT SET OF ROWS
701 if len(rows
)==0: # NO MORE DATA SO ALL DONE
703 for v
in map_f(cache_f(rows
)): # CACHE AND GENERATE RESULTS
705 def tuple_from_dict(self
, d
):
706 'transform kwarg dict into tuple for storing in database'
707 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
708 for col
,icol
in self
.data
.items():
711 except (KeyError,TypeError):
714 def tuple_from_obj(self
, obj
):
715 'transform object attributes into tuple for storing in database'
716 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
717 for col
,icol
in self
.data
.items():
719 l
[icol
] = getattr(obj
,col
)
720 except (AttributeError,TypeError):
723 def _insert(self
, l
):
724 '''insert tuple into the database. Note this uses the MySQL
725 extension REPLACE, which overwrites any duplicate key.'''
726 s
= '%(REPLACE)s into ' + self
.name
+ ' values (' \
727 + ','.join(['%s']*len(l
)) + ')'
728 sql
,params
= self
._format
_query
(s
, l
)
729 self
.cursor
.execute(sql
, params
)
730 def insert(self
, obj
):
731 '''insert new row by transforming obj to tuple of values'''
732 l
= self
.tuple_from_obj(obj
)
734 def get_insert_id(self
):
735 'get the primary key value for the last INSERT'
736 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
737 auto_id
= self
.cursor
.lastrowid
738 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
739 raise NotImplementedError('''your db lacks lastrowid support?''')
741 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
743 def new(self
, **kwargs
):
744 'return a new record with the assigned attributes, added to DB'
745 if not self
.writeable
:
746 raise ValueError('this database is read only!')
747 obj
= self
.itemClass(None, newRow
=True, **kwargs
) # saves itself to db
748 self
._weakValueDict
[obj
.id] = obj
# AND SAVE TO OUR LOCAL DICT CACHE
750 def clear_cache(self
):
752 self
._weakValueDict
.clear()
753 def __delitem__(self
, k
):
754 if not self
.writeable
:
755 raise ValueError('this database is read only!')
756 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
757 % (self
.name
,self
.primary_key
),(k
,))
758 self
.cursor
.execute(sql
, params
)
760 del self
._weakValueDict
[k
]
764 def getKeys(self
,queryOption
='', selectCols
=None):
765 'uses db select; does not force load'
766 if selectCols
is None:
767 selectCols
=self
.primary_key
768 if queryOption
=='' and self
.orderBy
is not None:
769 queryOption
= self
.orderBy
# apply default ordering
770 self
.cursor
.execute('select %s from %s %s'
771 %(selectCols
,self
.name
,queryOption
))
772 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
774 def iter_keys(self
, selectCols
=None, orderBy
='', map_f
=iter,
775 cache_f
=lambda x
:[t
[0] for t
in x
], get_f
=None, **kwargs
):
776 'guarantee correct iteration insulated from other queries'
777 if selectCols
is None:
778 selectCols
=self
.primary_key
779 if orderBy
=='' and self
.orderBy
is not None:
780 orderBy
= self
.orderBy
# apply default ordering
781 cursor
= self
.get_new_cursor()
782 if cursor
: # got our own cursor, guaranteeing query isolation
783 if hasattr(self
.serverInfo
, 'iter_keys') \
784 and self
.serverInfo
.custom_iter_keys
:
785 # use custom iter_keys() method from serverInfo
786 return self
.serverInfo
.iter_keys(self
, cursor
, selectCols
=selectCols
,
787 map_f
=map_f
, orderBy
=orderBy
,
788 cache_f
=cache_f
, **kwargs
)
790 self
._select
(cursor
=cursor
, selectCols
=selectCols
,
791 orderBy
=orderBy
, **kwargs
)
792 return self
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
,
794 cursorHolder
=CursorCloser(cursor
))
795 else: # must pre-fetch all keys to ensure query isolation
796 if get_f
is not None:
799 return iter(self
.keys())
801 class SQLTable(SQLTableBase
):
802 "Provide on-the-fly access to rows in the database, caching the results in dict"
803 itemClass
= TupleO
# our default itemClass; constructor can override
806 def load(self
,oclass
=None):
807 "Load all data from the table"
808 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
809 return self
._isLoaded
810 except AttributeError:
813 oclass
=self
.itemClass
814 self
.cursor
.execute('select * from %s' % self
.name
)
815 l
=self
.cursor
.fetchall()
816 self
._weakValueDict
= {} # just store the whole dataset in memory
818 self
.cacheItem(t
,oclass
) # CACHE IT IN LOCAL DICTIONARY
819 self
._isLoaded
=True # MARK THIS CONTAINER AS FULLY LOADED
821 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
823 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
824 except KeyError: # NOT FOUND, SO TRY THE DATABASE
825 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 2'
826 % (self
.name
,self
.primary_key
),(k
,))
827 self
.cursor
.execute(sql
, params
)
828 l
= self
.cursor
.fetchmany(2) # get at most 2 rows
830 raise KeyError('%s not found in %s, or not unique' %(str(k
),self
.name
))
832 return self
.cacheItem(l
[0],self
.itemClass
) # CACHE IT IN LOCAL DICTIONARY
833 def __setitem__(self
, k
, v
):
834 if not self
.writeable
:
835 raise ValueError('this database is read only!')
839 except AttributeError:
840 raise ValueError('object not bound to itemClass for this db!')
845 except AttributeError:
847 else: # delete row with old ID
849 v
.cache_id(k
) # cache the new ID on the object
850 self
.insert(v
) # SAVE TO THE RELATIONAL DB SERVER
851 self
._weakValueDict
[k
] = v
# CACHE THIS ITEM IN OUR DICTIONARY
853 'forces load of entire table into memory'
855 return [(k
,self
[k
]) for k
in self
] # apply orderBy rules...
857 'uses arraysize / maxCache and fetchmany() to manage data transfer'
858 return iter_keys(self
, selectCols
='*', cache_f
=None,
859 map_f
=generate_items
, get_f
=self
.items
)
861 'forces load of entire table into memory'
863 return [self
[k
] for k
in self
] # apply orderBy rules...
864 def itervalues(self
):
865 'uses arraysize / maxCache and fetchmany() to manage data transfer'
866 return iter_keys(self
, selectCols
='*', cache_f
=None, get_f
=self
.values
)
868 def getClusterKeys(self
,queryOption
=''):
869 'uses db select; does not force load'
870 self
.cursor
.execute('select distinct %s from %s %s'
871 %(self
.clusterKey
,self
.name
,queryOption
))
872 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
875 class SQLTableClustered(SQLTable
):
876 '''use clusterKey to load a whole cluster of rows at once,
877 specifically, all rows that share the same clusterKey value.'''
878 def __init__(self
, *args
, **kwargs
):
879 kwargs
= kwargs
.copy() # get a copy we can alter
880 kwargs
['autoGC'] = False # don't use WeakValueDictionary
881 SQLTable
.__init
__(self
, *args
, **kwargs
)
883 return getKeys(self
,'order by %s' %self
.clusterKey
)
884 def clusterkeys(self
):
885 return getClusterKeys(self
, 'order by %s' %self
.clusterKey
)
886 def __getitem__(self
,k
):
888 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
889 except KeyError: # NOT FOUND, SO TRY THE DATABASE
890 sql
,params
= self
._format
_query
('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
891 % (self
.name
,self
.name
,self
.primary_key
,
892 self
.clusterKey
,self
.clusterKey
),(k
,))
893 self
.cursor
.execute(sql
, params
)
894 l
=self
.cursor
.fetchall()
896 for t
in l
: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
897 self
.cacheItem(t
,self
.itemClass
)
898 return self
._weakValueDict
[k
] # should be in cache, if row k exists
899 def itercluster(self
,cluster_id
):
900 'iterate over all items from the specified cluster'
902 return self
.select('where %s=%%s'%self
.clusterKey
,(cluster_id
,))
903 def fetch_cluster(self
):
904 'use self.cursor.fetchmany to obtain all rows for next cluster'
905 icol
= self
._attrSQL
(self
.clusterKey
,columnNumber
=True)
908 rows
= self
._fetch
_cluster
_cache
# USE SAVED ROWS FROM PREVIOUS CALL
909 del self
._fetch
_cluster
_cache
910 except AttributeError:
911 rows
= self
.cursor
.fetchmany()
913 cluster_id
= rows
[0][icol
]
917 for i
,t
in enumerate(rows
): # CHECK THAT ALL ROWS FROM THIS CLUSTER
918 if cluster_id
!= t
[icol
]: # START OF A NEW CLUSTER
919 result
+= rows
[:i
] # RETURN ROWS OF CURRENT CLUSTER
920 self
._fetch
_cluster
_cache
= rows
[i
:] # SAVE NEXT CLUSTER
923 rows
= self
.cursor
.fetchmany() # GET NEXT SET OF ROWS
925 def itervalues(self
):
926 'uses arraysize / maxCache and fetchmany() to manage data transfer'
927 cursor
= self
.get_new_cursor()
928 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
929 return self
.generic_iterator(cursor
, self
.fetch_cluster
,
930 cursorHolder
=CursorHolder(cursor
))
932 'uses arraysize / maxCache and fetchmany() to manage data transfer'
933 cursor
= self
.get_new_cursor()
934 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
935 return self
.generic_iterator(cursor
, self
.fetch_cluster
,
936 map_f
=generate_items
,
937 cursorHolder
=CursorHolder(cursor
))
939 class SQLForeignRelation(object):
940 'mapping based on matching a foreign key in an SQL table'
941 def __init__(self
,table
,keyName
):
944 def __getitem__(self
,k
):
945 'get list of objects o with getattr(o,keyName)==k.id'
947 for o
in self
.table
.select('where %s=%%s'%self
.keyName
,(k
.id,)):
950 raise KeyError('%s not found in %s' %(str(k
),self
.name
))
954 class SQLTableNoCache(SQLTableBase
):
955 '''Provide on-the-fly access to rows in the database;
956 values are simply an object interface (SQLRow) to back-end db query.
957 Row data are not stored locally, but always accessed by querying the db'''
958 itemClass
=SQLRow
# DEFAULT OBJECT CLASS FOR ROWS...
961 def getID(self
,t
): return t
[0] # GET ID FROM TUPLE
962 def select(self
,whereClause
,params
):
963 return SQLTableBase
.select(self
,whereClause
,params
,self
.oclass
,
965 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
967 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
968 except KeyError: # NOT FOUND, SO TRY THE DATABASE
969 self
._select
('where %s=%%s' % self
.primary_key
, (k
,),
971 t
= self
.cursor
.fetchmany(2)
973 raise KeyError('id %s non-existent or not unique' % k
)
974 o
= self
.itemClass(k
) # create obj referencing this ID
975 self
._weakValueDict
[k
] = o
# cache the SQLRow object
977 def __setitem__(self
, k
, v
):
978 if not self
.writeable
:
979 raise ValueError('this database is read only!')
983 except AttributeError:
984 raise ValueError('object not bound to itemClass for this db!')
986 del self
[k
] # delete row with new ID if any
990 del self
._weakValueDict
[v
.id] # delete from old cache location
993 self
._update
(v
.id, self
.primary_key
, k
) # just change its ID in db
994 v
.cache_id(k
) # change the cached ID value
995 self
._weakValueDict
[k
] = v
# assign to new cache location
996 def addAttrAlias(self
,**kwargs
):
997 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
999 SQLRow
._tableclass
=SQLTableNoCache
# SQLRow IS FOR NON-CACHING TABLE INTERFACE
1002 class SQLTableMultiNoCache(SQLTableBase
):
1003 "Trivial on-the-fly access for table with key that returns multiple rows"
1004 itemClass
= TupleO
# default itemClass; constructor can override
1005 _distinct_key
='id' # DEFAULT COLUMN TO USE AS KEY
1006 def __init__(self
, *args
, **kwargs
):
1007 SQLTableBase
.__init
__(self
, *args
, **kwargs
)
1008 self
.distinct_key
= self
._attrSQL
(self
._distinct
_key
)
1009 if not self
.orderBy
:
1010 self
.orderBy
= 'GROUP BY %s ORDER BY %s' % (self
.distinct_key
,
1012 self
.iterSQL
= 'WHERE %s>%%s' % self
.distinct_key
1013 self
.iterColumns
= (self
.distinct_key
,)
1015 return getKeys(self
, selectCols
=self
.distinct_key
)
1017 return iter_keys(self
, selectCols
=self
.distinct_key
)
1018 def __getitem__(self
,id):
1019 sql
,params
= self
._format
_query
('select * from %s where %s=%%s'
1020 %(self
.name
,self
._attrSQL
(self
._distinct
_key
)),(id,))
1021 self
.cursor
.execute(sql
, params
)
1022 l
=self
.cursor
.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1024 yield self
.itemClass(row
)
1025 def addAttrAlias(self
,**kwargs
):
1026 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
1030 class SQLEdges(SQLTableMultiNoCache
):
1031 '''provide iterator over edges as (source,target,edge)
1032 and getitem[edge] --> [(source,target),...]'''
1033 _distinct_key
='edge_id'
1034 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1035 _pickleAttrs
.update(dict(graph
=0))
1037 self
.cursor
.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1038 %(self
._attrSQL
('source_id'),self
._attrSQL
('target_id'),
1039 self
._attrSQL
('edge_id'),self
.name
,
1040 self
._attrSQL
('target_id'),self
._attrSQL
('source_id'),
1041 self
._attrSQL
('target_id')))
1042 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1043 for source_id
,target_id
,edge_id
in self
.cursor
.fetchall():
1044 l
.append((self
.graph
.unpack_source(source_id
),
1045 self
.graph
.unpack_target(target_id
),
1046 self
.graph
.unpack_edge(edge_id
)))
1050 return iter(self
.keys())
1051 def __getitem__(self
,edge
):
1052 sql
,params
= self
._format
_query
('select %s,%s from %s where %s=%%s'
1053 %(self
._attrSQL
('source_id'),
1054 self
._attrSQL
('target_id'),
1056 self
._attrSQL
(self
._distinct
_key
)),
1057 (self
.graph
.pack_edge(edge
),))
1058 self
.cursor
.execute(sql
, params
)
1059 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1060 for source_id
,target_id
in self
.cursor
.fetchall():
1061 l
.append((self
.graph
.unpack_source(source_id
),
1062 self
.graph
.unpack_target(target_id
)))
1066 class SQLEdgeDict(object):
1067 '2nd level graph interface to SQL database'
1068 def __init__(self
,fromNode
,table
):
1069 self
.fromNode
=fromNode
1071 if not hasattr(self
.table
,'allowMissingNodes'):
1072 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s limit 1'
1073 %(self
.table
.sourceSQL
,
1075 self
.table
.sourceSQL
),
1077 self
.table
.cursor
.execute(sql
, params
)
1078 if len(self
.table
.cursor
.fetchall())<1:
1079 raise KeyError('node not in graph!')
1081 def __getitem__(self
,target
):
1082 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1083 %(self
.table
.edgeSQL
,
1085 self
.table
.sourceSQL
,
1086 self
.table
.targetSQL
),
1088 self
.table
.pack_target(target
)))
1089 self
.table
.cursor
.execute(sql
, params
)
1090 l
= self
.table
.cursor
.fetchmany(2) # get at most two rows
1092 raise KeyError('either no edge from source to target or not unique!')
1094 return self
.table
.unpack_edge(l
[0][0]) # RETURN EDGE
1096 raise KeyError('no edge from node to target')
1097 def __setitem__(self
,target
,edge
):
1098 sql
,params
= self
.table
._format
_query
('replace into %s values (%%s,%%s,%%s)'
1101 self
.table
.pack_target(target
),
1102 self
.table
.pack_edge(edge
)))
1103 self
.table
.cursor
.execute(sql
, params
)
1104 if not hasattr(self
.table
,'sourceDB') or \
1105 (hasattr(self
.table
,'targetDB') and
1106 self
.table
.sourceDB
is self
.table
.targetDB
):
1107 self
.table
+= target
# ADD AS NODE TO GRAPH
1108 def __iadd__(self
,target
):
1110 return self
# iadd MUST RETURN self!
1111 def __delitem__(self
,target
):
1112 sql
,params
= self
.table
._format
_query
('delete from %s where %s=%%s and %s=%%s'
1114 self
.table
.sourceSQL
,
1115 self
.table
.targetSQL
),
1117 self
.table
.pack_target(target
)))
1118 self
.table
.cursor
.execute(sql
, params
)
1119 if self
.table
.cursor
.rowcount
< 1: # no rows deleted?
1120 raise KeyError('no edge from node to target')
1122 def iterator_query(self
):
1123 sql
,params
= self
.table
._format
_query
('select %s,%s from %s where %s=%%s and %s is not null'
1124 %(self
.table
.targetSQL
,
1127 self
.table
.sourceSQL
,
1128 self
.table
.targetSQL
),
1130 self
.table
.cursor
.execute(sql
, params
)
1131 return self
.table
.cursor
.fetchall()
1133 return [self
.table
.unpack_target(target_id
)
1134 for target_id
,edge_id
in self
.iterator_query()]
1136 return [self
.table
.unpack_edge(edge_id
)
1137 for target_id
,edge_id
in self
.iterator_query()]
1139 return [(self
.table
.unpack_source(self
.fromNode
),self
.table
.unpack_target(target_id
),
1140 self
.table
.unpack_edge(edge_id
))
1141 for target_id
,edge_id
in self
.iterator_query()]
1143 return [(self
.table
.unpack_target(target_id
),self
.table
.unpack_edge(edge_id
))
1144 for target_id
,edge_id
in self
.iterator_query()]
1145 def __iter__(self
): return iter(self
.keys())
1146 def itervalues(self
): return iter(self
.values())
1147 def iteritems(self
): return iter(self
.items())
1149 return len(self
.keys())
1152 class SQLEdgelessDict(SQLEdgeDict
):
1153 'for SQLGraph tables that lack edge_id column'
1154 def __getitem__(self
,target
):
1155 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1156 %(self
.table
.targetSQL
,
1158 self
.table
.sourceSQL
,
1159 self
.table
.targetSQL
),
1161 self
.table
.pack_target(target
)))
1162 self
.table
.cursor
.execute(sql
, params
)
1163 l
= self
.table
.cursor
.fetchmany(2)
1165 raise KeyError('either no edge from source to target or not unique!')
1166 return None # no edge info!
1167 def iterator_query(self
):
1168 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s is not null'
1169 %(self
.table
.targetSQL
,
1171 self
.table
.sourceSQL
,
1172 self
.table
.targetSQL
),
1174 self
.table
.cursor
.execute(sql
, params
)
1175 return [(t
[0],None) for t
in self
.table
.cursor
.fetchall()]
1177 SQLEdgeDict
._edgelessClass
= SQLEdgelessDict
1179 class SQLGraphEdgeDescriptor(object):
1180 'provide an SQLEdges interface on demand'
1181 def __get__(self
,obj
,objtype
):
1183 attrAlias
=obj
.attrAlias
.copy()
1184 except AttributeError:
1185 return SQLEdges(obj
.name
, obj
.cursor
, graph
=obj
)
1187 return SQLEdges(obj
.name
, obj
.cursor
, attrAlias
=attrAlias
,
1190 def getColumnTypes(createTable
,attrAlias
={},defaultColumnType
='int',
1191 columnAttrs
=('source','target','edge'),**kwargs
):
1192 'return list of [(colname,coltype),...] for source,target,edge'
1194 for attr
in columnAttrs
:
1196 attrName
= attrAlias
[attr
+'_id']
1198 attrName
= attr
+'_id'
1199 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1200 l
.append((attrName
,createTable
[attr
+'_id']))
1202 except (KeyError,TypeError):
1204 try: # get type info from primary key for that database
1205 db
= kwargs
[attr
+'DB']
1207 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1210 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1212 try: # GET ONE IDENTIFIER FROM THE DATABASE
1214 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1216 l
.append((attrName
,db
.columnType
[db
.primary_key
]))
1218 except AttributeError:
1220 else: # GET THE TYPE FROM THIS IDENTIFIER
1221 if isinstance(k
,int) or isinstance(k
,long):
1222 l
.append((attrName
,'int'))
1224 elif isinstance(k
,str):
1225 l
.append((attrName
,'varchar(32)'))
1228 raise ValueError('SQLGraph node / edge must be int or str!')
1229 l
.append((attrName
,defaultColumnType
))
1230 logger
.warn('no type info found for %s, so using default: %s'
1231 % (attrName
, defaultColumnType
))
1237 class SQLGraph(SQLTableMultiNoCache
):
1238 '''provide a graph interface via a SQL table. Key capabilities are:
1239 - setitem with an empty dictionary: a dummy operation
1240 - getitem with a key that exists: return a placeholder
1241 - setitem with non empty placeholder: again a dummy operation
1242 EXAMPLE TABLE SCHEMA:
1243 create table mygraph (source_id int not null,target_id int,edge_id int,
1244 unique(source_id,target_id));
1246 _distinct_key
='source_id'
1247 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1248 _pickleAttrs
.update(dict(sourceDB
=0,targetDB
=0,edgeDB
=0,allowMissingNodes
=0))
1249 _edgeClass
= SQLEdgeDict
1250 def __init__(self
,name
,*l
,**kwargs
):
1251 graphArgs
,tableArgs
= split_kwargs(kwargs
,
1252 ('attrAlias','defaultColumnType','columnAttrs',
1253 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1254 'edgeDictClass','graph'))
1255 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1256 c
= getColumnTypes(**kwargs
)
1257 tableArgs
['createTable'] = \
1258 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1259 % (name
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1261 self
.allowMissingNodes
= kwargs
['allowMissingNodes']
1262 except KeyError: pass
1263 SQLTableMultiNoCache
.__init
__(self
,name
,*l
,**tableArgs
)
1264 self
.sourceSQL
= self
._attrSQL
('source_id')
1265 self
.targetSQL
= self
._attrSQL
('target_id')
1267 self
.edgeSQL
= self
._attrSQL
('edge_id')
1268 except AttributeError:
1270 self
._edgeClass
= self
._edgeClass
._edgelessClass
1271 save_graph_db_refs(self
,**kwargs
)
1272 def __getitem__(self
,k
):
1273 return self
._edgeClass
(self
.pack_source(k
),self
)
1274 def __iadd__(self
,k
):
1275 sql
,params
= self
._format
_query
('delete from %s where %s=%%s and %s is null'
1276 % (self
.name
,self
.sourceSQL
,self
.targetSQL
),
1277 (self
.pack_source(k
),))
1278 self
.cursor
.execute(sql
, params
)
1279 sql
,params
= self
._format
_query
('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1280 % self
.name
,(self
.pack_source(k
),))
1281 self
.cursor
.execute(sql
, params
)
1282 return self
# iadd MUST RETURN SELF!
1283 def __isub__(self
,k
):
1284 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
1285 % (self
.name
,self
.sourceSQL
),
1286 (self
.pack_source(k
),))
1287 self
.cursor
.execute(sql
, params
)
1288 if self
.cursor
.rowcount
== 0:
1289 raise KeyError('node not found in graph')
1290 return self
# iadd MUST RETURN SELF!
1291 __setitem__
= graph_setitem
1292 def __contains__(self
,k
):
1293 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 1'
1294 %(self
.name
,self
.sourceSQL
),
1295 (self
.pack_source(k
),))
1296 self
.cursor
.execute(sql
, params
)
1297 l
= self
.cursor
.fetchmany(2)
1299 def __invert__(self
):
1300 'get an interface to the inverse graph mapping'
1302 return self
._inverse
1303 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1304 attrAlias
= dict(source_id
=self
.targetSQL
, # SWAP SOURCE & TARGET
1305 target_id
=self
.sourceSQL
,
1306 edge_id
=self
.edgeSQL
)
1307 if self
.edgeSQL
is None: # no edge interface
1308 del attrAlias
['edge_id']
1309 self
._inverse
=SQLGraph(self
.name
,self
.cursor
,
1310 attrAlias
=attrAlias
,
1311 **graph_db_inverse_refs(self
))
1312 self
._inverse
._inverse
=self
1313 return self
._inverse
1315 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1316 yield self
.unpack_source(k
)
1317 def iteritems(self
):
1318 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1319 yield (self
.unpack_source(k
), self
._edgeClass
(k
, self
))
1320 def itervalues(self
):
1321 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1322 yield self
._edgeClass
(k
, self
)
1324 return [self
.unpack_source(k
) for k
in SQLTableMultiNoCache
.keys(self
)]
1325 def values(self
): return list(self
.itervalues())
1326 def items(self
): return list(self
.iteritems())
1327 edges
=SQLGraphEdgeDescriptor()
1328 update
= update_graph
1330 'get number of source nodes in graph'
1331 self
.cursor
.execute('select count(distinct %s) from %s'
1332 %(self
.sourceSQL
,self
.name
))
1333 return self
.cursor
.fetchone()[0]
1335 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1336 ## def __cmp__(self,other):
1340 ## it = iter(self.edges)
1343 ## source,target,edge = it.next()
1344 ## except StopIteration:
1347 ## if d is not None:
1348 ## diff = cmp(n_target,len(d))
1351 ## if source is None:
1354 ## n += 1 # COUNT SOURCE NODES
1361 ## diff = cmp(edge,d[target])
1366 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1367 ## return cmp(n,len(other))
1369 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1371 class SQLIDGraph(SQLGraph
):
1372 add_trivial_packing_methods(locals())
1373 SQLGraph
._IDGraphClass
= SQLIDGraph
1377 class SQLEdgeDictClustered(dict):
1378 'simple cache for 2nd level dictionary of target_id:edge_id'
1379 def __init__(self
,g
,fromNode
):
1381 self
.fromNode
=fromNode
1383 def __iadd__(self
,l
):
1384 for target_id
,edge_id
in l
:
1385 dict.__setitem
__(self
,target_id
,edge_id
)
1386 return self
# iadd MUST RETURN SELF!
1388 class SQLEdgesClusteredDescr(object):
1389 def __get__(self
,obj
,objtype
):
1390 e
=SQLEdgesClustered(obj
.table
,obj
.edge_id
,obj
.source_id
,obj
.target_id
,
1391 graph
=obj
,**graph_db_inverse_refs(obj
,True))
1392 for source_id
,d
in obj
.d
.iteritems(): # COPY EDGE CACHE
1393 e
.load([(edge_id
,source_id
,target_id
)
1394 for (target_id
,edge_id
) in d
.iteritems()])
1397 class SQLGraphClustered(object):
1398 'SQL graph with clustered caching -- loads an entire cluster at a time'
1399 _edgeDictClass
=SQLEdgeDictClustered
1400 def __init__(self
,table
,source_id
='source_id',target_id
='target_id',
1401 edge_id
='edge_id',clusterKey
=None,**kwargs
):
1403 if isinstance(table
,types
.StringType
): # CREATE THE TABLE INTERFACE
1404 if clusterKey
is None:
1405 raise ValueError('you must provide a clusterKey argument!')
1406 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1407 c
= getColumnTypes(attrAlias
=dict(source_id
=source_id
,target_id
=target_id
,
1408 edge_id
=edge_id
),**kwargs
)
1409 kwargs
['createTable'] = \
1410 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1411 % (table
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],
1412 c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1413 table
= SQLTableClustered(table
,clusterKey
=clusterKey
,**kwargs
)
1415 self
.source_id
=source_id
1416 self
.target_id
=target_id
1417 self
.edge_id
=edge_id
1419 save_graph_db_refs(self
,**kwargs
)
1420 _pickleAttrs
= dict(table
=0,source_id
=0,target_id
=0,edge_id
=0,sourceDB
=0,targetDB
=0,
1422 def __getstate__(self
):
1423 state
= standard_getstate(self
)
1424 state
['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1426 def __getitem__(self
,k
):
1427 'get edgeDict for source node k, from cache or by loading its cluster'
1428 try: # GET DIRECTLY FROM CACHE
1431 if hasattr(self
,'_isLoaded'):
1432 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1433 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1434 sql
,params
= self
.table
._format
_query
('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1435 %(self
.source_id
,self
.target_id
,
1436 self
.edge_id
,self
.table
.name
,
1437 self
.table
.name
,self
.source_id
,
1438 self
.table
.clusterKey
,self
.table
.clusterKey
,
1439 self
.table
.primary_key
),
1440 (self
.pack_source(k
),))
1441 self
.table
.cursor
.execute(sql
, params
)
1442 self
.load(self
.table
.cursor
.fetchall()) # CACHE THIS CLUSTER
1443 return self
.d
[k
] # RETURN EDGE DICT FOR THIS NODE
1444 def load(self
,l
=None,unpack
=True):
1445 'load the specified rows (or all, if None provided) into local cache'
1447 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1448 return self
._isLoaded
1449 except AttributeError:
1451 self
.table
.cursor
.execute('select %s,%s,%s from %s'
1452 %(self
.source_id
,self
.target_id
,
1453 self
.edge_id
,self
.table
.name
))
1454 l
=self
.table
.cursor
.fetchall()
1456 self
.d
.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1457 for source
,target
,edge
in l
: # SAVE TO OUR CACHE
1459 source
= self
.unpack_source(source
)
1460 target
= self
.unpack_target(target
)
1461 edge
= self
.unpack_edge(edge
)
1463 self
.d
[source
] += [(target
,edge
)]
1465 d
= self
._edgeDictClass
(self
,source
)
1466 d
+= [(target
,edge
)]
1468 def __invert__(self
):
1469 'interface to reverse graph mapping'
1471 return self
._inverse
# INVERSE MAP ALREADY EXISTS
1472 except AttributeError:
1474 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1475 self
._inverse
=SQLGraphClustered(self
.table
,self
.target_id
,self
.source_id
,
1476 self
.edge_id
,**graph_db_inverse_refs(self
))
1477 self
._inverse
._inverse
=self
1478 for source
,d
in self
.d
.iteritems(): # INVERT OUR CACHE
1479 self
._inverse
.load([(target
,source
,edge
)
1480 for (target
,edge
) in d
.iteritems()],unpack
=False)
1481 return self
._inverse
1482 edges
=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1483 update
= update_graph
1484 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1485 def __iter__(self
): ################# ITERATORS
1486 'uses db select; does not force load'
1487 return iter(self
.keys())
1489 'uses db select; does not force load'
1490 self
.table
.cursor
.execute('select distinct(%s) from %s'
1491 %(self
.source_id
,self
.table
.name
))
1492 return [self
.unpack_source(t
[0])
1493 for t
in self
.table
.cursor
.fetchall()]
1494 methodFactory(['iteritems','items','itervalues','values'],
1495 'lambda self:(self.load(),self.d.%s())[1]',locals())
1496 def __contains__(self
,k
):
1503 class SQLIDGraphClustered(SQLGraphClustered
):
1504 add_trivial_packing_methods(locals())
1505 SQLGraphClustered
._IDGraphClass
= SQLIDGraphClustered
1507 class SQLEdgesClustered(SQLGraphClustered
):
1508 'edges interface for SQLGraphClustered'
1509 _edgeDictClass
= list
1510 _pickleAttrs
= SQLGraphClustered
._pickleAttrs
.copy()
1511 _pickleAttrs
.update(dict(graph
=0))
1515 for edge_id
,l
in self
.d
.iteritems():
1516 for source_id
,target_id
in l
:
1517 result
.append((self
.graph
.unpack_source(source_id
),
1518 self
.graph
.unpack_target(target_id
),
1519 self
.graph
.unpack_edge(edge_id
)))
1522 class ForeignKeyInverse(object):
1523 'map each key to a single value according to its foreign key'
1524 def __init__(self
,g
):
1526 def __getitem__(self
,obj
):
1528 source_id
= getattr(obj
,self
.g
.keyColumn
)
1529 if source_id
is None:
1531 return self
.g
.sourceDB
[source_id
]
1532 def __setitem__(self
,obj
,source
):
1534 if source
is not None:
1535 self
.g
[source
][obj
] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1536 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1537 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1538 old_source
= self
[obj
]
1539 if old_source
is not None:
1540 del self
.g
[old_source
][obj
]
1541 def check_obj(self
,obj
):
1542 'raise KeyError if obj not from this db'
1544 if obj
.db
is not self
.g
.targetDB
:
1545 raise AttributeError
1546 except AttributeError:
1547 raise KeyError('key is not from targetDB of this graph!')
1548 def __contains__(self
,obj
):
1555 return self
.g
.targetDB
.itervalues()
1557 return self
.g
.targetDB
.values()
1558 def iteritems(self
):
1560 source_id
= getattr(obj
,self
.g
.keyColumn
)
1561 if source_id
is None:
1564 yield obj
,self
.g
.sourceDB
[source_id
]
1566 return list(self
.iteritems())
1567 def itervalues(self
):
1568 for obj
,val
in self
.iteritems():
1571 return list(self
.itervalues())
1572 def __invert__(self
):
1575 class ForeignKeyEdge(dict):
1576 '''edge interface to a foreign key in an SQL table.
1577 Caches dict of target nodes in itself; provides dict interface.
1578 Adds or deletes edges by setting foreign key values in the table'''
1579 def __init__(self
,g
,k
):
1583 for v
in g
.targetDB
.select('where %s=%%s' % g
.keyColumn
,(k
.id,)): # SEARCH THE DB
1584 dict.__setitem
__(self
,v
,None) # SAVE IN CACHE
1585 def __setitem__(self
,dest
,v
):
1586 if not hasattr(dest
,'db') or dest
.db
is not self
.g
.targetDB
:
1587 raise KeyError('dest is not in the targetDB bound to this graph!')
1589 raise ValueError('sorry,this graph cannot store edge information!')
1590 if not hasattr(dest
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1591 old_source
= self
.g
._inverse
[dest
] # CHECK FOR PRE-EXISTING EDGE
1592 if old_source
is not None: # REMOVE OLD EDGE FROM CACHE
1593 dict.__delitem
__(self
.g
[old_source
],dest
)
1594 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1595 setattr(dest
,self
.g
.keyColumn
,self
.src
.id) # SAVE TO DB ATTRIBUTE
1596 dict.__setitem
__(self
,dest
,None) # SAVE IN CACHE
1597 def __delitem__(self
,dest
):
1598 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1599 setattr(dest
,self
.g
.keyColumn
,None) # SAVE TO DB ATTRIBUTE
1600 dict.__delitem
__(self
,dest
) # REMOVE FROM CACHE
1602 class ForeignKeyGraph(object, UserDict
.DictMixin
):
1603 '''graph interface to a foreign key in an SQL table
1604 Caches dict of target nodes in itself; provides dict interface.
1606 def __init__(self
, sourceDB
, targetDB
, keyColumn
, autoGC
=True, **kwargs
):
1607 '''sourceDB is any database of source nodes;
1608 targetDB must be an SQL database of target nodes;
1609 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1610 if autoGC
: # automatically garbage collect unused objects
1611 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
1613 self
._weakValueDict
= {}
1614 self
.autoGC
= autoGC
1615 self
.sourceDB
= sourceDB
1616 self
.targetDB
= targetDB
1617 self
.keyColumn
= keyColumn
1618 self
._inverse
= ForeignKeyInverse(self
)
1619 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, keyColumn
=0, autoGC
=0)
1620 __getstate__
= standard_getstate
########### SUPPORT FOR PICKLING
1621 __setstate__
= standard_setstate
1622 def _inverse_schema(self
):
1623 'provide custom schema rule for inverting this graph... just use keyColumn!'
1624 return dict(invert
=True,uniqueMapping
=True)
1625 def __getitem__(self
,k
):
1626 if not hasattr(k
,'db') or k
.db
is not self
.sourceDB
:
1627 raise KeyError('object is not in the sourceDB bound to this graph!')
1629 return self
._weakValueDict
[k
.id] # get from cache
1632 d
= ForeignKeyEdge(self
,k
)
1633 self
._weakValueDict
[k
.id] = d
# save in cache
1635 def __setitem__(self
, k
, v
):
1636 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1637 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1638 def __delitem__(self
, k
):
1639 raise KeyError('''Instead of del g[k], follow a graph
1640 interface: del g[src][dest]''')
1642 return self
.sourceDB
.values()
1643 __invert__
= standard_invert
1645 def describeDBTables(name
,cursor
,idDict
):
1647 Get table info about database <name> via <cursor>, and store primary keys
1648 in idDict, along with a list of the tables each key indexes.
1650 cursor
.execute('use %s' % name
)
1651 cursor
.execute('show tables')
1653 l
=[c
[0] for c
in cursor
.fetchall()]
1656 o
=SQLTable(tname
,cursor
)
1658 for f
in o
.description
:
1659 if f
==o
.primary_key
:
1660 idDict
.setdefault(f
, []).append(o
)
1661 elif f
[-3:]=='_id' and f
not in idDict
:
1667 def indexIDs(tables
,idDict
=None):
1668 "Get an index of primary keys in the <tables> dictionary."
1671 for o
in tables
.values():
1673 if o
.primary_key
not in idDict
:
1674 idDict
[o
.primary_key
]=[]
1675 idDict
[o
.primary_key
].append(o
) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1676 for f
in o
.description
:
1677 if f
[-3:]=='_id' and f
not in idDict
:
1683 def suffixSubset(tables
,suffix
):
1684 "Filter table index for those matching a specific suffix"
1686 for name
,t
in tables
.items():
1687 if name
.endswith(suffix
):
1694 def graphDBTables(tables
,idDict
):
1696 for t
in tables
.values():
1697 for f
in t
.description
:
1698 if f
==t
.primary_key
:
1699 edgeInfo
=PRIMARY_KEY
1702 g
.setEdge(f
,t
,edgeInfo
)
1703 g
.setEdge(t
,f
,edgeInfo
)
1706 SQLTypeTranslation
= {types
.StringType
:'varchar(32)',
1707 types
.IntType
:'int',
1708 types
.FloatType
:'float'}
1710 def createTableFromRepr(rows
,tableName
,cursor
,typeTranslation
=None,
1711 optionalDict
=None,indexDict
=()):
1712 """Save rows into SQL tableName using cursor, with optional
1713 translations of columns to specific SQL types (specified
1714 by typeTranslation dict).
1715 - optionDict can specify columns that are allowed to be NULL.
1716 - indexDict can specify columns that must be indexed; columns
1717 whose names end in _id will be indexed by default.
1718 - rows must be an iterator which in turn returns dictionaries,
1719 each representing a tuple of values (indexed by their column
1723 row
=rows
.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1724 except StopIteration:
1725 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1727 createTableFromRow(cursor
, tableName
,row
,typeTranslation
,
1728 optionalDict
,indexDict
)
1731 storeRow(cursor
,tableName
,row
) # SAVE OUR FIRST ROW
1732 for row
in rows
: # NOW SAVE ALL THE ROWS
1733 storeRow(cursor
,tableName
,row
)
1735 def createTableFromRow(cursor
, tableName
, row
,typeTranslation
=None,
1736 optionalDict
=None,indexDict
=()):
1738 for col
,val
in row
.items(): # PREPARE SQL TYPES FOR COLUMNS
1740 if typeTranslation
!=None and col
in typeTranslation
:
1741 coltype
=typeTranslation
[col
] # USER-SUPPLIED TRANSLATION
1742 elif type(val
) in SQLTypeTranslation
:
1743 coltype
=SQLTypeTranslation
[type(val
)]
1744 else: # SEARCH FOR A COMPATIBLE TYPE
1745 for t
in SQLTypeTranslation
:
1746 if isinstance(val
,t
):
1747 coltype
=SQLTypeTranslation
[t
]
1750 raise TypeError("Don't know SQL type to use for %s" % col
)
1751 create_def
='%s %s' %(col
,coltype
)
1752 if optionalDict
==None or col
not in optionalDict
:
1753 create_def
+=' not null'
1754 create_defs
.append(create_def
)
1755 for col
in row
: # CREATE INDEXES FOR ID COLUMNS
1756 if col
[-3:]=='_id' or col
in indexDict
:
1757 create_defs
.append('index(%s)' % col
)
1758 cmd
='create table if not exists %s (%s)' % (tableName
,','.join(create_defs
))
1759 cursor
.execute(cmd
) # CREATE THE TABLE IN THE DATABASE
1762 def storeRow(cursor
, tableName
, row
):
1763 row_format
=','.join(len(row
)*['%s'])
1764 cmd
='insert into %s values (%s)' % (tableName
,row_format
)
1765 cursor
.execute(cmd
,tuple(row
.values()))
1767 def storeRowDelayed(cursor
, tableName
, row
):
1768 row_format
=','.join(len(row
)*['%s'])
1769 cmd
='insert delayed into %s values (%s)' % (tableName
,row_format
)
1770 cursor
.execute(cmd
,tuple(row
.values()))
1773 class TableGroup(dict):
1774 'provide attribute access to dbname qualified tablenames'
1775 def __init__(self
,db
='test',suffix
=None,**kw
):
1778 if suffix
is not None:
1780 for k
,v
in kw
.items():
1781 if v
is not None and '.' not in v
:
1782 v
=self
.db
+'.'+v
# ADD DATABASE NAME AS PREFIX
1784 def __getattr__(self
,k
):
1787 def sqlite_connect(*args
, **kwargs
):
1788 sqlite
= import_sqlite()
1789 connection
= sqlite
.connect(*args
, **kwargs
)
1790 cursor
= connection
.cursor()
1791 return connection
, cursor
1793 class DBServerInfo(object):
1794 'picklable reference to a database server'
1795 def __init__(self
, moduleName
='MySQLdb', serverSideCursors
=True,
1796 blockIterators
=True, *args
, **kwargs
):
1798 self
.__class
__ = _DBServerModuleDict
[moduleName
]
1800 raise ValueError('Module name not found in _DBServerModuleDict: '\
1802 self
.moduleName
= moduleName
1803 self
.args
= args
# connection arguments
1804 self
.kwargs
= kwargs
1805 self
.serverSideCursors
= serverSideCursors
1806 self
.custom_iter_keys
= blockIterators
1807 if self
.serverSideCursors
and not self
.custom_iter_keys
:
1808 raise ValueError('serverSideCursors=True requires blockIterators=True!')
1811 """returns cursor associated with the DB server info (reused)"""
1814 except AttributeError:
1815 self
._start
_connection
()
1818 def new_cursor(self
, arraysize
=None):
1819 """returns a NEW cursor; you must close it yourself! """
1820 if not hasattr(self
, '_connection'):
1821 self
._start
_connection
()
1822 cursor
= self
._connection
.cursor()
1823 if arraysize
is not None:
1824 cursor
.arraysize
= arraysize
1828 """Close file containing this database"""
1829 self
._cursor
.close()
1830 self
._connection
.close()
1832 del self
._connection
1834 def __getstate__(self
):
1835 """return all picklable arguments"""
1836 return dict(args
=self
.args
, kwargs
=self
.kwargs
,
1837 moduleName
=self
.moduleName
,
1838 serverSideCursors
=self
.serverSideCursors
,
1839 custom_iter_keys
=self
.custom_iter_keys
)
1842 class MySQLServerInfo(DBServerInfo
):
1843 'customized for MySQLdb SSCursor support via new_cursor()'
1844 _serverType
= 'mysql'
1845 def _start_connection(self
):
1846 self
._connection
,self
._cursor
= mysql_connect(*self
.args
, **self
.kwargs
)
1847 def new_cursor(self
, arraysize
=None):
1848 'provide streaming cursor support'
1849 if not self
.serverSideCursors
: # use regular MySQLdb cursor
1850 return DBServerInfo
.new_cursor(self
, arraysize
)
1852 conn
= self
._conn
_sscursor
1853 except AttributeError:
1854 self
._conn
_sscursor
,cursor
= mysql_connect(useStreaming
=True,
1855 *self
.args
, **self
.kwargs
)
1857 cursor
= self
._conn
_sscursor
.cursor()
1858 if arraysize
is not None:
1859 cursor
.arraysize
= arraysize
1862 DBServerInfo
.close(self
)
1864 self
._conn
_sscursor
.close()
1865 del self
._conn
_sscursor
1866 except AttributeError:
1868 def iter_keys(self
, db
, cursor
, map_f
=iter,
1869 cache_f
=lambda x
:[t
[0] for t
in x
], **kwargs
):
1870 block_generator
= BlockGenerator(db
, cursor
, **kwargs
)
1872 cache_f
= block_generator
.cache_f
1873 except AttributeError:
1875 return db
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
,
1876 map_f
=map_f
, fetch_f
=block_generator
)
1878 class CursorCloser(object):
1879 """container for ensuring cursor.close() is called, when this obj deleted.
1880 For Python 2.5+, we could replace this with a try... finally clause
1881 in a generator function such as generic_iterator(); see PEP 342 or
1882 What's New in Python 2.5. """
1883 def __init__(self
, cursor
):
1884 self
.cursor
= cursor
1888 class BlockGenerator(CursorCloser
):
1889 'workaround for MySQLdb iteration horrible performance'
1890 def __init__(self
, db
, cursor
, selectCols
, whereClause
='', **kwargs
):
1892 self
.cursor
= cursor
1893 self
.selectCols
= selectCols
1894 self
.kwargs
= kwargs
1895 self
.whereClause
= ''
1896 if kwargs
['orderBy']: # use iterSQL/iterColumns for WHERE / SELECT
1897 self
.whereSQL
= db
.iterSQL
1898 if selectCols
== '*': # extracting all columns
1899 self
.whereParams
= [db
.data
[col
] for col
in db
.iterColumns
]
1900 else: # selectCols is single column
1901 iterColumns
= list(db
.iterColumns
)
1902 try: # if selectCols in db.iterColumns, just use that
1903 i
= iterColumns
.index(selectCols
)
1904 except ValueError: # have to append selectCols
1905 i
= len(db
.iterColumns
)
1906 iterColumns
+= [selectCols
]
1907 self
.selectCols
= ','.join(iterColumns
)
1908 self
.whereParams
= range(len(db
.iterColumns
))
1909 if i
> 0: # need to extract desired column
1910 self
.cache_f
= lambda x
:[t
[i
] for t
in x
]
1911 else: # just use primary key
1912 self
.whereSQL
= 'WHERE %s>%%s' % self
.db
.primary_key
1913 self
.whereParams
= (db
.data
['id'],)
1918 'get the next block of data'
1921 self
.db
._select
(self
.whereClause
, self
.params
, cursor
=self
.cursor
,
1922 limit
='LIMIT %s' % self
.cursor
.arraysize
,
1923 selectCols
=self
.selectCols
, **(self
.kwargs
))
1924 rows
= self
.cursor
.fetchall()
1925 if len(rows
) < self
.cursor
.arraysize
: # iteration complete
1928 lastrow
= rows
[-1] # extract params from the last row in this block
1929 if len(lastrow
) > 1:
1930 self
.params
= [lastrow
[icol
] for icol
in self
.whereParams
]
1932 self
.params
= lastrow
1933 self
.whereClause
= self
.whereSQL
1938 class SQLiteServerInfo(DBServerInfo
):
1939 """picklable reference to a sqlite database"""
1940 _serverType
= 'sqlite'
1941 def __init__(self
, database
, *args
, **kwargs
):
1942 """Takes same arguments as sqlite3.connect()"""
1943 DBServerInfo
.__init
__(self
, 'sqlite', # save abs path!
1944 database
=SourceFileName(database
),
1946 def _start_connection(self
):
1947 self
._connection
,self
._cursor
= sqlite_connect(*self
.args
, **self
.kwargs
)
1948 def __getstate__(self
):
1949 database
= self
.kwargs
.get('database', False) or self
.args
[0]
1950 if database
== ':memory:':
1951 raise ValueError('SQLite in-memory database is not picklable!')
1952 return DBServerInfo
.__getstate
__(self
)
1954 # list of DBServerInfo subclasses for different modules
1955 _DBServerModuleDict
= dict(MySQLdb
=MySQLServerInfo
, sqlite
=SQLiteServerInfo
)
1958 class MapView(object, UserDict
.DictMixin
):
1959 'general purpose 1:1 mapping defined by any SQL query'
1960 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None,
1961 serverInfo
=None, inverseSQL
=None, **kwargs
):
1962 self
.sourceDB
= sourceDB
1963 self
.targetDB
= targetDB
1964 self
.viewSQL
= viewSQL
1965 self
.inverseSQL
= inverseSQL
1967 if serverInfo
is not None: # get cursor from serverInfo
1968 cursor
= serverInfo
.cursor()
1970 try: # can we get it from our other db?
1971 serverInfo
= sourceDB
.serverInfo
1972 except AttributeError:
1973 raise ValueError('you must provide serverInfo or cursor!')
1975 cursor
= serverInfo
.cursor()
1976 self
.cursor
= cursor
1977 self
.serverInfo
= serverInfo
1978 self
.get_sql_format(False) # get sql formatter for this db interface
1979 _schemaModuleDict
= _schemaModuleDict
# default module list
1980 get_sql_format
= get_table_schema
1981 def __getitem__(self
, k
):
1982 if not hasattr(k
,'db') or k
.db
is not self
.sourceDB
:
1983 raise KeyError('object is not in the sourceDB bound to this map!')
1984 sql
,params
= self
._format
_query
(self
.viewSQL
, (k
.id,))
1985 self
.cursor
.execute(sql
, params
) # formatted for this db interface
1986 t
= self
.cursor
.fetchmany(2) # get at most two rows
1988 raise KeyError('%s not found in MapView, or not unique'
1990 return self
.targetDB
[t
[0][0]] # get the corresponding object
1991 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, viewSQL
=0, serverInfo
=0,
1993 __getstate__
= standard_getstate
1994 __setstate__
= standard_setstate
1995 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
1996 setdefault
= read_only_error
1998 'only yield sourceDB items that are actually in this mapping!'
1999 for k
in self
.sourceDB
.itervalues():
2006 return [k
for k
in self
] # don't use list(self); causes infinite loop!
2007 def __invert__(self
):
2009 return self
._inverse
2010 except AttributeError:
2011 if self
.inverseSQL
is None:
2012 raise ValueError('this MapView has no inverseSQL!')
2013 self
._inverse
= self
.__class
__(self
.targetDB
, self
.sourceDB
,
2014 self
.inverseSQL
, self
.cursor
,
2015 serverInfo
=self
.serverInfo
,
2016 inverseSQL
=self
.viewSQL
)
2017 self
._inverse
._inverse
= self
2018 return self
._inverse
2020 class GraphViewEdgeDict(UserDict
.DictMixin
):
2021 'edge dictionary for GraphView: just pre-loaded on init'
2022 def __init__(self
, g
, k
):
2025 sql
,params
= self
.g
._format
_query
(self
.g
.viewSQL
, (k
.id,))
2026 self
.g
.cursor
.execute(sql
, params
) # run the query
2027 l
= self
.g
.cursor
.fetchall() # get results
2029 raise KeyError('key %s not in GraphView' % k
.id)
2030 self
.targets
= [t
[0] for t
in l
] # preserve order of the results
2031 d
= {} # also keep targetID:edgeID mapping
2032 if self
.g
.edgeDB
is not None: # save with edge info
2040 return len(self
.targets
)
2042 for k
in self
.targets
:
2043 yield self
.g
.targetDB
[k
]
2046 def iteritems(self
):
2047 if self
.g
.edgeDB
is not None: # save with edge info
2048 for k
in self
.targets
:
2049 yield (self
.g
.targetDB
[k
], self
.g
.edgeDB
[self
.targetDict
[k
]])
2050 else: # just save the list of targets, no edge info
2051 for k
in self
.targets
:
2052 yield (self
.g
.targetDB
[k
], None)
2053 def __getitem__(self
, o
, exitIfFound
=False):
2054 'for the specified target object, return its associated edge object'
2056 if o
.db
is not self
.g
.targetDB
:
2057 raise KeyError('key is not part of targetDB!')
2058 edgeID
= self
.targetDict
[o
.id]
2059 except AttributeError:
2060 raise KeyError('key has no id or db attribute?!')
2063 if self
.g
.edgeDB
is not None: # return the edge object
2064 return self
.g
.edgeDB
[edgeID
]
2065 else: # no edge info
2067 def __contains__(self
, o
):
2069 self
.__getitem
__(o
, True) # raise KeyError if not found
2073 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
2074 setdefault
= read_only_error
2076 class GraphView(MapView
):
2077 'general purpose graph interface defined by any SQL query'
2078 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None, edgeDB
=None,
2080 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
2081 self
.edgeDB
= edgeDB
2082 MapView
.__init
__(self
, sourceDB
, targetDB
, viewSQL
, cursor
, **kwargs
)
2083 def __getitem__(self
, k
):
2084 if not hasattr(k
,'db') or k
.db
is not self
.sourceDB
:
2085 raise KeyError('object is not in the sourceDB bound to this map!')
2086 return GraphViewEdgeDict(self
, k
)
2087 _pickleAttrs
= MapView
._pickleAttrs
.copy()
2088 _pickleAttrs
.update(dict(edgeDB
=0))
2090 # @CTB move to sqlgraph.py?
2092 class SQLSequence(SQLRow
, SequenceBase
):
2093 """Transparent access to a DB row representing a sequence.
2095 Use attrAlias dict to rename 'length' to something else.
2097 def _init_subclass(cls
, db
, **kwargs
):
2098 db
.seqInfoDict
= db
# db will act as its own seqInfoDict
2099 SQLRow
._init
_subclass
(db
=db
, **kwargs
)
2100 _init_subclass
= classmethod(_init_subclass
)
2101 def __init__(self
, id):
2102 SQLRow
.__init
__(self
, id)
2103 SequenceBase
.__init
__(self
)
2106 def strslice(self
,start
,end
):
2107 "Efficient access to slice of a sequence, useful for huge contigs"
2108 return self
._select
('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
2109 %(self
.db
._attrSQL
('seq'),start
+1,end
-start
))
2111 class DNASQLSequence(SQLSequence
):
2112 _seqtype
=DNA_SEQTYPE
2114 class RNASQLSequence(SQLSequence
):
2115 _seqtype
=RNA_SEQTYPE
2117 class ProteinSQLSequence(SQLSequence
):
2118 _seqtype
=PROTEIN_SEQTYPE