3 from __future__
import generators
5 from sequence
import SequenceBase
, DNA_SEQTYPE
, RNA_SEQTYPE
, PROTEIN_SEQTYPE
7 from classutil
import methodFactory
,standard_getstate
,\
8 override_rich_cmp
,generate_items
,get_bound_subclass
,standard_setstate
,\
9 get_valid_path
,standard_invert
,RecentValueDictionary
,read_only_error
,\
10 SourceFileName
, split_kwargs
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self
, db
, attr
):
20 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
21 def __get__(self
, obj
, klass
):
22 return obj
._data
[self
.icol
]
23 def __set__(self
, obj
, val
):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor
):
27 def __set__(self
, obj
, val
):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor
):
32 'read-write interface to named attribute'
33 def __init__(self
, db
, attr
):
35 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
36 self
.attrSQL
= db
._attrSQL
(attr
, sqlColumn
=True) # SQL column name
37 def __set__(self
, obj
, val
):
38 obj
.db
._update
(obj
.id, self
.attrSQL
, val
) # AND UPDATE THE DATABASE
39 obj
.save_local(self
.attr
, val
)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self
, db
, attr
):
44 self
.selectSQL
= db
._attrSQL
(attr
) # SQL expression for this attr
45 def __get__(self
, obj
, klass
):
46 return obj
._select
(self
.selectSQL
)
47 def __set__(self
, obj
, val
):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor
):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self
, obj
, val
):
53 obj
.db
._update
(obj
.id, self
.selectSQL
, val
) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self
, db
, attr
):
59 def __get__(self
, obj
, klass
):
60 return getattr(obj
, self
.attr
)
61 def __set__(self
, obj
, val
):
62 raise AttributeError('attribute %s is read-only!' % self
.attr
)
65 def select_from_row(row
, what
):
66 "return value from SQL expression applied to this row"
67 sql
,params
= row
.db
._format
_query
('select %s from %s where %s=%%s limit 2'
68 % (what
,row
.db
.name
,row
.db
.primary_key
),
70 row
.db
.cursor
.execute(sql
, params
)
71 t
= row
.db
.cursor
.fetchmany(2) # get at most two rows
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row
.db
.name
,str(row
.id),what
))
75 return t
[0][0] #return the single field we requested
77 def init_row_subclass(cls
, db
):
78 'add descriptors for db attributes'
79 for attr
in db
.data
: # bind all database columns
80 if attr
== 'id': # handle ID attribute specially
81 setattr(cls
, attr
, cls
._idDescriptor
(db
, attr
))
83 try: # check if this attr maps to an SQL column
84 db
._attrSQL
(attr
, columnNumber
=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls
, attr
, cls
._sqlDescriptor
(db
, attr
))
87 else: # treat as interface to our stored tuple
88 setattr(cls
, attr
, cls
._columnDescriptor
(db
, attr
))
91 """get list of column names as our attributes """
92 return self
.db
.data
.keys()
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor
= TupleDescriptor
115 _idDescriptor
= TupleIDDescriptor
116 _sqlDescriptor
= SQLDescriptor
117 _init_subclass
= classmethod(init_row_subclass
)
118 _select
= select_from_row
120 def __init__(self
, data
):
121 self
._data
= data
# save our data tuple
123 def insert_and_cache_id(self
, l
, **kwargs
):
124 'insert tuple into db and cache its rowID on self'
125 self
.db
._insert
(l
) # save to database
127 rowID
= kwargs
['id'] # use the ID supplied by user
129 rowID
= self
.db
.get_insert_id() # get auto-inc ID value
130 self
.cache_id(rowID
) # cache this ID on self
132 class TupleORW(TupleO
):
133 'read-write version of TupleO'
134 _columnDescriptor
= TupleDescriptorRW
135 insert_and_cache_id
= insert_and_cache_id
136 def __init__(self
, data
, newRow
=False, **kwargs
):
137 if not newRow
: # just cache data from the database
140 self
._data
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
141 self
.insert_and_cache_id(self
._data
, **kwargs
)
142 def cache_id(self
,row_id
):
143 self
.save_local('id',row_id
)
144 def save_local(self
,attr
,val
):
145 icol
= self
._attrcol
[attr
]
147 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self
._data
= list(self
._data
)
150 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
152 TupleO
._RWClass
= TupleORW
# record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self
, db
, attr
, readOnly
= False):
158 self
.col
= db
._attrSQL
(attr
, sqlColumn
=True) # MAP THIS TO SQL COLUMN NAME
161 self
.__class
__ = self
._readOnlyClass
162 def __get__(self
, obj
, objtype
):
164 return obj
.__dict
__[self
.attr
]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self
.col
==self
.db
.primary_key
:
168 self
.db
._select
('where %s=%%s' % self
.db
.primary_key
,(obj
.id,),self
.col
)
169 l
= self
.db
.cursor
.fetchall()
171 raise AttributeError('db row not found or not unique!')
172 obj
.__dict
__[self
.attr
] = l
[0][0] # UPDATE THE CACHE
174 def __set__(self
, obj
, val
):
175 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self
.db
._update
(obj
.id, self
.col
, val
) # UPDATE THE DATABASE
177 obj
.__dict
__[self
.attr
] = val
# UPDATE THE CACHE
179 ## m = self.consequences
180 ## except AttributeError:
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor
):
188 def __set__(self
, obj
, val
):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor
._readOnlyClass
= ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor
= _sqlDescriptor
= SQLDescriptor
200 _idDescriptor
= ReadOnlyDescriptor
201 _init_subclass
= classmethod(init_row_subclass
)
202 _select
= select_from_row
204 def __init__(self
, rowID
):
208 class SQLRowRW(SQLRow
):
209 'read-write version of SQLRow'
210 _columnDescriptor
= SQLDescriptorRW
211 insert_and_cache_id
= insert_and_cache_id
212 def __init__(self
, rowID
, newRow
=False, **kwargs
):
213 if not newRow
: # just cache data from the database
214 return self
.cache_id(rowID
)
215 l
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
216 self
.insert_and_cache_id(l
, **kwargs
)
217 def cache_id(self
, rowID
):
220 SQLRow
._RWClass
= SQLRowRW
224 def list_to_dict(names
, values
):
225 'return dictionary of those named args that are present in values[]'
227 for i
,v
in enumerate(values
):
235 def get_name_cursor(name
=None, **kwargs
):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
240 argList
= name
.split() # TREAT AS WS-SEPARATED LIST
242 name
= argList
[0] # USE 1ST ARG AS TABLE NAME
243 argnames
= ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs
= kwargs
.copy() # a copy we can overwrite
245 kwargs
.update(list_to_dict(argnames
, argList
[1:]))
246 serverInfo
= DBServerInfo(**kwargs
)
247 return name
,serverInfo
.cursor(),serverInfo
249 def mysql_connect(connect
=None, configFile
=None, useStreaming
=False, **args
):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs
= args
.copy() # a copy we can modify
252 if 'user' not in kwargs
and configFile
is None: #Find where config file is
253 osname
= platform
.system()
254 if osname
in('Microsoft', 'Windows'): # Machine is a Windows box
256 try: # handle case where WINDIR not defined by Windows...
257 windir
= os
.environ
['WINDIR']
258 paths
+= [(windir
, 'my.ini'), (windir
, 'my.cnf')]
262 sysdrv
= os
.environ
['SYSTEMDRIVE']
263 paths
+= [(sysdrv
, os
.path
.sep
+ 'my.ini'),
264 (sysdrv
, os
.path
.sep
+ 'my.cnf')]
268 configFile
= get_valid_path(*paths
)
269 else: # treat as normal platform with home directories
270 configFile
= os
.path
.join(os
.path
.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile
= configFile
or os
.path
.join( os
.getcwd(), 'mysql.cnf' )
276 if configFile
and os
.path
.exists(configFile
):
277 kwargs
['read_default_file'] = configFile
278 connect
= None # force it to use MySQLdb
281 connect
= MySQLdb
.connect
282 kwargs
['compress'] = True
283 if useStreaming
: # use server side cursors for scalable result sets
285 from MySQLdb
import cursors
286 kwargs
['cursorclass'] = cursors
.SSCursor
287 except (ImportError, AttributeError):
289 conn
= connect(**kwargs
)
290 cursor
= conn
.cursor()
293 _mysqlMacros
= dict(IGNORE
='ignore', REPLACE
='replace',
294 AUTO_INCREMENT
='AUTO_INCREMENT', SUBSTRING
='substring',
295 SUBSTR_FROM
='FROM', SUBSTR_FOR
='FOR')
297 def mysql_table_schema(self
, analyzeSchema
=True):
298 'retrieve table schema from a MySQL database, save on self'
300 self
._format
_query
= SQLFormatDict(MySQLdb
.paramstyle
, _mysqlMacros
)
301 if not analyzeSchema
:
303 self
.clear_schema() # reset settings and dictionaries
304 self
.cursor
.execute('describe %s' % self
.name
) # get info about columns
305 columns
= self
.cursor
.fetchall()
306 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
307 for icol
,c
in enumerate(columns
):
309 self
.columnName
.append(field
) # list of columns in same order as table
310 if c
[3] == "PRI": # record as primary key
311 if self
.primary_key
is None:
312 self
.primary_key
= field
315 self
.primary_key
.append(field
)
316 except AttributeError:
317 self
.primary_key
= [self
.primary_key
,field
]
318 if c
[1][:3].lower() == 'int':
319 self
.usesIntID
= True
321 self
.usesIntID
= False
323 self
.indexed
[field
] = icol
324 self
.description
[field
] = self
.cursor
.description
[icol
]
325 self
.columnType
[field
] = c
[1] # SQL COLUMN TYPE
327 _sqliteMacros
= dict(IGNORE
='or ignore', REPLACE
='insert or replace',
328 AUTO_INCREMENT
='', SUBSTRING
='substr',
329 SUBSTR_FROM
=',', SUBSTR_FOR
=',')
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
334 import sqlite3
as sqlite
336 from pysqlite2
import dbapi2
as sqlite
339 def sqlite_table_schema(self
, analyzeSchema
=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite
= import_sqlite()
342 self
._format
_query
= SQLFormatDict(sqlite
.paramstyle
, _sqliteMacros
)
343 if not analyzeSchema
:
345 self
.clear_schema() # reset settings and dictionaries
346 self
.cursor
.execute('PRAGMA table_info("%s")' % self
.name
)
347 columns
= self
.cursor
.fetchall()
348 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
349 for icol
,c
in enumerate(columns
):
351 self
.columnName
.append(field
) # list of columns in same order as table
352 self
.description
[field
] = self
.cursor
.description
[icol
]
353 self
.columnType
[field
] = c
[2] # SQL COLUMN TYPE
354 self
.cursor
.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self
.name
) # get primary key / unique indexes
355 for indexname
in self
.cursor
.fetchall(): # search indexes for primary key
356 self
.cursor
.execute('PRAGMA index_info("%s")' % indexname
)
357 l
= self
.cursor
.fetchall() # get list of columns in this index
358 if len(l
) == 1: # assume 1st single-column unique index is primary key!
359 self
.primary_key
= l
[0][2]
360 break # done searching for primary key!
361 if self
.primary_key
is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self
.cursor
.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self
.name
)
363 sql
= self
.cursor
.fetchall()[0][0]
364 for columnSQL
in sql
[sql
.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL
.lower(): # must be the primary key!
366 col
= columnSQL
.split()[0] # get column name
367 if col
in self
.columnType
:
368 self
.primary_key
= col
369 break # done searching for primary key!
371 raise ValueError('unknown primary key %s in table %s'
373 if self
.primary_key
is not None: # check its type
374 if self
.columnType
[self
.primary_key
] == 'int' or \
375 self
.columnType
[self
.primary_key
] == 'integer':
376 self
.usesIntID
= True
378 self
.usesIntID
= False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats
= dict(pyformat
='%%(%d)s', numeric
=':%d', named
=':%d',
393 qmark
='(ignore)', format
='(ignore)')
394 def __init__(self
, paramstyle
, substitutionDict
={}):
395 self
.substitutionDict
= substitutionDict
.copy()
396 self
.paramstyle
= paramstyle
397 self
.paramFormat
= self
._paramFormats
[paramstyle
]
398 self
.makeDict
= (paramstyle
== 'pyformat' or paramstyle
== 'named')
399 if paramstyle
== 'qmark': # handle these as simple substitution
400 self
.substitutionDict
['?'] = '?'
401 elif paramstyle
== 'format':
402 self
.substitutionDict
['?'] = '%s'
403 def __getitem__(self
, k
):
404 'apply correct substitution for this SQL interface'
406 return self
.substitutionDict
[k
] # apply our substitutions
409 if k
== '?': # sequential parameter
410 s
= self
.paramFormat
% self
.iparam
411 self
.iparam
+= 1 # advance to the next parameter
413 raise KeyError('unknown macro: %s' % k
)
414 def __call__(self
, sql
, paramList
):
415 'returns corrected sql,params for this interface'
416 self
.iparam
= 1 # DB-ABI param indexing begins at 1
417 sql
= sql
.replace('%s', '%(?)s') # convert format into pyformat
418 s
= sql
% self
# apply all %(x)s replacements in sql
419 if self
.makeDict
: # construct a params dict
421 for i
,param
in enumerate(paramList
):
422 paramDict
[str(i
+ 1)] = param
#DB-ABI param indexing begins at 1
424 else: # just return the original params list
427 def get_table_schema(self
, analyzeSchema
=True):
428 'run the right schema function based on type of db server connection'
430 modname
= self
.cursor
.__class
__.__module
__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
434 schema_func
= self
._schemaModuleDict
[modname
]
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname
)
439 schema_func(self
, analyzeSchema
) # run the schema function
442 _schemaModuleDict
= {'MySQLdb.cursors':mysql_table_schema
,
443 'pysqlite2.dbapi2':sqlite_table_schema
,
444 'sqlite3':sqlite_table_schema
}
446 class SQLTableBase(object, UserDict
.DictMixin
):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict
= _schemaModuleDict
# default module list
449 get_table_schema
= get_table_schema
450 def __init__(self
,name
,cursor
=None,itemClass
=None,attrAlias
=None,
451 clusterKey
=None,createTable
=None,graph
=None,maxCache
=None,
452 arraysize
=1024, itemSliceClass
=None, dropIfExists
=False,
453 serverInfo
=None, autoGC
=True, orderBy
=None,
454 writeable
=False, **kwargs
):
455 if autoGC
: # automatically garbage collect unused objects
456 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
458 self
._weakValueDict
= {}
460 self
.orderBy
= orderBy
461 self
.writeable
= writeable
463 if serverInfo
is not None: # get cursor from serverInfo
464 cursor
= serverInfo
.cursor()
465 else: # try to read connection info from name or config file
466 name
,cursor
,serverInfo
= get_name_cursor(name
,**kwargs
)
468 warnings
.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
469 DeprecationWarning, stacklevel
=2)
471 if createTable
is not None: # RUN COMMAND TO CREATE THIS TABLE
472 if dropIfExists
: # get rid of any existing table
473 cursor
.execute('drop table if exists ' + name
)
474 self
.get_table_schema(False) # check dbtype, init _format_query
475 sql
,params
= self
._format
_query
(createTable
, ()) # apply macros
476 cursor
.execute(sql
) # create the table
478 if graph
is not None:
480 if maxCache
is not None:
481 self
.maxCache
= maxCache
482 if arraysize
is not None:
483 self
.arraysize
= arraysize
484 cursor
.arraysize
= arraysize
485 self
.get_table_schema() # get schema of columns to serve as attrs
486 self
.data
= {} # map of all attributes, including aliases
487 for icol
,field
in enumerate(self
.columnName
):
488 self
.data
[field
] = icol
# 1st add mappings to columns
490 self
.data
['id']=self
.data
[self
.primary_key
]
491 except (KeyError,TypeError):
493 if hasattr(self
,'_attr_alias'): # apply attribute aliases for this class
494 self
.addAttrAlias(False,**self
._attr
_alias
)
495 self
.objclass(itemClass
) # NEED TO SUBCLASS OUR ITEM CLASS
496 if itemSliceClass
is not None:
497 self
.itemSliceClass
= itemSliceClass
498 get_bound_subclass(self
, 'itemSliceClass', self
.name
) # need to subclass itemSliceClass
499 if attrAlias
is not None: # ADD ATTRIBUTE ALIASES
500 self
.attrAlias
= attrAlias
# RECORD FOR PICKLING PURPOSES
501 self
.data
.update(attrAlias
)
502 if clusterKey
is not None:
503 self
.clusterKey
=clusterKey
504 if serverInfo
is not None:
505 self
.serverInfo
= serverInfo
508 self
._select
(selectCols
='count(*)')
509 return self
.cursor
.fetchone()[0]
512 _pickleAttrs
= dict(name
=0, clusterKey
=0, maxCache
=0, arraysize
=0,
513 attrAlias
=0, serverInfo
=0, autoGC
=0, orderBy
=0,
515 __getstate__
= standard_getstate
516 def __setstate__(self
,state
):
517 # default cursor provisioning by worldbase is deprecated!
518 ## if 'serverInfo' not in state: # hmm, no address for db server?
519 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
520 ## from Data import getResource
521 ## state['cursor'] = getResource.getTableCursor(state['name'])
522 ## except ImportError:
523 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
524 self
.__init
__(**state
)
526 return '<SQL table '+self
.name
+'>'
528 def clear_schema(self
):
529 'reset all schema information for this table'
533 self
.usesIntID
= None
534 self
.primary_key
= None
536 def _attrSQL(self
,attr
,sqlColumn
=False,columnNumber
=False):
537 "Translate python attribute name to appropriate SQL expression"
538 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
539 field
=self
.data
[attr
]
541 raise AttributeError('attribute %s not a valid column or alias in %s'
543 if sqlColumn
: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
544 try: # CHECK IF field IS COLUMN NUMBER
545 return self
.columnName
[field
] # RETURN SQL COLUMN NAME
547 try: # CHECK IF field IS SQL COLUMN NAME
548 return self
.columnName
[self
.data
[field
]] # THIS WILL JUST RETURN field...
549 except (KeyError,TypeError):
550 raise AttributeError('attribute %s does not map to an SQL column in %s'
553 try: # CHECK IF field IS A COLUMN NUMBER
554 return field
+0 # ONLY RETURN AN INTEGER
556 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
557 return self
.data
[field
]+0 # ONLY RETURN AN INTEGER
558 except (KeyError,TypeError):
559 raise ValueError('attribute %s does not map to a SQL column!' % attr
)
560 if isinstance(field
,types
.StringType
):
561 attr
=field
# USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
563 attr
=self
.primary_key
565 def addAttrAlias(self
,saveToPickle
=True,**kwargs
):
566 """Add new attributes as aliases of existing attributes.
567 They can be specified either as named args:
568 t.addAttrAlias(newattr=oldattr)
569 or by passing a dictionary kwargs whose keys are newattr
570 and values are oldattr:
571 t.addAttrAlias(**kwargs)
572 saveToPickle=True forces these aliases to be saved if object is pickled.
575 self
.attrAlias
.update(kwargs
)
576 for key
,val
in kwargs
.items():
577 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
578 self
.data
[val
]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
579 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
580 except TypeError: # val IS ITSELF AN ALIAS
581 self
.data
[key
] = self
.data
[val
] # SO MAP TO WHAT IT MAPS TO
582 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
584 def objclass(self
,oclass
=None):
585 "Create class representing a row in this table by subclassing oclass, adding data"
586 if oclass
is not None: # use this as our base itemClass
587 self
.itemClass
= oclass
589 self
.itemClass
= self
.itemClass
._RWClass
# use its writeable version
590 oclass
= get_bound_subclass(self
, 'itemClass', self
.name
,
591 subclassArgs
=dict(db
=self
)) # bind itemClass
592 if issubclass(oclass
, TupleO
):
593 oclass
._attrcol
= self
.data
# BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
594 if hasattr(oclass
,'_tableclass') and not isinstance(self
,oclass
._tableclass
):
595 self
.__class
__=oclass
._tableclass
# ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
596 def _select(self
, whereClause
='', params
=(), selectCols
='t1.*',
597 cursor
=None, orderBy
='', limit
=''):
598 'execute the specified query but do not fetch'
599 sql
,params
= self
._format
_query
('select %s from %s t1 %s %s %s'
600 % (selectCols
, self
.name
, whereClause
, orderBy
,
603 self
.cursor
.execute(sql
, params
)
605 cursor
.execute(sql
, params
)
606 def select(self
,whereClause
,params
=None,oclass
=None,selectCols
='t1.*'):
607 "Generate the list of objects that satisfy the database SELECT"
609 oclass
=self
.itemClass
610 self
._select
(whereClause
,params
,selectCols
)
611 l
=self
.cursor
.fetchall()
613 yield self
.cacheItem(t
,oclass
)
614 def query(self
,**kwargs
):
615 'query for intersection of all specified kwargs, returned as iterator'
618 for k
,v
in kwargs
.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
619 if v
is None: # CONVERT TO SQL NULL TEST
620 criteria
.append('%s IS NULL' % self
._attrSQL
(k
))
621 else: # TEST FOR EQUALITY
622 criteria
.append('%s=%%s' % self
._attrSQL
(k
))
624 return self
.select('where '+' and '.join(criteria
),params
)
625 def _update(self
,row_id
,col
,val
):
626 'update a single field in the specified row to the specified value'
627 sql
,params
= self
._format
_query
('update %s set %s=%%s where %s=%%s'
628 %(self
.name
,col
,self
.primary_key
),
630 self
.cursor
.execute(sql
, params
)
633 return t
[self
.data
['id']] # GET ID FROM TUPLE
634 except TypeError: # treat as alias
635 return t
[self
.data
[self
.data
['id']]]
636 def cacheItem(self
,t
,oclass
):
637 'get obj from cache if possible, or construct from tuple'
640 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
642 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
643 return self
._weakValueDict
[id]
647 self
._weakValueDict
[id] = o
# CACHE THIS ITEM IN OUR DICTIONARY
649 def cache_items(self
,rows
,oclass
=None):
651 oclass
=self
.itemClass
653 yield self
.cacheItem(t
,oclass
)
654 def foreignKey(self
,attr
,k
):
655 'get iterator for objects with specified foreign key value'
656 return self
.select('where %s=%%s'%attr
,(k
,))
657 def limit_cache(self
):
658 'APPLY maxCache LIMIT TO CACHE SIZE'
660 if self
.maxCache
<len(self
._weakValueDict
):
661 self
._weakValueDict
.clear()
662 except AttributeError:
665 def get_new_cursor(self
):
666 """Return a new cursor object, or None if not possible """
668 new_cursor
= self
.serverInfo
.new_cursor
669 except AttributeError:
671 return new_cursor(self
.arraysize
)
673 def generic_iterator(self
, cursor
=None, fetch_f
=None, cache_f
=None,
675 'generic iterator that runs fetch, cache and map functions'
676 if fetch_f
is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
678 fetch_f
= self
.cursor
.fetchmany
679 else: # isolate this iter from other queries
680 fetch_f
= cursor
.fetchmany
682 cache_f
= self
.cache_items
685 rows
= fetch_f() # FETCH THE NEXT SET OF ROWS
686 if len(rows
)==0: # NO MORE DATA SO ALL DONE
688 for v
in map_f(cache_f(rows
)): # CACHE AND GENERATE RESULTS
690 if cursor
is not None: # close iterator now that we're done
692 def tuple_from_dict(self
, d
):
693 'transform kwarg dict into tuple for storing in database'
694 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
695 for col
,icol
in self
.data
.items():
698 except (KeyError,TypeError):
701 def tuple_from_obj(self
, obj
):
702 'transform object attributes into tuple for storing in database'
703 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
704 for col
,icol
in self
.data
.items():
706 l
[icol
] = getattr(obj
,col
)
707 except (AttributeError,TypeError):
710 def _insert(self
, l
):
711 '''insert tuple into the database. Note this uses the MySQL
712 extension REPLACE, which overwrites any duplicate key.'''
713 s
= '%(REPLACE)s into ' + self
.name
+ ' values (' \
714 + ','.join(['%s']*len(l
)) + ')'
715 sql
,params
= self
._format
_query
(s
, l
)
716 self
.cursor
.execute(sql
, params
)
717 def insert(self
, obj
):
718 '''insert new row by transforming obj to tuple of values'''
719 l
= self
.tuple_from_obj(obj
)
721 def get_insert_id(self
):
722 'get the primary key value for the last INSERT'
723 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
724 auto_id
= self
.cursor
.lastrowid
725 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
726 raise NotImplementedError('''your db lacks lastrowid support?''')
728 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
730 def new(self
, **kwargs
):
731 'return a new record with the assigned attributes, added to DB'
732 if not self
.writeable
:
733 raise ValueError('this database is read only!')
734 obj
= self
.itemClass(None, newRow
=True, **kwargs
) # saves itself to db
735 self
._weakValueDict
[obj
.id] = obj
# AND SAVE TO OUR LOCAL DICT CACHE
737 def clear_cache(self
):
739 self
._weakValueDict
.clear()
740 def __delitem__(self
, k
):
741 if not self
.writeable
:
742 raise ValueError('this database is read only!')
743 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
744 % (self
.name
,self
.primary_key
),(k
,))
745 self
.cursor
.execute(sql
, params
)
747 del self
._weakValueDict
[k
]
751 def getKeys(self
,queryOption
='', selectCols
=None):
752 'uses db select; does not force load'
753 if selectCols
is None:
754 selectCols
=self
.primary_key
755 if queryOption
=='' and self
.orderBy
is not None:
756 queryOption
= self
.orderBy
# apply default ordering
757 self
.cursor
.execute('select %s from %s %s'
758 %(selectCols
,self
.name
,queryOption
))
759 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
761 def iter_keys(self
, selectCols
=None, orderBy
='', map_f
=iter,
762 cache_f
=lambda x
:[t
[0] for t
in x
], get_f
=None, **kwargs
):
763 'guarantee correct iteration insulated from other queries'
764 if selectCols
is None:
765 selectCols
=self
.primary_key
766 if orderBy
=='' and self
.orderBy
is not None:
767 orderBy
= self
.orderBy
# apply default ordering
768 cursor
= self
.get_new_cursor()
769 if cursor
: # got our own cursor, guaranteeing query isolation
770 self
._select
(cursor
=cursor
, selectCols
=selectCols
, orderBy
=orderBy
,
772 return self
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
, map_f
=map_f
)
773 else: # must pre-fetch all keys to ensure query isolation
774 if get_f
is not None:
777 return iter(self
.keys())
779 class SQLTable(SQLTableBase
):
780 "Provide on-the-fly access to rows in the database, caching the results in dict"
781 itemClass
= TupleO
# our default itemClass; constructor can override
784 def load(self
,oclass
=None):
785 "Load all data from the table"
786 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
787 return self
._isLoaded
788 except AttributeError:
791 oclass
=self
.itemClass
792 self
.cursor
.execute('select * from %s' % self
.name
)
793 l
=self
.cursor
.fetchall()
794 self
._weakValueDict
= {} # just store the whole dataset in memory
796 self
.cacheItem(t
,oclass
) # CACHE IT IN LOCAL DICTIONARY
797 self
._isLoaded
=True # MARK THIS CONTAINER AS FULLY LOADED
799 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
801 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
802 except KeyError: # NOT FOUND, SO TRY THE DATABASE
803 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 2'
804 % (self
.name
,self
.primary_key
),(k
,))
805 self
.cursor
.execute(sql
, params
)
806 l
= self
.cursor
.fetchmany(2) # get at most 2 rows
808 raise KeyError('%s not found in %s, or not unique' %(str(k
),self
.name
))
810 return self
.cacheItem(l
[0],self
.itemClass
) # CACHE IT IN LOCAL DICTIONARY
811 def __setitem__(self
, k
, v
):
812 if not self
.writeable
:
813 raise ValueError('this database is read only!')
817 except AttributeError:
818 raise ValueError('object not bound to itemClass for this db!')
823 except AttributeError:
825 else: # delete row with old ID
827 v
.cache_id(k
) # cache the new ID on the object
828 self
.insert(v
) # SAVE TO THE RELATIONAL DB SERVER
829 self
._weakValueDict
[k
] = v
# CACHE THIS ITEM IN OUR DICTIONARY
831 'forces load of entire table into memory'
833 return [(k
,self
[k
]) for k
in self
] # apply orderBy rules...
835 'uses arraysize / maxCache and fetchmany() to manage data transfer'
836 return iter_keys(self
, selectCols
='*', cache_f
=None,
837 map_f
=generate_items
, get_f
=self
.items
)
839 'forces load of entire table into memory'
841 return [self
[k
] for k
in self
] # apply orderBy rules...
842 def itervalues(self
):
843 'uses arraysize / maxCache and fetchmany() to manage data transfer'
844 return iter_keys(self
, selectCols
='*', cache_f
=None, get_f
=self
.values
)
846 def getClusterKeys(self
,queryOption
=''):
847 'uses db select; does not force load'
848 self
.cursor
.execute('select distinct %s from %s %s'
849 %(self
.clusterKey
,self
.name
,queryOption
))
850 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
853 class SQLTableClustered(SQLTable
):
854 '''use clusterKey to load a whole cluster of rows at once,
855 specifically, all rows that share the same clusterKey value.'''
856 def __init__(self
, *args
, **kwargs
):
857 kwargs
= kwargs
.copy() # get a copy we can alter
858 kwargs
['autoGC'] = False # don't use WeakValueDictionary
859 SQLTable
.__init
__(self
, *args
, **kwargs
)
861 return getKeys(self
,'order by %s' %self
.clusterKey
)
862 def clusterkeys(self
):
863 return getClusterKeys(self
, 'order by %s' %self
.clusterKey
)
864 def __getitem__(self
,k
):
866 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
867 except KeyError: # NOT FOUND, SO TRY THE DATABASE
868 sql
,params
= self
._format
_query
('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
869 % (self
.name
,self
.name
,self
.primary_key
,
870 self
.clusterKey
,self
.clusterKey
),(k
,))
871 self
.cursor
.execute(sql
, params
)
872 l
=self
.cursor
.fetchall()
874 for t
in l
: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
875 self
.cacheItem(t
,self
.itemClass
)
876 return self
._weakValueDict
[k
] # should be in cache, if row k exists
877 def itercluster(self
,cluster_id
):
878 'iterate over all items from the specified cluster'
880 return self
.select('where %s=%%s'%self
.clusterKey
,(cluster_id
,))
881 def fetch_cluster(self
):
882 'use self.cursor.fetchmany to obtain all rows for next cluster'
883 icol
= self
._attrSQL
(self
.clusterKey
,columnNumber
=True)
886 rows
= self
._fetch
_cluster
_cache
# USE SAVED ROWS FROM PREVIOUS CALL
887 del self
._fetch
_cluster
_cache
888 except AttributeError:
889 rows
= self
.cursor
.fetchmany()
891 cluster_id
= rows
[0][icol
]
895 for i
,t
in enumerate(rows
): # CHECK THAT ALL ROWS FROM THIS CLUSTER
896 if cluster_id
!= t
[icol
]: # START OF A NEW CLUSTER
897 result
+= rows
[:i
] # RETURN ROWS OF CURRENT CLUSTER
898 self
._fetch
_cluster
_cache
= rows
[i
:] # SAVE NEXT CLUSTER
901 rows
= self
.cursor
.fetchmany() # GET NEXT SET OF ROWS
903 def itervalues(self
):
904 'uses arraysize / maxCache and fetchmany() to manage data transfer'
905 cursor
= self
.get_new_cursor()
906 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
907 return self
.generic_iterator(cursor
, self
.fetch_cluster
)
909 'uses arraysize / maxCache and fetchmany() to manage data transfer'
910 cursor
= self
.get_new_cursor()
911 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
912 return self
.generic_iterator(cursor
, self
.fetch_cluster
,
913 map_f
=generate_items
)
915 class SQLForeignRelation(object):
916 'mapping based on matching a foreign key in an SQL table'
917 def __init__(self
,table
,keyName
):
920 def __getitem__(self
,k
):
921 'get list of objects o with getattr(o,keyName)==k.id'
923 for o
in self
.table
.select('where %s=%%s'%self
.keyName
,(k
.id,)):
926 raise KeyError('%s not found in %s' %(str(k
),self
.name
))
930 class SQLTableNoCache(SQLTableBase
):
931 '''Provide on-the-fly access to rows in the database;
932 values are simply an object interface (SQLRow) to back-end db query.
933 Row data are not stored locally, but always accessed by querying the db'''
934 itemClass
=SQLRow
# DEFAULT OBJECT CLASS FOR ROWS...
937 def getID(self
,t
): return t
[0] # GET ID FROM TUPLE
938 def select(self
,whereClause
,params
):
939 return SQLTableBase
.select(self
,whereClause
,params
,self
.oclass
,
941 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
943 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
944 except KeyError: # NOT FOUND, SO TRY THE DATABASE
945 self
._select
('where %s=%%s' % self
.primary_key
, (k
,),
947 t
= self
.cursor
.fetchmany(2)
949 raise KeyError('id %s non-existent or not unique' % k
)
950 o
= self
.itemClass(k
) # create obj referencing this ID
951 self
._weakValueDict
[k
] = o
# cache the SQLRow object
953 def __setitem__(self
, k
, v
):
954 if not self
.writeable
:
955 raise ValueError('this database is read only!')
959 except AttributeError:
960 raise ValueError('object not bound to itemClass for this db!')
962 del self
[k
] # delete row with new ID if any
966 del self
._weakValueDict
[v
.id] # delete from old cache location
969 self
._update
(v
.id, self
.primary_key
, k
) # just change its ID in db
970 v
.cache_id(k
) # change the cached ID value
971 self
._weakValueDict
[k
] = v
# assign to new cache location
972 def addAttrAlias(self
,**kwargs
):
973 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
975 SQLRow
._tableclass
=SQLTableNoCache
# SQLRow IS FOR NON-CACHING TABLE INTERFACE
978 class SQLTableMultiNoCache(SQLTableBase
):
979 "Trivial on-the-fly access for table with key that returns multiple rows"
980 itemClass
= TupleO
# default itemClass; constructor can override
981 _distinct_key
='id' # DEFAULT COLUMN TO USE AS KEY
983 return getKeys(self
, selectCols
='distinct(%s)'
984 % self
._attrSQL
(self
._distinct
_key
))
986 return iter_keys(self
, 'distinct(%s)' % self
._attrSQL
(self
._distinct
_key
))
987 def __getitem__(self
,id):
988 sql
,params
= self
._format
_query
('select * from %s where %s=%%s'
989 %(self
.name
,self
._attrSQL
(self
._distinct
_key
)),(id,))
990 self
.cursor
.execute(sql
, params
)
991 l
=self
.cursor
.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
993 yield self
.itemClass(row
)
994 def addAttrAlias(self
,**kwargs
):
995 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
999 class SQLEdges(SQLTableMultiNoCache
):
1000 '''provide iterator over edges as (source,target,edge)
1001 and getitem[edge] --> [(source,target),...]'''
1002 _distinct_key
='edge_id'
1003 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1004 _pickleAttrs
.update(dict(graph
=0))
1006 self
.cursor
.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1007 %(self
._attrSQL
('source_id'),self
._attrSQL
('target_id'),
1008 self
._attrSQL
('edge_id'),self
.name
,
1009 self
._attrSQL
('target_id'),self
._attrSQL
('source_id'),
1010 self
._attrSQL
('target_id')))
1011 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1012 for source_id
,target_id
,edge_id
in self
.cursor
.fetchall():
1013 l
.append((self
.graph
.unpack_source(source_id
),
1014 self
.graph
.unpack_target(target_id
),
1015 self
.graph
.unpack_edge(edge_id
)))
1019 return iter(self
.keys())
1020 def __getitem__(self
,edge
):
1021 sql
,params
= self
._format
_query
('select %s,%s from %s where %s=%%s'
1022 %(self
._attrSQL
('source_id'),
1023 self
._attrSQL
('target_id'),
1025 self
._attrSQL
(self
._distinct
_key
)),
1026 (self
.graph
.pack_edge(edge
),))
1027 self
.cursor
.execute(sql
, params
)
1028 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1029 for source_id
,target_id
in self
.cursor
.fetchall():
1030 l
.append((self
.graph
.unpack_source(source_id
),
1031 self
.graph
.unpack_target(target_id
)))
1035 class SQLEdgeDict(object):
1036 '2nd level graph interface to SQL database'
1037 def __init__(self
,fromNode
,table
):
1038 self
.fromNode
=fromNode
1040 if not hasattr(self
.table
,'allowMissingNodes'):
1041 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s limit 1'
1042 %(self
.table
.sourceSQL
,
1044 self
.table
.sourceSQL
),
1046 self
.table
.cursor
.execute(sql
, params
)
1047 if len(self
.table
.cursor
.fetchall())<1:
1048 raise KeyError('node not in graph!')
1050 def __getitem__(self
,target
):
1051 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1052 %(self
.table
.edgeSQL
,
1054 self
.table
.sourceSQL
,
1055 self
.table
.targetSQL
),
1057 self
.table
.pack_target(target
)))
1058 self
.table
.cursor
.execute(sql
, params
)
1059 l
= self
.table
.cursor
.fetchmany(2) # get at most two rows
1061 raise KeyError('either no edge from source to target or not unique!')
1063 return self
.table
.unpack_edge(l
[0][0]) # RETURN EDGE
1065 raise KeyError('no edge from node to target')
1066 def __setitem__(self
,target
,edge
):
1067 sql
,params
= self
.table
._format
_query
('replace into %s values (%%s,%%s,%%s)'
1070 self
.table
.pack_target(target
),
1071 self
.table
.pack_edge(edge
)))
1072 self
.table
.cursor
.execute(sql
, params
)
1073 if not hasattr(self
.table
,'sourceDB') or \
1074 (hasattr(self
.table
,'targetDB') and self
.table
.sourceDB
==self
.table
.targetDB
):
1075 self
.table
+= target
# ADD AS NODE TO GRAPH
1076 def __iadd__(self
,target
):
1078 return self
# iadd MUST RETURN self!
1079 def __delitem__(self
,target
):
1080 sql
,params
= self
.table
._format
_query
('delete from %s where %s=%%s and %s=%%s'
1082 self
.table
.sourceSQL
,
1083 self
.table
.targetSQL
),
1085 self
.table
.pack_target(target
)))
1086 self
.table
.cursor
.execute(sql
, params
)
1087 if self
.table
.cursor
.rowcount
< 1: # no rows deleted?
1088 raise KeyError('no edge from node to target')
1090 def iterator_query(self
):
1091 sql
,params
= self
.table
._format
_query
('select %s,%s from %s where %s=%%s and %s is not null'
1092 %(self
.table
.targetSQL
,
1095 self
.table
.sourceSQL
,
1096 self
.table
.targetSQL
),
1098 self
.table
.cursor
.execute(sql
, params
)
1099 return self
.table
.cursor
.fetchall()
1101 return [self
.table
.unpack_target(target_id
)
1102 for target_id
,edge_id
in self
.iterator_query()]
1104 return [self
.table
.unpack_edge(edge_id
)
1105 for target_id
,edge_id
in self
.iterator_query()]
1107 return [(self
.table
.unpack_source(self
.fromNode
),self
.table
.unpack_target(target_id
),
1108 self
.table
.unpack_edge(edge_id
))
1109 for target_id
,edge_id
in self
.iterator_query()]
1111 return [(self
.table
.unpack_target(target_id
),self
.table
.unpack_edge(edge_id
))
1112 for target_id
,edge_id
in self
.iterator_query()]
1113 def __iter__(self
): return iter(self
.keys())
1114 def itervalues(self
): return iter(self
.values())
1115 def iteritems(self
): return iter(self
.items())
1117 return len(self
.keys())
1120 class SQLEdgelessDict(SQLEdgeDict
):
1121 'for SQLGraph tables that lack edge_id column'
1122 def __getitem__(self
,target
):
1123 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1124 %(self
.table
.targetSQL
,
1126 self
.table
.sourceSQL
,
1127 self
.table
.targetSQL
),
1129 self
.table
.pack_target(target
)))
1130 self
.table
.cursor
.execute(sql
, params
)
1131 l
= self
.table
.cursor
.fetchmany(2)
1133 raise KeyError('either no edge from source to target or not unique!')
1134 return None # no edge info!
1135 def iterator_query(self
):
1136 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s is not null'
1137 %(self
.table
.targetSQL
,
1139 self
.table
.sourceSQL
,
1140 self
.table
.targetSQL
),
1142 self
.table
.cursor
.execute(sql
, params
)
1143 return [(t
[0],None) for t
in self
.table
.cursor
.fetchall()]
1145 SQLEdgeDict
._edgelessClass
= SQLEdgelessDict
1147 class SQLGraphEdgeDescriptor(object):
1148 'provide an SQLEdges interface on demand'
1149 def __get__(self
,obj
,objtype
):
1151 attrAlias
=obj
.attrAlias
.copy()
1152 except AttributeError:
1153 return SQLEdges(obj
.name
, obj
.cursor
, graph
=obj
)
1155 return SQLEdges(obj
.name
, obj
.cursor
, attrAlias
=attrAlias
,
1158 def getColumnTypes(createTable
,attrAlias
={},defaultColumnType
='int',
1159 columnAttrs
=('source','target','edge'),**kwargs
):
1160 'return list of [(colname,coltype),...] for source,target,edge'
1162 for attr
in columnAttrs
:
1164 attrName
= attrAlias
[attr
+'_id']
1166 attrName
= attr
+'_id'
1167 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1168 l
.append((attrName
,createTable
[attr
+'_id']))
1170 except (KeyError,TypeError):
1172 try: # get type info from primary key for that database
1173 db
= kwargs
[attr
+'DB']
1175 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1178 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1180 try: # GET ONE IDENTIFIER FROM THE DATABASE
1182 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1184 l
.append((attrName
,db
.columnType
[db
.primary_key
]))
1186 except AttributeError:
1188 else: # GET THE TYPE FROM THIS IDENTIFIER
1189 if isinstance(k
,int) or isinstance(k
,long):
1190 l
.append((attrName
,'int'))
1192 elif isinstance(k
,str):
1193 l
.append((attrName
,'varchar(32)'))
1196 raise ValueError('SQLGraph node / edge must be int or str!')
1197 l
.append((attrName
,defaultColumnType
))
1198 logger
.warn('no type info found for %s, so using default: %s'
1199 % (attrName
, defaultColumnType
))
1205 class SQLGraph(SQLTableMultiNoCache
):
1206 '''provide a graph interface via a SQL table. Key capabilities are:
1207 - setitem with an empty dictionary: a dummy operation
1208 - getitem with a key that exists: return a placeholder
1209 - setitem with non empty placeholder: again a dummy operation
1210 EXAMPLE TABLE SCHEMA:
1211 create table mygraph (source_id int not null,target_id int,edge_id int,
1212 unique(source_id,target_id));
1214 _distinct_key
='source_id'
1215 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1216 _pickleAttrs
.update(dict(sourceDB
=0,targetDB
=0,edgeDB
=0,allowMissingNodes
=0))
1217 _edgeClass
= SQLEdgeDict
1218 def __init__(self
,name
,*l
,**kwargs
):
1219 graphArgs
,tableArgs
= split_kwargs(kwargs
,
1220 ('attrAlias','defaultColumnType','columnAttrs',
1221 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1222 'edgeDictClass','graph'))
1223 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1224 c
= getColumnTypes(**kwargs
)
1225 tableArgs
['createTable'] = \
1226 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1227 % (name
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1229 self
.allowMissingNodes
= kwargs
['allowMissingNodes']
1230 except KeyError: pass
1231 SQLTableMultiNoCache
.__init
__(self
,name
,*l
,**tableArgs
)
1232 self
.sourceSQL
= self
._attrSQL
('source_id')
1233 self
.targetSQL
= self
._attrSQL
('target_id')
1235 self
.edgeSQL
= self
._attrSQL
('edge_id')
1236 except AttributeError:
1238 self
._edgeClass
= self
._edgeClass
._edgelessClass
1239 save_graph_db_refs(self
,**kwargs
)
1240 def __getitem__(self
,k
):
1241 return self
._edgeClass
(self
.pack_source(k
),self
)
1242 def __iadd__(self
,k
):
1243 sql
,params
= self
._format
_query
('delete from %s where %s=%%s and %s is null'
1244 % (self
.name
,self
.sourceSQL
,self
.targetSQL
),
1245 (self
.pack_source(k
),))
1246 self
.cursor
.execute(sql
, params
)
1247 sql
,params
= self
._format
_query
('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1248 % self
.name
,(self
.pack_source(k
),))
1249 self
.cursor
.execute(sql
, params
)
1250 return self
# iadd MUST RETURN SELF!
1251 def __isub__(self
,k
):
1252 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
1253 % (self
.name
,self
.sourceSQL
),
1254 (self
.pack_source(k
),))
1255 self
.cursor
.execute(sql
, params
)
1256 if self
.cursor
.rowcount
== 0:
1257 raise KeyError('node not found in graph')
1258 return self
# iadd MUST RETURN SELF!
1259 __setitem__
= graph_setitem
1260 def __contains__(self
,k
):
1261 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 1'
1262 %(self
.name
,self
.sourceSQL
),
1263 (self
.pack_source(k
),))
1264 self
.cursor
.execute(sql
, params
)
1265 l
= self
.cursor
.fetchmany(2)
1267 def __invert__(self
):
1268 'get an interface to the inverse graph mapping'
1270 return self
._inverse
1271 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1272 attrAlias
= dict(source_id
=self
.targetSQL
, # SWAP SOURCE & TARGET
1273 target_id
=self
.sourceSQL
,
1274 edge_id
=self
.edgeSQL
)
1275 if self
.edgeSQL
is None: # no edge interface
1276 del attrAlias
['edge_id']
1277 self
._inverse
=SQLGraph(self
.name
,self
.cursor
,
1278 attrAlias
=attrAlias
,
1279 **graph_db_inverse_refs(self
))
1280 self
._inverse
._inverse
=self
1281 return self
._inverse
1283 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1284 yield self
.unpack_source(k
)
1285 def iteritems(self
):
1286 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1287 yield (self
.unpack_source(k
), self
._edgeClass
(k
, self
))
1288 def itervalues(self
):
1289 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1290 yield self
._edgeClass
(k
, self
)
1292 return [self
.unpack_source(k
) for k
in SQLTableMultiNoCache
.keys(self
)]
1293 def values(self
): return list(self
.itervalues())
1294 def items(self
): return list(self
.iteritems())
1295 edges
=SQLGraphEdgeDescriptor()
1296 update
= update_graph
1298 'get number of source nodes in graph'
1299 self
.cursor
.execute('select count(distinct %s) from %s'
1300 %(self
.sourceSQL
,self
.name
))
1301 return self
.cursor
.fetchone()[0]
1303 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1304 ## def __cmp__(self,other):
1308 ## it = iter(self.edges)
1311 ## source,target,edge = it.next()
1312 ## except StopIteration:
1315 ## if d is not None:
1316 ## diff = cmp(n_target,len(d))
1319 ## if source is None:
1322 ## n += 1 # COUNT SOURCE NODES
1329 ## diff = cmp(edge,d[target])
1334 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1335 ## return cmp(n,len(other))
1337 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1339 class SQLIDGraph(SQLGraph
):
1340 add_trivial_packing_methods(locals())
1341 SQLGraph
._IDGraphClass
= SQLIDGraph
1345 class SQLEdgeDictClustered(dict):
1346 'simple cache for 2nd level dictionary of target_id:edge_id'
1347 def __init__(self
,g
,fromNode
):
1349 self
.fromNode
=fromNode
1351 def __iadd__(self
,l
):
1352 for target_id
,edge_id
in l
:
1353 dict.__setitem
__(self
,target_id
,edge_id
)
1354 return self
# iadd MUST RETURN SELF!
1356 class SQLEdgesClusteredDescr(object):
1357 def __get__(self
,obj
,objtype
):
1358 e
=SQLEdgesClustered(obj
.table
,obj
.edge_id
,obj
.source_id
,obj
.target_id
,
1359 graph
=obj
,**graph_db_inverse_refs(obj
,True))
1360 for source_id
,d
in obj
.d
.iteritems(): # COPY EDGE CACHE
1361 e
.load([(edge_id
,source_id
,target_id
)
1362 for (target_id
,edge_id
) in d
.iteritems()])
1365 class SQLGraphClustered(object):
1366 'SQL graph with clustered caching -- loads an entire cluster at a time'
1367 _edgeDictClass
=SQLEdgeDictClustered
1368 def __init__(self
,table
,source_id
='source_id',target_id
='target_id',
1369 edge_id
='edge_id',clusterKey
=None,**kwargs
):
1371 if isinstance(table
,types
.StringType
): # CREATE THE TABLE INTERFACE
1372 if clusterKey
is None:
1373 raise ValueError('you must provide a clusterKey argument!')
1374 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1375 c
= getColumnTypes(attrAlias
=dict(source_id
=source_id
,target_id
=target_id
,
1376 edge_id
=edge_id
),**kwargs
)
1377 kwargs
['createTable'] = \
1378 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1379 % (table
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],
1380 c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1381 table
= SQLTableClustered(table
,clusterKey
=clusterKey
,**kwargs
)
1383 self
.source_id
=source_id
1384 self
.target_id
=target_id
1385 self
.edge_id
=edge_id
1387 save_graph_db_refs(self
,**kwargs
)
1388 _pickleAttrs
= dict(table
=0,source_id
=0,target_id
=0,edge_id
=0,sourceDB
=0,targetDB
=0,
1390 def __getstate__(self
):
1391 state
= standard_getstate(self
)
1392 state
['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1394 def __getitem__(self
,k
):
1395 'get edgeDict for source node k, from cache or by loading its cluster'
1396 try: # GET DIRECTLY FROM CACHE
1399 if hasattr(self
,'_isLoaded'):
1400 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1401 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1402 sql
,params
= self
.table
._format
_query
('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1403 %(self
.source_id
,self
.target_id
,
1404 self
.edge_id
,self
.table
.name
,
1405 self
.table
.name
,self
.source_id
,
1406 self
.table
.clusterKey
,self
.table
.clusterKey
,
1407 self
.table
.primary_key
),
1408 (self
.pack_source(k
),))
1409 self
.table
.cursor
.execute(sql
, params
)
1410 self
.load(self
.table
.cursor
.fetchall()) # CACHE THIS CLUSTER
1411 return self
.d
[k
] # RETURN EDGE DICT FOR THIS NODE
1412 def load(self
,l
=None,unpack
=True):
1413 'load the specified rows (or all, if None provided) into local cache'
1415 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1416 return self
._isLoaded
1417 except AttributeError:
1419 self
.table
.cursor
.execute('select %s,%s,%s from %s'
1420 %(self
.source_id
,self
.target_id
,
1421 self
.edge_id
,self
.table
.name
))
1422 l
=self
.table
.cursor
.fetchall()
1424 self
.d
.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1425 for source
,target
,edge
in l
: # SAVE TO OUR CACHE
1427 source
= self
.unpack_source(source
)
1428 target
= self
.unpack_target(target
)
1429 edge
= self
.unpack_edge(edge
)
1431 self
.d
[source
] += [(target
,edge
)]
1433 d
= self
._edgeDictClass
(self
,source
)
1434 d
+= [(target
,edge
)]
1436 def __invert__(self
):
1437 'interface to reverse graph mapping'
1439 return self
._inverse
# INVERSE MAP ALREADY EXISTS
1440 except AttributeError:
1442 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1443 self
._inverse
=SQLGraphClustered(self
.table
,self
.target_id
,self
.source_id
,
1444 self
.edge_id
,**graph_db_inverse_refs(self
))
1445 self
._inverse
._inverse
=self
1446 for source
,d
in self
.d
.iteritems(): # INVERT OUR CACHE
1447 self
._inverse
.load([(target
,source
,edge
)
1448 for (target
,edge
) in d
.iteritems()],unpack
=False)
1449 return self
._inverse
1450 edges
=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1451 update
= update_graph
1452 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1453 def __iter__(self
): ################# ITERATORS
1454 'uses db select; does not force load'
1455 return iter(self
.keys())
1457 'uses db select; does not force load'
1458 self
.table
.cursor
.execute('select distinct(%s) from %s'
1459 %(self
.source_id
,self
.table
.name
))
1460 return [self
.unpack_source(t
[0])
1461 for t
in self
.table
.cursor
.fetchall()]
1462 methodFactory(['iteritems','items','itervalues','values'],
1463 'lambda self:(self.load(),self.d.%s())[1]',locals())
1464 def __contains__(self
,k
):
1471 class SQLIDGraphClustered(SQLGraphClustered
):
1472 add_trivial_packing_methods(locals())
1473 SQLGraphClustered
._IDGraphClass
= SQLIDGraphClustered
1475 class SQLEdgesClustered(SQLGraphClustered
):
1476 'edges interface for SQLGraphClustered'
1477 _edgeDictClass
= list
1478 _pickleAttrs
= SQLGraphClustered
._pickleAttrs
.copy()
1479 _pickleAttrs
.update(dict(graph
=0))
1483 for edge_id
,l
in self
.d
.iteritems():
1484 for source_id
,target_id
in l
:
1485 result
.append((self
.graph
.unpack_source(source_id
),
1486 self
.graph
.unpack_target(target_id
),
1487 self
.graph
.unpack_edge(edge_id
)))
1490 class ForeignKeyInverse(object):
1491 'map each key to a single value according to its foreign key'
1492 def __init__(self
,g
):
1494 def __getitem__(self
,obj
):
1496 source_id
= getattr(obj
,self
.g
.keyColumn
)
1497 if source_id
is None:
1499 return self
.g
.sourceDB
[source_id
]
1500 def __setitem__(self
,obj
,source
):
1502 if source
is not None:
1503 self
.g
[source
][obj
] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1504 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1505 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1506 old_source
= self
[obj
]
1507 if old_source
is not None:
1508 del self
.g
[old_source
][obj
]
1509 def check_obj(self
,obj
):
1510 'raise KeyError if obj not from this db'
1512 if obj
.db
!= self
.g
.targetDB
:
1513 raise AttributeError
1514 except AttributeError:
1515 raise KeyError('key is not from targetDB of this graph!')
1516 def __contains__(self
,obj
):
1523 return self
.g
.targetDB
.itervalues()
1525 return self
.g
.targetDB
.values()
1526 def iteritems(self
):
1528 source_id
= getattr(obj
,self
.g
.keyColumn
)
1529 if source_id
is None:
1532 yield obj
,self
.g
.sourceDB
[source_id
]
1534 return list(self
.iteritems())
1535 def itervalues(self
):
1536 for obj
,val
in self
.iteritems():
1539 return list(self
.itervalues())
1540 def __invert__(self
):
1543 class ForeignKeyEdge(dict):
1544 '''edge interface to a foreign key in an SQL table.
1545 Caches dict of target nodes in itself; provides dict interface.
1546 Adds or deletes edges by setting foreign key values in the table'''
1547 def __init__(self
,g
,k
):
1551 for v
in g
.targetDB
.select('where %s=%%s' % g
.keyColumn
,(k
.id,)): # SEARCH THE DB
1552 dict.__setitem
__(self
,v
,None) # SAVE IN CACHE
1553 def __setitem__(self
,dest
,v
):
1554 if not hasattr(dest
,'db') or dest
.db
!= self
.g
.targetDB
:
1555 raise KeyError('dest is not in the targetDB bound to this graph!')
1557 raise ValueError('sorry,this graph cannot store edge information!')
1558 if not hasattr(dest
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1559 old_source
= self
.g
._inverse
[dest
] # CHECK FOR PRE-EXISTING EDGE
1560 if old_source
is not None: # REMOVE OLD EDGE FROM CACHE
1561 dict.__delitem
__(self
.g
[old_source
],dest
)
1562 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1563 setattr(dest
,self
.g
.keyColumn
,self
.src
.id) # SAVE TO DB ATTRIBUTE
1564 dict.__setitem
__(self
,dest
,None) # SAVE IN CACHE
1565 def __delitem__(self
,dest
):
1566 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1567 setattr(dest
,self
.g
.keyColumn
,None) # SAVE TO DB ATTRIBUTE
1568 dict.__delitem
__(self
,dest
) # REMOVE FROM CACHE
1570 class ForeignKeyGraph(object, UserDict
.DictMixin
):
1571 '''graph interface to a foreign key in an SQL table
1572 Caches dict of target nodes in itself; provides dict interface.
1574 def __init__(self
, sourceDB
, targetDB
, keyColumn
, autoGC
=True, **kwargs
):
1575 '''sourceDB is any database of source nodes;
1576 targetDB must be an SQL database of target nodes;
1577 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1578 if autoGC
: # automatically garbage collect unused objects
1579 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
1581 self
._weakValueDict
= {}
1582 self
.autoGC
= autoGC
1583 self
.sourceDB
= sourceDB
1584 self
.targetDB
= targetDB
1585 self
.keyColumn
= keyColumn
1586 self
._inverse
= ForeignKeyInverse(self
)
1587 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, keyColumn
=0, autoGC
=0)
1588 __getstate__
= standard_getstate
########### SUPPORT FOR PICKLING
1589 __setstate__
= standard_setstate
1590 def _inverse_schema(self
):
1591 'provide custom schema rule for inverting this graph... just use keyColumn!'
1592 return dict(invert
=True,uniqueMapping
=True)
1593 def __getitem__(self
,k
):
1594 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1595 raise KeyError('object is not in the sourceDB bound to this graph!')
1597 return self
._weakValueDict
[k
.id] # get from cache
1600 d
= ForeignKeyEdge(self
,k
)
1601 self
._weakValueDict
[k
.id] = d
# save in cache
1603 def __setitem__(self
, k
, v
):
1604 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1605 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1606 def __delitem__(self
, k
):
1607 raise KeyError('''Instead of del g[k], follow a graph
1608 interface: del g[src][dest]''')
1610 return self
.sourceDB
.values()
1611 __invert__
= standard_invert
1613 def describeDBTables(name
,cursor
,idDict
):
1615 Get table info about database <name> via <cursor>, and store primary keys
1616 in idDict, along with a list of the tables each key indexes.
1618 cursor
.execute('use %s' % name
)
1619 cursor
.execute('show tables')
1621 l
=[c
[0] for c
in cursor
.fetchall()]
1624 o
=SQLTable(tname
,cursor
)
1626 for f
in o
.description
:
1627 if f
==o
.primary_key
:
1628 idDict
.setdefault(f
, []).append(o
)
1629 elif f
[-3:]=='_id' and f
not in idDict
:
1635 def indexIDs(tables
,idDict
=None):
1636 "Get an index of primary keys in the <tables> dictionary."
1639 for o
in tables
.values():
1641 if o
.primary_key
not in idDict
:
1642 idDict
[o
.primary_key
]=[]
1643 idDict
[o
.primary_key
].append(o
) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1644 for f
in o
.description
:
1645 if f
[-3:]=='_id' and f
not in idDict
:
1651 def suffixSubset(tables
,suffix
):
1652 "Filter table index for those matching a specific suffix"
1654 for name
,t
in tables
.items():
1655 if name
.endswith(suffix
):
1662 def graphDBTables(tables
,idDict
):
1664 for t
in tables
.values():
1665 for f
in t
.description
:
1666 if f
==t
.primary_key
:
1667 edgeInfo
=PRIMARY_KEY
1670 g
.setEdge(f
,t
,edgeInfo
)
1671 g
.setEdge(t
,f
,edgeInfo
)
1674 SQLTypeTranslation
= {types
.StringType
:'varchar(32)',
1675 types
.IntType
:'int',
1676 types
.FloatType
:'float'}
1678 def createTableFromRepr(rows
,tableName
,cursor
,typeTranslation
=None,
1679 optionalDict
=None,indexDict
=()):
1680 """Save rows into SQL tableName using cursor, with optional
1681 translations of columns to specific SQL types (specified
1682 by typeTranslation dict).
1683 - optionDict can specify columns that are allowed to be NULL.
1684 - indexDict can specify columns that must be indexed; columns
1685 whose names end in _id will be indexed by default.
1686 - rows must be an iterator which in turn returns dictionaries,
1687 each representing a tuple of values (indexed by their column
1691 row
=rows
.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1692 except StopIteration:
1693 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1695 createTableFromRow(cursor
, tableName
,row
,typeTranslation
,
1696 optionalDict
,indexDict
)
1699 storeRow(cursor
,tableName
,row
) # SAVE OUR FIRST ROW
1700 for row
in rows
: # NOW SAVE ALL THE ROWS
1701 storeRow(cursor
,tableName
,row
)
1703 def createTableFromRow(cursor
, tableName
, row
,typeTranslation
=None,
1704 optionalDict
=None,indexDict
=()):
1706 for col
,val
in row
.items(): # PREPARE SQL TYPES FOR COLUMNS
1708 if typeTranslation
!=None and col
in typeTranslation
:
1709 coltype
=typeTranslation
[col
] # USER-SUPPLIED TRANSLATION
1710 elif type(val
) in SQLTypeTranslation
:
1711 coltype
=SQLTypeTranslation
[type(val
)]
1712 else: # SEARCH FOR A COMPATIBLE TYPE
1713 for t
in SQLTypeTranslation
:
1714 if isinstance(val
,t
):
1715 coltype
=SQLTypeTranslation
[t
]
1718 raise TypeError("Don't know SQL type to use for %s" % col
)
1719 create_def
='%s %s' %(col
,coltype
)
1720 if optionalDict
==None or col
not in optionalDict
:
1721 create_def
+=' not null'
1722 create_defs
.append(create_def
)
1723 for col
in row
: # CREATE INDEXES FOR ID COLUMNS
1724 if col
[-3:]=='_id' or col
in indexDict
:
1725 create_defs
.append('index(%s)' % col
)
1726 cmd
='create table if not exists %s (%s)' % (tableName
,','.join(create_defs
))
1727 cursor
.execute(cmd
) # CREATE THE TABLE IN THE DATABASE
1730 def storeRow(cursor
, tableName
, row
):
1731 row_format
=','.join(len(row
)*['%s'])
1732 cmd
='insert into %s values (%s)' % (tableName
,row_format
)
1733 cursor
.execute(cmd
,tuple(row
.values()))
1735 def storeRowDelayed(cursor
, tableName
, row
):
1736 row_format
=','.join(len(row
)*['%s'])
1737 cmd
='insert delayed into %s values (%s)' % (tableName
,row_format
)
1738 cursor
.execute(cmd
,tuple(row
.values()))
1741 class TableGroup(dict):
1742 'provide attribute access to dbname qualified tablenames'
1743 def __init__(self
,db
='test',suffix
=None,**kw
):
1746 if suffix
is not None:
1748 for k
,v
in kw
.items():
1749 if v
is not None and '.' not in v
:
1750 v
=self
.db
+'.'+v
# ADD DATABASE NAME AS PREFIX
1752 def __getattr__(self
,k
):
1755 def sqlite_connect(*args
, **kwargs
):
1756 sqlite
= import_sqlite()
1757 connection
= sqlite
.connect(*args
, **kwargs
)
1758 cursor
= connection
.cursor()
1759 return connection
, cursor
1761 class DBServerInfo(object):
1762 'picklable reference to a database server'
1763 def __init__(self
, moduleName
='MySQLdb', *args
, **kwargs
):
1765 self
.__class
__ = _DBServerModuleDict
[moduleName
]
1767 raise ValueError('Module name not found in _DBServerModuleDict: '\
1769 self
.moduleName
= moduleName
1771 self
.kwargs
= kwargs
# connection arguments
1774 """returns cursor associated with the DB server info (reused)"""
1777 except AttributeError:
1778 self
._start
_connection
()
1781 def new_cursor(self
, arraysize
=None):
1782 """returns a NEW cursor; you must close it yourself! """
1783 if not hasattr(self
, '_connection'):
1784 self
._start
_connection
()
1785 cursor
= self
._connection
.cursor()
1786 if arraysize
is not None:
1787 cursor
.arraysize
= arraysize
1791 """Close file containing this database"""
1792 self
._cursor
.close()
1793 self
._connection
.close()
1795 del self
._connection
1797 def __getstate__(self
):
1798 """return all picklable arguments"""
1799 return dict(args
=self
.args
, kwargs
=self
.kwargs
,
1800 moduleName
=self
.moduleName
)
1803 class MySQLServerInfo(DBServerInfo
):
1804 'customized for MySQLdb SSCursor support via new_cursor()'
1805 def _start_connection(self
):
1806 self
._connection
,self
._cursor
= mysql_connect(*self
.args
, **self
.kwargs
)
1807 def new_cursor(self
, arraysize
=None):
1808 'provide streaming cursor support'
1810 conn
= self
._conn
_sscursor
1811 except AttributeError:
1812 self
._conn
_sscursor
,cursor
= mysql_connect(useStreaming
=True,
1813 *self
.args
, **self
.kwargs
)
1815 cursor
= self
._conn
_sscursor
.cursor()
1816 if arraysize
is not None:
1817 cursor
.arraysize
= arraysize
1820 DBServerInfo
.close(self
)
1822 self
._conn
_sscursor
.close()
1823 del self
._conn
_sscursor
1824 except AttributeError:
1826 def iter_keys(self
, db
, selectCols
=None, orderBy
='', map_f
=iter,
1827 cache_f
=lambda x
:[t
[0] for t
in x
], get_f
=None, **kwargs
):
1828 cursor
= self
.new_cursor()
1829 block_generator
= BlockGenerator(db
, self
, cursor
, selectCols
=None,
1830 orderBy
='', **kwargs
)
1831 return db
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
,
1832 map_f
=map_f
, fetch_f
=block_generator
)
1835 class BlockGenerator(object):
1836 def __init__(self
, db
, serverInfo
, cursor
, whereClause
='', **kwargs
):
1838 self
.serverInfo
= serverInfo
1839 self
.cursor
= cursor
1840 self
.kwargs
= kwargs
1841 self
.blockSize
= 10000
1842 self
.whereClause
= ''
1843 #self.__iter__() # start me up!
1845 ## def __iter__(self):
1846 ## 'initialize this iterator'
1847 ## self.db._select(cursor=cursor, selectCols='min(%s),max(%s),count(*)'
1848 ## % (self.db.name, self.db.name))
1849 ## l = self.cursor.fetchall()
1850 ## self.minID, self.maxID, self.count = l[0]
1851 ## self.start = self.minID - 1 # only works for int
1855 ## 'get the next start position'
1856 ## if self.start >= self.maxID:
1857 ## raise StopIteration
1861 'get the next block of data'
1863 ## start = self.next()
1864 ## except StopIteration:
1866 self
.db
._select
(cursor
=self
.cursor
, whereClause
=self
.whereClause
,
1867 limit
='LIMIT %s' % self
.blockSize
, **kwargs
)
1868 rows
= self
.cursor
.fetchall()
1870 if len(lastrow
) > 1: # extract the last ID value in this block
1871 start
= lastrow
[self
.db
.data
['id']]
1874 self
.whereClause
= '%s>%s' %(self
.db
.primary_key
,start
)
1879 class SQLiteServerInfo(DBServerInfo
):
1880 """picklable reference to a sqlite database"""
1881 def __init__(self
, database
, *args
, **kwargs
):
1882 """Takes same arguments as sqlite3.connect()"""
1883 DBServerInfo
.__init
__(self
, 'sqlite',
1884 SourceFileName(database
), # save abs path!
1886 def _start_connection(self
):
1887 self
._connection
,self
._cursor
= sqlite_connect(*self
.args
, **self
.kwargs
)
1888 def __getstate__(self
):
1889 if self
.args
[0] == ':memory:':
1890 raise ValueError('SQLite in-memory database is not picklable!')
1891 return DBServerInfo
.__getstate
__(self
)
1893 # list of DBServerInfo subclasses for different modules
1894 _DBServerModuleDict
= dict(MySQLdb
=MySQLServerInfo
, sqlite
=SQLiteServerInfo
)
1897 class MapView(object, UserDict
.DictMixin
):
1898 'general purpose 1:1 mapping defined by any SQL query'
1899 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None,
1900 serverInfo
=None, inverseSQL
=None, **kwargs
):
1901 self
.sourceDB
= sourceDB
1902 self
.targetDB
= targetDB
1903 self
.viewSQL
= viewSQL
1904 self
.inverseSQL
= inverseSQL
1906 if serverInfo
is not None: # get cursor from serverInfo
1907 cursor
= serverInfo
.cursor()
1909 try: # can we get it from our other db?
1910 serverInfo
= sourceDB
.serverInfo
1911 except AttributeError:
1912 raise ValueError('you must provide serverInfo or cursor!')
1914 cursor
= serverInfo
.cursor()
1915 self
.cursor
= cursor
1916 self
.serverInfo
= serverInfo
1917 self
.get_sql_format(False) # get sql formatter for this db interface
1918 _schemaModuleDict
= _schemaModuleDict
# default module list
1919 get_sql_format
= get_table_schema
1920 def __getitem__(self
, k
):
1921 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1922 raise KeyError('object is not in the sourceDB bound to this map!')
1923 sql
,params
= self
._format
_query
(self
.viewSQL
, (k
.id,))
1924 self
.cursor
.execute(sql
, params
) # formatted for this db interface
1925 t
= self
.cursor
.fetchmany(2) # get at most two rows
1927 raise KeyError('%s not found in MapView, or not unique'
1929 return self
.targetDB
[t
[0][0]] # get the corresponding object
1930 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, viewSQL
=0, serverInfo
=0,
1932 __getstate__
= standard_getstate
1933 __setstate__
= standard_setstate
1934 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
1935 setdefault
= read_only_error
1937 'only yield sourceDB items that are actually in this mapping!'
1938 for k
in self
.sourceDB
.itervalues():
1945 return [k
for k
in self
] # don't use list(self); causes infinite loop!
1946 def __invert__(self
):
1948 return self
._inverse
1949 except AttributeError:
1950 if self
.inverseSQL
is None:
1951 raise ValueError('this MapView has no inverseSQL!')
1952 self
._inverse
= self
.__class
__(self
.targetDB
, self
.sourceDB
,
1953 self
.inverseSQL
, self
.cursor
,
1954 serverInfo
=self
.serverInfo
,
1955 inverseSQL
=self
.viewSQL
)
1956 self
._inverse
._inverse
= self
1957 return self
._inverse
1959 class GraphViewEdgeDict(UserDict
.DictMixin
):
1960 'edge dictionary for GraphView: just pre-loaded on init'
1961 def __init__(self
, g
, k
):
1964 sql
,params
= self
.g
._format
_query
(self
.g
.viewSQL
, (k
.id,))
1965 self
.g
.cursor
.execute(sql
, params
) # run the query
1966 l
= self
.g
.cursor
.fetchall() # get results
1968 raise KeyError('key %s not in GraphView' % k
.id)
1969 self
.targets
= [t
[0] for t
in l
] # preserve order of the results
1970 d
= {} # also keep targetID:edgeID mapping
1971 if self
.g
.edgeDB
is not None: # save with edge info
1979 return len(self
.targets
)
1981 for k
in self
.targets
:
1982 yield self
.g
.targetDB
[k
]
1985 def iteritems(self
):
1986 if self
.g
.edgeDB
is not None: # save with edge info
1987 for k
in self
.targets
:
1988 yield (self
.g
.targetDB
[k
], self
.g
.edgeDB
[self
.targetDict
[k
]])
1989 else: # just save the list of targets, no edge info
1990 for k
in self
.targets
:
1991 yield (self
.g
.targetDB
[k
], None)
1992 def __getitem__(self
, o
, exitIfFound
=False):
1993 'for the specified target object, return its associated edge object'
1995 if o
.db
is not self
.g
.targetDB
:
1996 raise KeyError('key is not part of targetDB!')
1997 edgeID
= self
.targetDict
[o
.id]
1998 except AttributeError:
1999 raise KeyError('key has no id or db attribute?!')
2002 if self
.g
.edgeDB
is not None: # return the edge object
2003 return self
.g
.edgeDB
[edgeID
]
2004 else: # no edge info
2006 def __contains__(self
, o
):
2008 self
.__getitem
__(o
, True) # raise KeyError if not found
2012 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
2013 setdefault
= read_only_error
2015 class GraphView(MapView
):
2016 'general purpose graph interface defined by any SQL query'
2017 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None, edgeDB
=None,
2019 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
2020 self
.edgeDB
= edgeDB
2021 MapView
.__init
__(self
, sourceDB
, targetDB
, viewSQL
, cursor
, **kwargs
)
2022 def __getitem__(self
, k
):
2023 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
2024 raise KeyError('object is not in the sourceDB bound to this map!')
2025 return GraphViewEdgeDict(self
, k
)
2026 _pickleAttrs
= MapView
._pickleAttrs
.copy()
2027 _pickleAttrs
.update(dict(edgeDB
=0))
2029 # @CTB move to sqlgraph.py?
2031 class SQLSequence(SQLRow
, SequenceBase
):
2032 """Transparent access to a DB row representing a sequence.
2034 Use attrAlias dict to rename 'length' to something else.
2036 def _init_subclass(cls
, db
, **kwargs
):
2037 db
.seqInfoDict
= db
# db will act as its own seqInfoDict
2038 SQLRow
._init
_subclass
(db
=db
, **kwargs
)
2039 _init_subclass
= classmethod(_init_subclass
)
2040 def __init__(self
, id):
2041 SQLRow
.__init
__(self
, id)
2042 SequenceBase
.__init
__(self
)
2045 def strslice(self
,start
,end
):
2046 "Efficient access to slice of a sequence, useful for huge contigs"
2047 return self
._select
('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
2048 %(self
.db
._attrSQL
('seq'),start
+1,end
-start
))
2050 class DNASQLSequence(SQLSequence
):
2051 _seqtype
=DNA_SEQTYPE
2053 class RNASQLSequence(SQLSequence
):
2054 _seqtype
=RNA_SEQTYPE
2056 class ProteinSQLSequence(SQLSequence
):
2057 _seqtype
=PROTEIN_SEQTYPE