3 from __future__
import generators
5 from sequence
import SequenceBase
, DNA_SEQTYPE
, RNA_SEQTYPE
, PROTEIN_SEQTYPE
7 from classutil
import methodFactory
,standard_getstate
,\
8 override_rich_cmp
,generate_items
,get_bound_subclass
,standard_setstate
,\
9 get_valid_path
,standard_invert
,RecentValueDictionary
,read_only_error
,\
10 SourceFileName
, split_kwargs
17 class TupleDescriptor(object):
18 'return tuple entry corresponding to named attribute'
19 def __init__(self
, db
, attr
):
20 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
21 def __get__(self
, obj
, klass
):
22 return obj
._data
[self
.icol
]
23 def __set__(self
, obj
, val
):
24 raise AttributeError('this database is read-only!')
26 class TupleIDDescriptor(TupleDescriptor
):
27 def __set__(self
, obj
, val
):
28 raise AttributeError('''You cannot change obj.id directly.
29 Instead, use db[newID] = obj''')
31 class TupleDescriptorRW(TupleDescriptor
):
32 'read-write interface to named attribute'
33 def __init__(self
, db
, attr
):
35 self
.icol
= db
.data
[attr
] # index of this attribute in the tuple
36 self
.attrSQL
= db
._attrSQL
(attr
, sqlColumn
=True) # SQL column name
37 def __set__(self
, obj
, val
):
38 obj
.db
._update
(obj
.id, self
.attrSQL
, val
) # AND UPDATE THE DATABASE
39 obj
.save_local(self
.attr
, val
)
41 class SQLDescriptor(object):
42 'return attribute value by querying the database'
43 def __init__(self
, db
, attr
):
44 self
.selectSQL
= db
._attrSQL
(attr
) # SQL expression for this attr
45 def __get__(self
, obj
, klass
):
46 return obj
._select
(self
.selectSQL
)
47 def __set__(self
, obj
, val
):
48 raise AttributeError('this database is read-only!')
50 class SQLDescriptorRW(SQLDescriptor
):
51 'writeable proxy to corresponding column in the database'
52 def __set__(self
, obj
, val
):
53 obj
.db
._update
(obj
.id, self
.selectSQL
, val
) #just update the database
55 class ReadOnlyDescriptor(object):
56 'enforce read-only policy, e.g. for ID attribute'
57 def __init__(self
, db
, attr
):
59 def __get__(self
, obj
, klass
):
60 return getattr(obj
, self
.attr
)
61 def __set__(self
, obj
, val
):
62 raise AttributeError('attribute %s is read-only!' % self
.attr
)
65 def select_from_row(row
, what
):
66 "return value from SQL expression applied to this row"
67 sql
,params
= row
.db
._format
_query
('select %s from %s where %s=%%s limit 2'
68 % (what
,row
.db
.name
,row
.db
.primary_key
),
70 row
.db
.cursor
.execute(sql
, params
)
71 t
= row
.db
.cursor
.fetchmany(2) # get at most two rows
73 raise KeyError('%s[%s].%s not found, or not unique'
74 % (row
.db
.name
,str(row
.id),what
))
75 return t
[0][0] #return the single field we requested
77 def init_row_subclass(cls
, db
):
78 'add descriptors for db attributes'
79 for attr
in db
.data
: # bind all database columns
80 if attr
== 'id': # handle ID attribute specially
81 setattr(cls
, attr
, cls
._idDescriptor
(db
, attr
))
83 try: # check if this attr maps to an SQL column
84 db
._attrSQL
(attr
, columnNumber
=True)
85 except AttributeError: # treat as SQL expression
86 setattr(cls
, attr
, cls
._sqlDescriptor
(db
, attr
))
87 else: # treat as interface to our stored tuple
88 setattr(cls
, attr
, cls
._columnDescriptor
(db
, attr
))
91 """get list of column names as our attributes """
92 return self
.db
.data
.keys()
95 """Provides attribute interface to a database tuple.
96 Storing the data as a tuple instead of a standard Python object
97 (which is stored using __dict__) uses about five-fold less
98 memory and is also much faster (the tuples returned from the
99 DB API fetch are simply referenced by the TupleO, with no
100 need to copy their individual values into __dict__).
102 This class follows the 'subclass binding' pattern, which
103 means that instead of using __getattr__ to process all
104 attribute requests (which is un-modular and leads to all
105 sorts of trouble), we follow Python's new model for
106 customizing attribute access, namely Descriptors.
107 We use classutil.get_bound_subclass() to automatically
108 create a subclass of this class, calling its _init_subclass()
109 class method to add all the descriptors needed for the
110 database table to which it is bound.
112 See the Pygr Developer Guide section of the docs for a
113 complete discussion of the subclass binding pattern."""
114 _columnDescriptor
= TupleDescriptor
115 _idDescriptor
= TupleIDDescriptor
116 _sqlDescriptor
= SQLDescriptor
117 _init_subclass
= classmethod(init_row_subclass
)
118 _select
= select_from_row
120 def __init__(self
, data
):
121 self
._data
= data
# save our data tuple
123 def insert_and_cache_id(self
, l
, **kwargs
):
124 'insert tuple into db and cache its rowID on self'
125 self
.db
._insert
(l
) # save to database
127 rowID
= kwargs
['id'] # use the ID supplied by user
129 rowID
= self
.db
.get_insert_id() # get auto-inc ID value
130 self
.cache_id(rowID
) # cache this ID on self
132 class TupleORW(TupleO
):
133 'read-write version of TupleO'
134 _columnDescriptor
= TupleDescriptorRW
135 insert_and_cache_id
= insert_and_cache_id
136 def __init__(self
, data
, newRow
=False, **kwargs
):
137 if not newRow
: # just cache data from the database
140 self
._data
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
141 self
.insert_and_cache_id(self
._data
, **kwargs
)
142 def cache_id(self
,row_id
):
143 self
.save_local('id',row_id
)
144 def save_local(self
,attr
,val
):
145 icol
= self
._attrcol
[attr
]
147 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
148 except TypeError: # TUPLE CAN'T STORE NEW VALUE, SO USE A LIST
149 self
._data
= list(self
._data
)
150 self
._data
[icol
] = val
# FINALLY UPDATE OUR LOCAL CACHE
152 TupleO
._RWClass
= TupleORW
# record this as writeable interface class
154 class ColumnDescriptor(object):
155 'read-write interface to column in a database, cached in obj.__dict__'
156 def __init__(self
, db
, attr
, readOnly
= False):
158 self
.col
= db
._attrSQL
(attr
, sqlColumn
=True) # MAP THIS TO SQL COLUMN NAME
161 self
.__class
__ = self
._readOnlyClass
162 def __get__(self
, obj
, objtype
):
164 return obj
.__dict
__[self
.attr
]
165 except KeyError: # NOT IN CACHE. TRY TO GET IT FROM DATABASE
166 if self
.col
==self
.db
.primary_key
:
168 self
.db
._select
('where %s=%%s' % self
.db
.primary_key
,(obj
.id,),self
.col
)
169 l
= self
.db
.cursor
.fetchall()
171 raise AttributeError('db row not found or not unique!')
172 obj
.__dict
__[self
.attr
] = l
[0][0] # UPDATE THE CACHE
174 def __set__(self
, obj
, val
):
175 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
176 self
.db
._update
(obj
.id, self
.col
, val
) # UPDATE THE DATABASE
177 obj
.__dict
__[self
.attr
] = val
# UPDATE THE CACHE
179 ## m = self.consequences
180 ## except AttributeError:
182 ## m(obj,val) # GENERATE CONSEQUENCES
183 ## def bind_consequences(self,f):
184 ## 'make function f be run as consequences method whenever value is set'
186 ## self.consequences = new.instancemethod(f,self,self.__class__)
187 class ReadOnlyColumnDesc(ColumnDescriptor
):
188 def __set__(self
, obj
, val
):
189 raise AttributeError('The ID of a database object is not writeable.')
190 ColumnDescriptor
._readOnlyClass
= ReadOnlyColumnDesc
194 class SQLRow(object):
195 """Provide transparent interface to a row in the database: attribute access
196 will be mapped to SELECT of the appropriate column, but data is not
197 cached on this object.
199 _columnDescriptor
= _sqlDescriptor
= SQLDescriptor
200 _idDescriptor
= ReadOnlyDescriptor
201 _init_subclass
= classmethod(init_row_subclass
)
202 _select
= select_from_row
204 def __init__(self
, rowID
):
208 class SQLRowRW(SQLRow
):
209 'read-write version of SQLRow'
210 _columnDescriptor
= SQLDescriptorRW
211 insert_and_cache_id
= insert_and_cache_id
212 def __init__(self
, rowID
, newRow
=False, **kwargs
):
213 if not newRow
: # just cache data from the database
214 return self
.cache_id(rowID
)
215 l
= self
.db
.tuple_from_dict(kwargs
) # convert to tuple
216 self
.insert_and_cache_id(l
, **kwargs
)
217 def cache_id(self
, rowID
):
220 SQLRow
._RWClass
= SQLRowRW
224 def list_to_dict(names
, values
):
225 'return dictionary of those named args that are present in values[]'
227 for i
,v
in enumerate(values
):
235 def get_name_cursor(name
=None, **kwargs
):
236 '''get table name and cursor by parsing name or using configFile.
237 If neither provided, will try to get via your MySQL config file.
238 If connect is None, will use MySQLdb.connect()'''
240 argList
= name
.split() # TREAT AS WS-SEPARATED LIST
242 name
= argList
[0] # USE 1ST ARG AS TABLE NAME
243 argnames
= ('host','user','passwd') # READ ARGS IN THIS ORDER
244 kwargs
= kwargs
.copy() # a copy we can overwrite
245 kwargs
.update(list_to_dict(argnames
, argList
[1:]))
246 serverInfo
= DBServerInfo(**kwargs
)
247 return name
,serverInfo
.cursor(),serverInfo
249 def mysql_connect(connect
=None, configFile
=None, useStreaming
=False, **args
):
250 """return connection and cursor objects, using .my.cnf if necessary"""
251 kwargs
= args
.copy() # a copy we can modify
252 if 'user' not in kwargs
and configFile
is None: #Find where config file is
253 osname
= platform
.system()
254 if osname
in('Microsoft', 'Windows'): # Machine is a Windows box
256 try: # handle case where WINDIR not defined by Windows...
257 windir
= os
.environ
['WINDIR']
258 paths
+= [(windir
, 'my.ini'), (windir
, 'my.cnf')]
262 sysdrv
= os
.environ
['SYSTEMDRIVE']
263 paths
+= [(sysdrv
, os
.path
.sep
+ 'my.ini'),
264 (sysdrv
, os
.path
.sep
+ 'my.cnf')]
268 configFile
= get_valid_path(*paths
)
269 else: # treat as normal platform with home directories
270 configFile
= os
.path
.join(os
.path
.expanduser('~'), '.my.cnf')
272 # allows for a local mysql local configuration file to be read
273 # from the current directory
274 configFile
= configFile
or os
.path
.join( os
.getcwd(), 'mysql.cnf' )
276 if configFile
and os
.path
.exists(configFile
):
277 kwargs
['read_default_file'] = configFile
278 connect
= None # force it to use MySQLdb
281 connect
= MySQLdb
.connect
282 kwargs
['compress'] = True
283 if useStreaming
: # use server side cursors for scalable result sets
285 from MySQLdb
import cursors
286 kwargs
['cursorclass'] = cursors
.SSCursor
287 except (ImportError, AttributeError):
289 conn
= connect(**kwargs
)
290 cursor
= conn
.cursor()
293 _mysqlMacros
= dict(IGNORE
='ignore', REPLACE
='replace',
294 AUTO_INCREMENT
='AUTO_INCREMENT', SUBSTRING
='substring',
295 SUBSTR_FROM
='FROM', SUBSTR_FOR
='FOR')
297 def mysql_table_schema(self
, analyzeSchema
=True):
298 'retrieve table schema from a MySQL database, save on self'
300 self
._format
_query
= SQLFormatDict(MySQLdb
.paramstyle
, _mysqlMacros
)
301 if not analyzeSchema
:
303 self
.clear_schema() # reset settings and dictionaries
304 self
.cursor
.execute('describe %s' % self
.name
) # get info about columns
305 columns
= self
.cursor
.fetchall()
306 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
307 for icol
,c
in enumerate(columns
):
309 self
.columnName
.append(field
) # list of columns in same order as table
310 if c
[3] == "PRI": # record as primary key
311 if self
.primary_key
is None:
312 self
.primary_key
= field
315 self
.primary_key
.append(field
)
316 except AttributeError:
317 self
.primary_key
= [self
.primary_key
,field
]
318 if c
[1][:3].lower() == 'int':
319 self
.usesIntID
= True
321 self
.usesIntID
= False
323 self
.indexed
[field
] = icol
324 self
.description
[field
] = self
.cursor
.description
[icol
]
325 self
.columnType
[field
] = c
[1] # SQL COLUMN TYPE
327 _sqliteMacros
= dict(IGNORE
='or ignore', REPLACE
='insert or replace',
328 AUTO_INCREMENT
='', SUBSTRING
='substr',
329 SUBSTR_FROM
=',', SUBSTR_FOR
=',')
332 'import sqlite3 (for Python 2.5+) or pysqlite2 for earlier Python versions'
334 import sqlite3
as sqlite
336 from pysqlite2
import dbapi2
as sqlite
339 def sqlite_table_schema(self
, analyzeSchema
=True):
340 'retrieve table schema from a sqlite3 database, save on self'
341 sqlite
= import_sqlite()
342 self
._format
_query
= SQLFormatDict(sqlite
.paramstyle
, _sqliteMacros
)
343 if not analyzeSchema
:
345 self
.clear_schema() # reset settings and dictionaries
346 self
.cursor
.execute('PRAGMA table_info("%s")' % self
.name
)
347 columns
= self
.cursor
.fetchall()
348 self
.cursor
.execute('select * from %s limit 1' % self
.name
) # descriptions
349 for icol
,c
in enumerate(columns
):
351 self
.columnName
.append(field
) # list of columns in same order as table
352 self
.description
[field
] = self
.cursor
.description
[icol
]
353 self
.columnType
[field
] = c
[2] # SQL COLUMN TYPE
354 self
.cursor
.execute('select name from sqlite_master where tbl_name="%s" and type="index" and sql is null' % self
.name
) # get primary key / unique indexes
355 for indexname
in self
.cursor
.fetchall(): # search indexes for primary key
356 self
.cursor
.execute('PRAGMA index_info("%s")' % indexname
)
357 l
= self
.cursor
.fetchall() # get list of columns in this index
358 if len(l
) == 1: # assume 1st single-column unique index is primary key!
359 self
.primary_key
= l
[0][2]
360 break # done searching for primary key!
361 if self
.primary_key
is None: # grrr, INTEGER PRIMARY KEY handled differently
362 self
.cursor
.execute('select sql from sqlite_master where tbl_name="%s" and type="table"' % self
.name
)
363 sql
= self
.cursor
.fetchall()[0][0]
364 for columnSQL
in sql
[sql
.index('(') + 1 :].split(','):
365 if 'primary key' in columnSQL
.lower(): # must be the primary key!
366 col
= columnSQL
.split()[0] # get column name
367 if col
in self
.columnType
:
368 self
.primary_key
= col
369 break # done searching for primary key!
371 raise ValueError('unknown primary key %s in table %s'
373 if self
.primary_key
is not None: # check its type
374 if self
.columnType
[self
.primary_key
] == 'int' or \
375 self
.columnType
[self
.primary_key
] == 'integer':
376 self
.usesIntID
= True
378 self
.usesIntID
= False
380 class SQLFormatDict(object):
381 '''Perform SQL keyword replacements for maintaining compatibility across
382 a wide range of SQL backends. Uses Python dict-based string format
383 function to do simple string replacements, and also to convert
384 params list to the paramstyle required for this interface.
385 Create by passing a dict of macros and the db-api paramstyle:
386 sfd = SQLFormatDict("qmark", substitutionDict)
388 Then transform queries+params as follows; input should be "format" style:
389 sql,params = sfd("select * from foo where id=%s and val=%s", (myID,myVal))
390 cursor.execute(sql, params)
392 _paramFormats
= dict(pyformat
='%%(%d)s', numeric
=':%d', named
=':%d',
393 qmark
='(ignore)', format
='(ignore)')
394 def __init__(self
, paramstyle
, substitutionDict
={}):
395 self
.substitutionDict
= substitutionDict
.copy()
396 self
.paramstyle
= paramstyle
397 self
.paramFormat
= self
._paramFormats
[paramstyle
]
398 self
.makeDict
= (paramstyle
== 'pyformat' or paramstyle
== 'named')
399 if paramstyle
== 'qmark': # handle these as simple substitution
400 self
.substitutionDict
['?'] = '?'
401 elif paramstyle
== 'format':
402 self
.substitutionDict
['?'] = '%s'
403 def __getitem__(self
, k
):
404 'apply correct substitution for this SQL interface'
406 return self
.substitutionDict
[k
] # apply our substitutions
409 if k
== '?': # sequential parameter
410 s
= self
.paramFormat
% self
.iparam
411 self
.iparam
+= 1 # advance to the next parameter
413 raise KeyError('unknown macro: %s' % k
)
414 def __call__(self
, sql
, paramList
):
415 'returns corrected sql,params for this interface'
416 self
.iparam
= 1 # DB-ABI param indexing begins at 1
417 sql
= sql
.replace('%s', '%(?)s') # convert format into pyformat
418 s
= sql
% self
# apply all %(x)s replacements in sql
419 if self
.makeDict
: # construct a params dict
421 for i
,param
in enumerate(paramList
):
422 paramDict
[str(i
+ 1)] = param
#DB-ABI param indexing begins at 1
424 else: # just return the original params list
427 def get_table_schema(self
, analyzeSchema
=True):
428 'run the right schema function based on type of db server connection'
430 modname
= self
.cursor
.__class
__.__module
__
431 except AttributeError:
432 raise ValueError('no cursor object or module information!')
434 schema_func
= self
._schemaModuleDict
[modname
]
436 raise KeyError('''unknown db module: %s. Use _schemaModuleDict
437 attribute to supply a method for obtaining table schema
438 for this module''' % modname
)
439 schema_func(self
, analyzeSchema
) # run the schema function
442 _schemaModuleDict
= {'MySQLdb.cursors':mysql_table_schema
,
443 'pysqlite2.dbapi2':sqlite_table_schema
,
444 'sqlite3':sqlite_table_schema
}
446 class SQLTableBase(object, UserDict
.DictMixin
):
447 "Store information about an SQL table as dict keyed by primary key"
448 _schemaModuleDict
= _schemaModuleDict
# default module list
449 get_table_schema
= get_table_schema
450 def __init__(self
,name
,cursor
=None,itemClass
=None,attrAlias
=None,
451 clusterKey
=None,createTable
=None,graph
=None,maxCache
=None,
452 arraysize
=1024, itemSliceClass
=None, dropIfExists
=False,
453 serverInfo
=None, autoGC
=True, orderBy
=None,
454 writeable
=False, **kwargs
):
455 if autoGC
: # automatically garbage collect unused objects
456 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
458 self
._weakValueDict
= {}
460 self
.orderBy
= orderBy
461 self
.writeable
= writeable
463 if serverInfo
is not None: # get cursor from serverInfo
464 cursor
= serverInfo
.cursor()
465 else: # try to read connection info from name or config file
466 name
,cursor
,serverInfo
= get_name_cursor(name
,**kwargs
)
468 warnings
.warn("""The cursor argument is deprecated. Use serverInfo instead! """,
469 DeprecationWarning, stacklevel
=2)
471 if createTable
is not None: # RUN COMMAND TO CREATE THIS TABLE
472 if dropIfExists
: # get rid of any existing table
473 cursor
.execute('drop table if exists ' + name
)
474 self
.get_table_schema(False) # check dbtype, init _format_query
475 sql
,params
= self
._format
_query
(createTable
, ()) # apply macros
476 cursor
.execute(sql
) # create the table
478 if graph
is not None:
480 if maxCache
is not None:
481 self
.maxCache
= maxCache
482 if arraysize
is not None:
483 self
.arraysize
= arraysize
484 cursor
.arraysize
= arraysize
485 self
.get_table_schema() # get schema of columns to serve as attrs
486 self
.data
= {} # map of all attributes, including aliases
487 for icol
,field
in enumerate(self
.columnName
):
488 self
.data
[field
] = icol
# 1st add mappings to columns
490 self
.data
['id']=self
.data
[self
.primary_key
]
491 except (KeyError,TypeError):
493 if hasattr(self
,'_attr_alias'): # apply attribute aliases for this class
494 self
.addAttrAlias(False,**self
._attr
_alias
)
495 self
.objclass(itemClass
) # NEED TO SUBCLASS OUR ITEM CLASS
496 if itemSliceClass
is not None:
497 self
.itemSliceClass
= itemSliceClass
498 get_bound_subclass(self
, 'itemSliceClass', self
.name
) # need to subclass itemSliceClass
499 if attrAlias
is not None: # ADD ATTRIBUTE ALIASES
500 self
.attrAlias
= attrAlias
# RECORD FOR PICKLING PURPOSES
501 self
.data
.update(attrAlias
)
502 if clusterKey
is not None:
503 self
.clusterKey
=clusterKey
504 if serverInfo
is not None:
505 self
.serverInfo
= serverInfo
508 self
._select
(selectCols
='count(*)')
509 return self
.cursor
.fetchone()[0]
512 def __cmp__(self
, other
):
513 'only match self and no other!'
517 return cmp(id(self
), id(other
))
518 _pickleAttrs
= dict(name
=0, clusterKey
=0, maxCache
=0, arraysize
=0,
519 attrAlias
=0, serverInfo
=0, autoGC
=0, orderBy
=0,
521 __getstate__
= standard_getstate
522 def __setstate__(self
,state
):
523 # default cursor provisioning by worldbase is deprecated!
524 ## if 'serverInfo' not in state: # hmm, no address for db server?
525 ## try: # SEE IF WE CAN GET CURSOR DIRECTLY FROM RESOURCE DATABASE
526 ## from Data import getResource
527 ## state['cursor'] = getResource.getTableCursor(state['name'])
528 ## except ImportError:
529 ## pass # FAILED, SO TRY TO GET A CURSOR IN THE USUAL WAYS...
530 self
.__init
__(**state
)
532 return '<SQL table '+self
.name
+'>'
534 def clear_schema(self
):
535 'reset all schema information for this table'
539 self
.usesIntID
= None
540 self
.primary_key
= None
542 def _attrSQL(self
,attr
,sqlColumn
=False,columnNumber
=False):
543 "Translate python attribute name to appropriate SQL expression"
544 try: # MAKE SURE THIS ATTRIBUTE CAN BE MAPPED TO DATABASE EXPRESSION
545 field
=self
.data
[attr
]
547 raise AttributeError('attribute %s not a valid column or alias in %s'
549 if sqlColumn
: # ENSURE THAT THIS TRULY MAPS TO A COLUMN NAME IN THE DB
550 try: # CHECK IF field IS COLUMN NUMBER
551 return self
.columnName
[field
] # RETURN SQL COLUMN NAME
553 try: # CHECK IF field IS SQL COLUMN NAME
554 return self
.columnName
[self
.data
[field
]] # THIS WILL JUST RETURN field...
555 except (KeyError,TypeError):
556 raise AttributeError('attribute %s does not map to an SQL column in %s'
559 try: # CHECK IF field IS A COLUMN NUMBER
560 return field
+0 # ONLY RETURN AN INTEGER
562 try: # CHECK IF field IS ITSELF THE SQL COLUMN NAME
563 return self
.data
[field
]+0 # ONLY RETURN AN INTEGER
564 except (KeyError,TypeError):
565 raise ValueError('attribute %s does not map to a SQL column!' % attr
)
566 if isinstance(field
,types
.StringType
):
567 attr
=field
# USE ALIASED EXPRESSION FOR DATABASE SELECT INSTEAD OF attr
569 attr
=self
.primary_key
571 def addAttrAlias(self
,saveToPickle
=True,**kwargs
):
572 """Add new attributes as aliases of existing attributes.
573 They can be specified either as named args:
574 t.addAttrAlias(newattr=oldattr)
575 or by passing a dictionary kwargs whose keys are newattr
576 and values are oldattr:
577 t.addAttrAlias(**kwargs)
578 saveToPickle=True forces these aliases to be saved if object is pickled.
581 self
.attrAlias
.update(kwargs
)
582 for key
,val
in kwargs
.items():
583 try: # 1st CHECK WHETHER val IS AN EXISTING COLUMN / ALIAS
584 self
.data
[val
]+0 # CHECK WHETHER val MAPS TO A COLUMN NUMBER
585 raise KeyError # YES, val IS ACTUAL SQL COLUMN NAME, SO SAVE IT DIRECTLY
586 except TypeError: # val IS ITSELF AN ALIAS
587 self
.data
[key
] = self
.data
[val
] # SO MAP TO WHAT IT MAPS TO
588 except KeyError: # TREAT AS ALIAS TO SQL EXPRESSION
590 def objclass(self
,oclass
=None):
591 "Create class representing a row in this table by subclassing oclass, adding data"
592 if oclass
is not None: # use this as our base itemClass
593 self
.itemClass
= oclass
595 self
.itemClass
= self
.itemClass
._RWClass
# use its writeable version
596 oclass
= get_bound_subclass(self
, 'itemClass', self
.name
,
597 subclassArgs
=dict(db
=self
)) # bind itemClass
598 if issubclass(oclass
, TupleO
):
599 oclass
._attrcol
= self
.data
# BIND ATTRIBUTE LIST TO TUPLEO INTERFACE
600 if hasattr(oclass
,'_tableclass') and not isinstance(self
,oclass
._tableclass
):
601 self
.__class
__=oclass
._tableclass
# ROW CLASS CAN OVERRIDE OUR CURRENT TABLE CLASS
602 def _select(self
, whereClause
='', params
=(), selectCols
='t1.*',
603 cursor
=None, orderBy
='', limit
=''):
604 'execute the specified query but do not fetch'
605 sql
,params
= self
._format
_query
('select %s from %s t1 %s %s %s'
606 % (selectCols
, self
.name
, whereClause
, orderBy
,
609 self
.cursor
.execute(sql
, params
)
611 cursor
.execute(sql
, params
)
612 def select(self
,whereClause
,params
=None,oclass
=None,selectCols
='t1.*'):
613 "Generate the list of objects that satisfy the database SELECT"
615 oclass
=self
.itemClass
616 self
._select
(whereClause
,params
,selectCols
)
617 l
=self
.cursor
.fetchall()
619 yield self
.cacheItem(t
,oclass
)
620 def query(self
,**kwargs
):
621 'query for intersection of all specified kwargs, returned as iterator'
624 for k
,v
in kwargs
.items(): # CONSTRUCT THE LIST OF WHERE CLAUSES
625 if v
is None: # CONVERT TO SQL NULL TEST
626 criteria
.append('%s IS NULL' % self
._attrSQL
(k
))
627 else: # TEST FOR EQUALITY
628 criteria
.append('%s=%%s' % self
._attrSQL
(k
))
630 return self
.select('where '+' and '.join(criteria
),params
)
631 def _update(self
,row_id
,col
,val
):
632 'update a single field in the specified row to the specified value'
633 sql
,params
= self
._format
_query
('update %s set %s=%%s where %s=%%s'
634 %(self
.name
,col
,self
.primary_key
),
636 self
.cursor
.execute(sql
, params
)
639 return t
[self
.data
['id']] # GET ID FROM TUPLE
640 except TypeError: # treat as alias
641 return t
[self
.data
[self
.data
['id']]]
642 def cacheItem(self
,t
,oclass
):
643 'get obj from cache if possible, or construct from tuple'
646 except KeyError: # NO PRIMARY KEY? IGNORE THE CACHE.
648 try: # IF ALREADY LOADED IN OUR DICTIONARY, JUST RETURN THAT ENTRY
649 return self
._weakValueDict
[id]
653 self
._weakValueDict
[id] = o
# CACHE THIS ITEM IN OUR DICTIONARY
655 def cache_items(self
,rows
,oclass
=None):
657 oclass
=self
.itemClass
659 yield self
.cacheItem(t
,oclass
)
660 def foreignKey(self
,attr
,k
):
661 'get iterator for objects with specified foreign key value'
662 return self
.select('where %s=%%s'%attr
,(k
,))
663 def limit_cache(self
):
664 'APPLY maxCache LIMIT TO CACHE SIZE'
666 if self
.maxCache
<len(self
._weakValueDict
):
667 self
._weakValueDict
.clear()
668 except AttributeError:
671 def get_new_cursor(self
):
672 """Return a new cursor object, or None if not possible """
674 new_cursor
= self
.serverInfo
.new_cursor
675 except AttributeError:
677 return new_cursor(self
.arraysize
)
679 def generic_iterator(self
, cursor
=None, fetch_f
=None, cache_f
=None,
681 'generic iterator that runs fetch, cache and map functions'
682 if fetch_f
is None: # JUST USE CURSOR'S PREFERRED CHUNK SIZE
684 fetch_f
= self
.cursor
.fetchmany
685 else: # isolate this iter from other queries
686 fetch_f
= cursor
.fetchmany
688 cache_f
= self
.cache_items
691 rows
= fetch_f() # FETCH THE NEXT SET OF ROWS
692 print 'fetch_f rows:', len(rows
)
693 if len(rows
)==0: # NO MORE DATA SO ALL DONE
695 for v
in map_f(cache_f(rows
)): # CACHE AND GENERATE RESULTS
697 if cursor
is not None: # close iterator now that we're done
699 def tuple_from_dict(self
, d
):
700 'transform kwarg dict into tuple for storing in database'
701 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
702 for col
,icol
in self
.data
.items():
705 except (KeyError,TypeError):
708 def tuple_from_obj(self
, obj
):
709 'transform object attributes into tuple for storing in database'
710 l
= [None]*len(self
.description
) # DEFAULT COLUMN VALUES ARE NULL
711 for col
,icol
in self
.data
.items():
713 l
[icol
] = getattr(obj
,col
)
714 except (AttributeError,TypeError):
717 def _insert(self
, l
):
718 '''insert tuple into the database. Note this uses the MySQL
719 extension REPLACE, which overwrites any duplicate key.'''
720 s
= '%(REPLACE)s into ' + self
.name
+ ' values (' \
721 + ','.join(['%s']*len(l
)) + ')'
722 sql
,params
= self
._format
_query
(s
, l
)
723 self
.cursor
.execute(sql
, params
)
724 def insert(self
, obj
):
725 '''insert new row by transforming obj to tuple of values'''
726 l
= self
.tuple_from_obj(obj
)
728 def get_insert_id(self
):
729 'get the primary key value for the last INSERT'
730 try: # ATTEMPT TO GET ASSIGNED ID FROM DB
731 auto_id
= self
.cursor
.lastrowid
732 except AttributeError: # CURSOR DOESN'T SUPPORT lastrowid
733 raise NotImplementedError('''your db lacks lastrowid support?''')
735 raise ValueError('lastrowid is None so cannot get ID from INSERT!')
737 def new(self
, **kwargs
):
738 'return a new record with the assigned attributes, added to DB'
739 if not self
.writeable
:
740 raise ValueError('this database is read only!')
741 obj
= self
.itemClass(None, newRow
=True, **kwargs
) # saves itself to db
742 self
._weakValueDict
[obj
.id] = obj
# AND SAVE TO OUR LOCAL DICT CACHE
744 def clear_cache(self
):
746 self
._weakValueDict
.clear()
747 def __delitem__(self
, k
):
748 if not self
.writeable
:
749 raise ValueError('this database is read only!')
750 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
751 % (self
.name
,self
.primary_key
),(k
,))
752 self
.cursor
.execute(sql
, params
)
754 del self
._weakValueDict
[k
]
758 def getKeys(self
,queryOption
='', selectCols
=None):
759 'uses db select; does not force load'
760 if selectCols
is None:
761 selectCols
=self
.primary_key
762 if queryOption
=='' and self
.orderBy
is not None:
763 queryOption
= self
.orderBy
# apply default ordering
764 self
.cursor
.execute('select %s from %s %s'
765 %(selectCols
,self
.name
,queryOption
))
766 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
768 def iter_keys(self
, selectCols
=None, orderBy
='', map_f
=iter,
769 cache_f
=lambda x
:[t
[0] for t
in x
], get_f
=None, **kwargs
):
770 'guarantee correct iteration insulated from other queries'
771 if selectCols
is None:
772 selectCols
=self
.primary_key
773 if orderBy
=='' and self
.orderBy
is not None:
774 orderBy
= self
.orderBy
# apply default ordering
775 cursor
= self
.get_new_cursor()
776 if cursor
: # got our own cursor, guaranteeing query isolation
778 iter_f
= self
.db
.serverInfo
.iter_keysAXAFDAA
779 except AttributeError:
780 self
._select
(cursor
=cursor
, selectCols
=selectCols
,
781 orderBy
=orderBy
, **kwargs
)
782 return self
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
,
784 else: # use custom iter_keys() method from serverInfo
785 return iter_f(self
, cursor
, selectCols
=selectCols
,
786 orderBy
=orderBy
, map_f
=map_f
, cache_f
=cache_f
,
788 else: # must pre-fetch all keys to ensure query isolation
789 if get_f
is not None:
792 return iter(self
.keys())
794 class SQLTable(SQLTableBase
):
795 "Provide on-the-fly access to rows in the database, caching the results in dict"
796 itemClass
= TupleO
# our default itemClass; constructor can override
799 def load(self
,oclass
=None):
800 "Load all data from the table"
801 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
802 return self
._isLoaded
803 except AttributeError:
806 oclass
=self
.itemClass
807 self
.cursor
.execute('select * from %s' % self
.name
)
808 l
=self
.cursor
.fetchall()
809 self
._weakValueDict
= {} # just store the whole dataset in memory
811 self
.cacheItem(t
,oclass
) # CACHE IT IN LOCAL DICTIONARY
812 self
._isLoaded
=True # MARK THIS CONTAINER AS FULLY LOADED
814 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
816 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
817 except KeyError: # NOT FOUND, SO TRY THE DATABASE
818 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 2'
819 % (self
.name
,self
.primary_key
),(k
,))
820 self
.cursor
.execute(sql
, params
)
821 l
= self
.cursor
.fetchmany(2) # get at most 2 rows
823 raise KeyError('%s not found in %s, or not unique' %(str(k
),self
.name
))
825 return self
.cacheItem(l
[0],self
.itemClass
) # CACHE IT IN LOCAL DICTIONARY
826 def __setitem__(self
, k
, v
):
827 if not self
.writeable
:
828 raise ValueError('this database is read only!')
832 except AttributeError:
833 raise ValueError('object not bound to itemClass for this db!')
838 except AttributeError:
840 else: # delete row with old ID
842 v
.cache_id(k
) # cache the new ID on the object
843 self
.insert(v
) # SAVE TO THE RELATIONAL DB SERVER
844 self
._weakValueDict
[k
] = v
# CACHE THIS ITEM IN OUR DICTIONARY
846 'forces load of entire table into memory'
848 return [(k
,self
[k
]) for k
in self
] # apply orderBy rules...
850 'uses arraysize / maxCache and fetchmany() to manage data transfer'
851 return iter_keys(self
, selectCols
='*', cache_f
=None,
852 map_f
=generate_items
, get_f
=self
.items
)
854 'forces load of entire table into memory'
856 return [self
[k
] for k
in self
] # apply orderBy rules...
857 def itervalues(self
):
858 'uses arraysize / maxCache and fetchmany() to manage data transfer'
859 return iter_keys(self
, selectCols
='*', cache_f
=None, get_f
=self
.values
)
861 def getClusterKeys(self
,queryOption
=''):
862 'uses db select; does not force load'
863 self
.cursor
.execute('select distinct %s from %s %s'
864 %(self
.clusterKey
,self
.name
,queryOption
))
865 return [t
[0] for t
in self
.cursor
.fetchall()] # GET ALL AT ONCE, SINCE OTHER CALLS MAY REUSE THIS CURSOR...
868 class SQLTableClustered(SQLTable
):
869 '''use clusterKey to load a whole cluster of rows at once,
870 specifically, all rows that share the same clusterKey value.'''
871 def __init__(self
, *args
, **kwargs
):
872 kwargs
= kwargs
.copy() # get a copy we can alter
873 kwargs
['autoGC'] = False # don't use WeakValueDictionary
874 SQLTable
.__init
__(self
, *args
, **kwargs
)
876 return getKeys(self
,'order by %s' %self
.clusterKey
)
877 def clusterkeys(self
):
878 return getClusterKeys(self
, 'order by %s' %self
.clusterKey
)
879 def __getitem__(self
,k
):
881 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
882 except KeyError: # NOT FOUND, SO TRY THE DATABASE
883 sql
,params
= self
._format
_query
('select t2.* from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s'
884 % (self
.name
,self
.name
,self
.primary_key
,
885 self
.clusterKey
,self
.clusterKey
),(k
,))
886 self
.cursor
.execute(sql
, params
)
887 l
=self
.cursor
.fetchall()
889 for t
in l
: # LOAD THE ENTIRE CLUSTER INTO OUR LOCAL CACHE
890 self
.cacheItem(t
,self
.itemClass
)
891 return self
._weakValueDict
[k
] # should be in cache, if row k exists
892 def itercluster(self
,cluster_id
):
893 'iterate over all items from the specified cluster'
895 return self
.select('where %s=%%s'%self
.clusterKey
,(cluster_id
,))
896 def fetch_cluster(self
):
897 'use self.cursor.fetchmany to obtain all rows for next cluster'
898 icol
= self
._attrSQL
(self
.clusterKey
,columnNumber
=True)
901 rows
= self
._fetch
_cluster
_cache
# USE SAVED ROWS FROM PREVIOUS CALL
902 del self
._fetch
_cluster
_cache
903 except AttributeError:
904 rows
= self
.cursor
.fetchmany()
906 cluster_id
= rows
[0][icol
]
910 for i
,t
in enumerate(rows
): # CHECK THAT ALL ROWS FROM THIS CLUSTER
911 if cluster_id
!= t
[icol
]: # START OF A NEW CLUSTER
912 result
+= rows
[:i
] # RETURN ROWS OF CURRENT CLUSTER
913 self
._fetch
_cluster
_cache
= rows
[i
:] # SAVE NEXT CLUSTER
916 rows
= self
.cursor
.fetchmany() # GET NEXT SET OF ROWS
918 def itervalues(self
):
919 'uses arraysize / maxCache and fetchmany() to manage data transfer'
920 cursor
= self
.get_new_cursor()
921 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
922 return self
.generic_iterator(cursor
, self
.fetch_cluster
)
924 'uses arraysize / maxCache and fetchmany() to manage data transfer'
925 cursor
= self
.get_new_cursor()
926 self
._select
('order by %s' %self
.clusterKey
, cursor
=cursor
)
927 return self
.generic_iterator(cursor
, self
.fetch_cluster
,
928 map_f
=generate_items
)
930 class SQLForeignRelation(object):
931 'mapping based on matching a foreign key in an SQL table'
932 def __init__(self
,table
,keyName
):
935 def __getitem__(self
,k
):
936 'get list of objects o with getattr(o,keyName)==k.id'
938 for o
in self
.table
.select('where %s=%%s'%self
.keyName
,(k
.id,)):
941 raise KeyError('%s not found in %s' %(str(k
),self
.name
))
945 class SQLTableNoCache(SQLTableBase
):
946 '''Provide on-the-fly access to rows in the database;
947 values are simply an object interface (SQLRow) to back-end db query.
948 Row data are not stored locally, but always accessed by querying the db'''
949 itemClass
=SQLRow
# DEFAULT OBJECT CLASS FOR ROWS...
952 def getID(self
,t
): return t
[0] # GET ID FROM TUPLE
953 def select(self
,whereClause
,params
):
954 return SQLTableBase
.select(self
,whereClause
,params
,self
.oclass
,
956 def __getitem__(self
,k
): # FIRST TRY LOCAL INDEX, THEN TRY DATABASE
958 return self
._weakValueDict
[k
] # DIRECTLY RETURN CACHED VALUE
959 except KeyError: # NOT FOUND, SO TRY THE DATABASE
960 self
._select
('where %s=%%s' % self
.primary_key
, (k
,),
962 t
= self
.cursor
.fetchmany(2)
964 raise KeyError('id %s non-existent or not unique' % k
)
965 o
= self
.itemClass(k
) # create obj referencing this ID
966 self
._weakValueDict
[k
] = o
# cache the SQLRow object
968 def __setitem__(self
, k
, v
):
969 if not self
.writeable
:
970 raise ValueError('this database is read only!')
974 except AttributeError:
975 raise ValueError('object not bound to itemClass for this db!')
977 del self
[k
] # delete row with new ID if any
981 del self
._weakValueDict
[v
.id] # delete from old cache location
984 self
._update
(v
.id, self
.primary_key
, k
) # just change its ID in db
985 v
.cache_id(k
) # change the cached ID value
986 self
._weakValueDict
[k
] = v
# assign to new cache location
987 def addAttrAlias(self
,**kwargs
):
988 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
990 SQLRow
._tableclass
=SQLTableNoCache
# SQLRow IS FOR NON-CACHING TABLE INTERFACE
993 class SQLTableMultiNoCache(SQLTableBase
):
994 "Trivial on-the-fly access for table with key that returns multiple rows"
995 itemClass
= TupleO
# default itemClass; constructor can override
996 _distinct_key
='id' # DEFAULT COLUMN TO USE AS KEY
998 return getKeys(self
, selectCols
='distinct(%s)'
999 % self
._attrSQL
(self
._distinct
_key
))
1001 return iter_keys(self
, 'distinct(%s)' % self
._attrSQL
(self
._distinct
_key
))
1002 def __getitem__(self
,id):
1003 sql
,params
= self
._format
_query
('select * from %s where %s=%%s'
1004 %(self
.name
,self
._attrSQL
(self
._distinct
_key
)),(id,))
1005 self
.cursor
.execute(sql
, params
)
1006 l
=self
.cursor
.fetchall() # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1008 yield self
.itemClass(row
)
1009 def addAttrAlias(self
,**kwargs
):
1010 self
.data
.update(kwargs
) # ALIAS KEYS TO EXPRESSION VALUES
1014 class SQLEdges(SQLTableMultiNoCache
):
1015 '''provide iterator over edges as (source,target,edge)
1016 and getitem[edge] --> [(source,target),...]'''
1017 _distinct_key
='edge_id'
1018 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1019 _pickleAttrs
.update(dict(graph
=0))
1021 self
.cursor
.execute('select %s,%s,%s from %s where %s is not null order by %s,%s'
1022 %(self
._attrSQL
('source_id'),self
._attrSQL
('target_id'),
1023 self
._attrSQL
('edge_id'),self
.name
,
1024 self
._attrSQL
('target_id'),self
._attrSQL
('source_id'),
1025 self
._attrSQL
('target_id')))
1026 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1027 for source_id
,target_id
,edge_id
in self
.cursor
.fetchall():
1028 l
.append((self
.graph
.unpack_source(source_id
),
1029 self
.graph
.unpack_target(target_id
),
1030 self
.graph
.unpack_edge(edge_id
)))
1034 return iter(self
.keys())
1035 def __getitem__(self
,edge
):
1036 sql
,params
= self
._format
_query
('select %s,%s from %s where %s=%%s'
1037 %(self
._attrSQL
('source_id'),
1038 self
._attrSQL
('target_id'),
1040 self
._attrSQL
(self
._distinct
_key
)),
1041 (self
.graph
.pack_edge(edge
),))
1042 self
.cursor
.execute(sql
, params
)
1043 l
= [] # PREFETCH ALL ROWS, SINCE CURSOR MAY BE REUSED
1044 for source_id
,target_id
in self
.cursor
.fetchall():
1045 l
.append((self
.graph
.unpack_source(source_id
),
1046 self
.graph
.unpack_target(target_id
)))
1050 class SQLEdgeDict(object):
1051 '2nd level graph interface to SQL database'
1052 def __init__(self
,fromNode
,table
):
1053 self
.fromNode
=fromNode
1055 if not hasattr(self
.table
,'allowMissingNodes'):
1056 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s limit 1'
1057 %(self
.table
.sourceSQL
,
1059 self
.table
.sourceSQL
),
1061 self
.table
.cursor
.execute(sql
, params
)
1062 if len(self
.table
.cursor
.fetchall())<1:
1063 raise KeyError('node not in graph!')
1065 def __getitem__(self
,target
):
1066 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1067 %(self
.table
.edgeSQL
,
1069 self
.table
.sourceSQL
,
1070 self
.table
.targetSQL
),
1072 self
.table
.pack_target(target
)))
1073 self
.table
.cursor
.execute(sql
, params
)
1074 l
= self
.table
.cursor
.fetchmany(2) # get at most two rows
1076 raise KeyError('either no edge from source to target or not unique!')
1078 return self
.table
.unpack_edge(l
[0][0]) # RETURN EDGE
1080 raise KeyError('no edge from node to target')
1081 def __setitem__(self
,target
,edge
):
1082 sql
,params
= self
.table
._format
_query
('replace into %s values (%%s,%%s,%%s)'
1085 self
.table
.pack_target(target
),
1086 self
.table
.pack_edge(edge
)))
1087 self
.table
.cursor
.execute(sql
, params
)
1088 if not hasattr(self
.table
,'sourceDB') or \
1089 (hasattr(self
.table
,'targetDB') and self
.table
.sourceDB
==self
.table
.targetDB
):
1090 self
.table
+= target
# ADD AS NODE TO GRAPH
1091 def __iadd__(self
,target
):
1093 return self
# iadd MUST RETURN self!
1094 def __delitem__(self
,target
):
1095 sql
,params
= self
.table
._format
_query
('delete from %s where %s=%%s and %s=%%s'
1097 self
.table
.sourceSQL
,
1098 self
.table
.targetSQL
),
1100 self
.table
.pack_target(target
)))
1101 self
.table
.cursor
.execute(sql
, params
)
1102 if self
.table
.cursor
.rowcount
< 1: # no rows deleted?
1103 raise KeyError('no edge from node to target')
1105 def iterator_query(self
):
1106 sql
,params
= self
.table
._format
_query
('select %s,%s from %s where %s=%%s and %s is not null'
1107 %(self
.table
.targetSQL
,
1110 self
.table
.sourceSQL
,
1111 self
.table
.targetSQL
),
1113 self
.table
.cursor
.execute(sql
, params
)
1114 return self
.table
.cursor
.fetchall()
1116 return [self
.table
.unpack_target(target_id
)
1117 for target_id
,edge_id
in self
.iterator_query()]
1119 return [self
.table
.unpack_edge(edge_id
)
1120 for target_id
,edge_id
in self
.iterator_query()]
1122 return [(self
.table
.unpack_source(self
.fromNode
),self
.table
.unpack_target(target_id
),
1123 self
.table
.unpack_edge(edge_id
))
1124 for target_id
,edge_id
in self
.iterator_query()]
1126 return [(self
.table
.unpack_target(target_id
),self
.table
.unpack_edge(edge_id
))
1127 for target_id
,edge_id
in self
.iterator_query()]
1128 def __iter__(self
): return iter(self
.keys())
1129 def itervalues(self
): return iter(self
.values())
1130 def iteritems(self
): return iter(self
.items())
1132 return len(self
.keys())
1135 class SQLEdgelessDict(SQLEdgeDict
):
1136 'for SQLGraph tables that lack edge_id column'
1137 def __getitem__(self
,target
):
1138 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s=%%s limit 2'
1139 %(self
.table
.targetSQL
,
1141 self
.table
.sourceSQL
,
1142 self
.table
.targetSQL
),
1144 self
.table
.pack_target(target
)))
1145 self
.table
.cursor
.execute(sql
, params
)
1146 l
= self
.table
.cursor
.fetchmany(2)
1148 raise KeyError('either no edge from source to target or not unique!')
1149 return None # no edge info!
1150 def iterator_query(self
):
1151 sql
,params
= self
.table
._format
_query
('select %s from %s where %s=%%s and %s is not null'
1152 %(self
.table
.targetSQL
,
1154 self
.table
.sourceSQL
,
1155 self
.table
.targetSQL
),
1157 self
.table
.cursor
.execute(sql
, params
)
1158 return [(t
[0],None) for t
in self
.table
.cursor
.fetchall()]
1160 SQLEdgeDict
._edgelessClass
= SQLEdgelessDict
1162 class SQLGraphEdgeDescriptor(object):
1163 'provide an SQLEdges interface on demand'
1164 def __get__(self
,obj
,objtype
):
1166 attrAlias
=obj
.attrAlias
.copy()
1167 except AttributeError:
1168 return SQLEdges(obj
.name
, obj
.cursor
, graph
=obj
)
1170 return SQLEdges(obj
.name
, obj
.cursor
, attrAlias
=attrAlias
,
1173 def getColumnTypes(createTable
,attrAlias
={},defaultColumnType
='int',
1174 columnAttrs
=('source','target','edge'),**kwargs
):
1175 'return list of [(colname,coltype),...] for source,target,edge'
1177 for attr
in columnAttrs
:
1179 attrName
= attrAlias
[attr
+'_id']
1181 attrName
= attr
+'_id'
1182 try: # SEE IF USER SPECIFIED A DESIRED TYPE
1183 l
.append((attrName
,createTable
[attr
+'_id']))
1185 except (KeyError,TypeError):
1187 try: # get type info from primary key for that database
1188 db
= kwargs
[attr
+'DB']
1190 raise KeyError # FORCE IT TO USE DEFAULT TYPE
1193 else: # INFER THE COLUMN TYPE FROM THE ASSOCIATED DATABASE KEYS...
1195 try: # GET ONE IDENTIFIER FROM THE DATABASE
1197 except StopIteration: # TABLE IS EMPTY, SO READ SQL TYPE FROM db OBJECT
1199 l
.append((attrName
,db
.columnType
[db
.primary_key
]))
1201 except AttributeError:
1203 else: # GET THE TYPE FROM THIS IDENTIFIER
1204 if isinstance(k
,int) or isinstance(k
,long):
1205 l
.append((attrName
,'int'))
1207 elif isinstance(k
,str):
1208 l
.append((attrName
,'varchar(32)'))
1211 raise ValueError('SQLGraph node / edge must be int or str!')
1212 l
.append((attrName
,defaultColumnType
))
1213 logger
.warn('no type info found for %s, so using default: %s'
1214 % (attrName
, defaultColumnType
))
1220 class SQLGraph(SQLTableMultiNoCache
):
1221 '''provide a graph interface via a SQL table. Key capabilities are:
1222 - setitem with an empty dictionary: a dummy operation
1223 - getitem with a key that exists: return a placeholder
1224 - setitem with non empty placeholder: again a dummy operation
1225 EXAMPLE TABLE SCHEMA:
1226 create table mygraph (source_id int not null,target_id int,edge_id int,
1227 unique(source_id,target_id));
1229 _distinct_key
='source_id'
1230 _pickleAttrs
= SQLTableMultiNoCache
._pickleAttrs
.copy()
1231 _pickleAttrs
.update(dict(sourceDB
=0,targetDB
=0,edgeDB
=0,allowMissingNodes
=0))
1232 _edgeClass
= SQLEdgeDict
1233 def __init__(self
,name
,*l
,**kwargs
):
1234 graphArgs
,tableArgs
= split_kwargs(kwargs
,
1235 ('attrAlias','defaultColumnType','columnAttrs',
1236 'sourceDB','targetDB','edgeDB','simpleKeys','unpack_edge',
1237 'edgeDictClass','graph'))
1238 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1239 c
= getColumnTypes(**kwargs
)
1240 tableArgs
['createTable'] = \
1241 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1242 % (name
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1244 self
.allowMissingNodes
= kwargs
['allowMissingNodes']
1245 except KeyError: pass
1246 SQLTableMultiNoCache
.__init
__(self
,name
,*l
,**tableArgs
)
1247 self
.sourceSQL
= self
._attrSQL
('source_id')
1248 self
.targetSQL
= self
._attrSQL
('target_id')
1250 self
.edgeSQL
= self
._attrSQL
('edge_id')
1251 except AttributeError:
1253 self
._edgeClass
= self
._edgeClass
._edgelessClass
1254 save_graph_db_refs(self
,**kwargs
)
1255 def __getitem__(self
,k
):
1256 return self
._edgeClass
(self
.pack_source(k
),self
)
1257 def __iadd__(self
,k
):
1258 sql
,params
= self
._format
_query
('delete from %s where %s=%%s and %s is null'
1259 % (self
.name
,self
.sourceSQL
,self
.targetSQL
),
1260 (self
.pack_source(k
),))
1261 self
.cursor
.execute(sql
, params
)
1262 sql
,params
= self
._format
_query
('insert %%(IGNORE)s into %s values (%%s,NULL,NULL)'
1263 % self
.name
,(self
.pack_source(k
),))
1264 self
.cursor
.execute(sql
, params
)
1265 return self
# iadd MUST RETURN SELF!
1266 def __isub__(self
,k
):
1267 sql
,params
= self
._format
_query
('delete from %s where %s=%%s'
1268 % (self
.name
,self
.sourceSQL
),
1269 (self
.pack_source(k
),))
1270 self
.cursor
.execute(sql
, params
)
1271 if self
.cursor
.rowcount
== 0:
1272 raise KeyError('node not found in graph')
1273 return self
# iadd MUST RETURN SELF!
1274 __setitem__
= graph_setitem
1275 def __contains__(self
,k
):
1276 sql
,params
= self
._format
_query
('select * from %s where %s=%%s limit 1'
1277 %(self
.name
,self
.sourceSQL
),
1278 (self
.pack_source(k
),))
1279 self
.cursor
.execute(sql
, params
)
1280 l
= self
.cursor
.fetchmany(2)
1282 def __invert__(self
):
1283 'get an interface to the inverse graph mapping'
1285 return self
._inverse
1286 except AttributeError: # CONSTRUCT INTERFACE TO INVERSE MAPPING
1287 attrAlias
= dict(source_id
=self
.targetSQL
, # SWAP SOURCE & TARGET
1288 target_id
=self
.sourceSQL
,
1289 edge_id
=self
.edgeSQL
)
1290 if self
.edgeSQL
is None: # no edge interface
1291 del attrAlias
['edge_id']
1292 self
._inverse
=SQLGraph(self
.name
,self
.cursor
,
1293 attrAlias
=attrAlias
,
1294 **graph_db_inverse_refs(self
))
1295 self
._inverse
._inverse
=self
1296 return self
._inverse
1298 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1299 yield self
.unpack_source(k
)
1300 def iteritems(self
):
1301 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1302 yield (self
.unpack_source(k
), self
._edgeClass
(k
, self
))
1303 def itervalues(self
):
1304 for k
in SQLTableMultiNoCache
.__iter
__(self
):
1305 yield self
._edgeClass
(k
, self
)
1307 return [self
.unpack_source(k
) for k
in SQLTableMultiNoCache
.keys(self
)]
1308 def values(self
): return list(self
.itervalues())
1309 def items(self
): return list(self
.iteritems())
1310 edges
=SQLGraphEdgeDescriptor()
1311 update
= update_graph
1313 'get number of source nodes in graph'
1314 self
.cursor
.execute('select count(distinct %s) from %s'
1315 %(self
.sourceSQL
,self
.name
))
1316 return self
.cursor
.fetchone()[0]
1318 override_rich_cmp(locals()) # MUST OVERRIDE __eq__ ETC. TO USE OUR __cmp__!
1319 ## def __cmp__(self,other):
1323 ## it = iter(self.edges)
1326 ## source,target,edge = it.next()
1327 ## except StopIteration:
1330 ## if d is not None:
1331 ## diff = cmp(n_target,len(d))
1334 ## if source is None:
1337 ## n += 1 # COUNT SOURCE NODES
1344 ## diff = cmp(edge,d[target])
1349 ## n_target += 1 # COUNT TARGET NODES FOR THIS SOURCE
1350 ## return cmp(n,len(other))
1352 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1354 class SQLIDGraph(SQLGraph
):
1355 add_trivial_packing_methods(locals())
1356 SQLGraph
._IDGraphClass
= SQLIDGraph
1360 class SQLEdgeDictClustered(dict):
1361 'simple cache for 2nd level dictionary of target_id:edge_id'
1362 def __init__(self
,g
,fromNode
):
1364 self
.fromNode
=fromNode
1366 def __iadd__(self
,l
):
1367 for target_id
,edge_id
in l
:
1368 dict.__setitem
__(self
,target_id
,edge_id
)
1369 return self
# iadd MUST RETURN SELF!
1371 class SQLEdgesClusteredDescr(object):
1372 def __get__(self
,obj
,objtype
):
1373 e
=SQLEdgesClustered(obj
.table
,obj
.edge_id
,obj
.source_id
,obj
.target_id
,
1374 graph
=obj
,**graph_db_inverse_refs(obj
,True))
1375 for source_id
,d
in obj
.d
.iteritems(): # COPY EDGE CACHE
1376 e
.load([(edge_id
,source_id
,target_id
)
1377 for (target_id
,edge_id
) in d
.iteritems()])
1380 class SQLGraphClustered(object):
1381 'SQL graph with clustered caching -- loads an entire cluster at a time'
1382 _edgeDictClass
=SQLEdgeDictClustered
1383 def __init__(self
,table
,source_id
='source_id',target_id
='target_id',
1384 edge_id
='edge_id',clusterKey
=None,**kwargs
):
1386 if isinstance(table
,types
.StringType
): # CREATE THE TABLE INTERFACE
1387 if clusterKey
is None:
1388 raise ValueError('you must provide a clusterKey argument!')
1389 if 'createTable' in kwargs
: # CREATE A SCHEMA FOR THIS TABLE
1390 c
= getColumnTypes(attrAlias
=dict(source_id
=source_id
,target_id
=target_id
,
1391 edge_id
=edge_id
),**kwargs
)
1392 kwargs
['createTable'] = \
1393 'create table %s (%s %s not null,%s %s,%s %s,unique(%s,%s))' \
1394 % (table
,c
[0][0],c
[0][1],c
[1][0],c
[1][1],
1395 c
[2][0],c
[2][1],c
[0][0],c
[1][0])
1396 table
= SQLTableClustered(table
,clusterKey
=clusterKey
,**kwargs
)
1398 self
.source_id
=source_id
1399 self
.target_id
=target_id
1400 self
.edge_id
=edge_id
1402 save_graph_db_refs(self
,**kwargs
)
1403 _pickleAttrs
= dict(table
=0,source_id
=0,target_id
=0,edge_id
=0,sourceDB
=0,targetDB
=0,
1405 def __getstate__(self
):
1406 state
= standard_getstate(self
)
1407 state
['d'] = {} # UNPICKLE SHOULD RESTORE GRAPH WITH EMPTY CACHE
1409 def __getitem__(self
,k
):
1410 'get edgeDict for source node k, from cache or by loading its cluster'
1411 try: # GET DIRECTLY FROM CACHE
1414 if hasattr(self
,'_isLoaded'):
1415 raise # ENTIRE GRAPH LOADED, SO k REALLY NOT IN THIS GRAPH
1416 # HAVE TO LOAD THE ENTIRE CLUSTER CONTAINING THIS NODE
1417 sql
,params
= self
.table
._format
_query
('select t2.%s,t2.%s,t2.%s from %s t1,%s t2 where t1.%s=%%s and t1.%s=t2.%s group by t2.%s'
1418 %(self
.source_id
,self
.target_id
,
1419 self
.edge_id
,self
.table
.name
,
1420 self
.table
.name
,self
.source_id
,
1421 self
.table
.clusterKey
,self
.table
.clusterKey
,
1422 self
.table
.primary_key
),
1423 (self
.pack_source(k
),))
1424 self
.table
.cursor
.execute(sql
, params
)
1425 self
.load(self
.table
.cursor
.fetchall()) # CACHE THIS CLUSTER
1426 return self
.d
[k
] # RETURN EDGE DICT FOR THIS NODE
1427 def load(self
,l
=None,unpack
=True):
1428 'load the specified rows (or all, if None provided) into local cache'
1430 try: # IF ALREADY LOADED, NO NEED TO DO ANYTHING
1431 return self
._isLoaded
1432 except AttributeError:
1434 self
.table
.cursor
.execute('select %s,%s,%s from %s'
1435 %(self
.source_id
,self
.target_id
,
1436 self
.edge_id
,self
.table
.name
))
1437 l
=self
.table
.cursor
.fetchall()
1439 self
.d
.clear() # CLEAR OUR CACHE AS load() WILL REPLICATE EVERYTHING
1440 for source
,target
,edge
in l
: # SAVE TO OUR CACHE
1442 source
= self
.unpack_source(source
)
1443 target
= self
.unpack_target(target
)
1444 edge
= self
.unpack_edge(edge
)
1446 self
.d
[source
] += [(target
,edge
)]
1448 d
= self
._edgeDictClass
(self
,source
)
1449 d
+= [(target
,edge
)]
1451 def __invert__(self
):
1452 'interface to reverse graph mapping'
1454 return self
._inverse
# INVERSE MAP ALREADY EXISTS
1455 except AttributeError:
1457 # JUST CREATE INTERFACE WITH SWAPPED TARGET & SOURCE
1458 self
._inverse
=SQLGraphClustered(self
.table
,self
.target_id
,self
.source_id
,
1459 self
.edge_id
,**graph_db_inverse_refs(self
))
1460 self
._inverse
._inverse
=self
1461 for source
,d
in self
.d
.iteritems(): # INVERT OUR CACHE
1462 self
._inverse
.load([(target
,source
,edge
)
1463 for (target
,edge
) in d
.iteritems()],unpack
=False)
1464 return self
._inverse
1465 edges
=SQLEdgesClusteredDescr() # CONSTRUCT EDGE INTERFACE ON DEMAND
1466 update
= update_graph
1467 add_standard_packing_methods(locals()) ############ PACK / UNPACK METHODS
1468 def __iter__(self
): ################# ITERATORS
1469 'uses db select; does not force load'
1470 return iter(self
.keys())
1472 'uses db select; does not force load'
1473 self
.table
.cursor
.execute('select distinct(%s) from %s'
1474 %(self
.source_id
,self
.table
.name
))
1475 return [self
.unpack_source(t
[0])
1476 for t
in self
.table
.cursor
.fetchall()]
1477 methodFactory(['iteritems','items','itervalues','values'],
1478 'lambda self:(self.load(),self.d.%s())[1]',locals())
1479 def __contains__(self
,k
):
1486 class SQLIDGraphClustered(SQLGraphClustered
):
1487 add_trivial_packing_methods(locals())
1488 SQLGraphClustered
._IDGraphClass
= SQLIDGraphClustered
1490 class SQLEdgesClustered(SQLGraphClustered
):
1491 'edges interface for SQLGraphClustered'
1492 _edgeDictClass
= list
1493 _pickleAttrs
= SQLGraphClustered
._pickleAttrs
.copy()
1494 _pickleAttrs
.update(dict(graph
=0))
1498 for edge_id
,l
in self
.d
.iteritems():
1499 for source_id
,target_id
in l
:
1500 result
.append((self
.graph
.unpack_source(source_id
),
1501 self
.graph
.unpack_target(target_id
),
1502 self
.graph
.unpack_edge(edge_id
)))
1505 class ForeignKeyInverse(object):
1506 'map each key to a single value according to its foreign key'
1507 def __init__(self
,g
):
1509 def __getitem__(self
,obj
):
1511 source_id
= getattr(obj
,self
.g
.keyColumn
)
1512 if source_id
is None:
1514 return self
.g
.sourceDB
[source_id
]
1515 def __setitem__(self
,obj
,source
):
1517 if source
is not None:
1518 self
.g
[source
][obj
] = None # ENSURES ALL THE RIGHT CACHING OPERATIONS DONE
1519 else: # DELETE PRE-EXISTING EDGE IF PRESENT
1520 if not hasattr(obj
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1521 old_source
= self
[obj
]
1522 if old_source
is not None:
1523 del self
.g
[old_source
][obj
]
1524 def check_obj(self
,obj
):
1525 'raise KeyError if obj not from this db'
1527 if obj
.db
!= self
.g
.targetDB
:
1528 raise AttributeError
1529 except AttributeError:
1530 raise KeyError('key is not from targetDB of this graph!')
1531 def __contains__(self
,obj
):
1538 return self
.g
.targetDB
.itervalues()
1540 return self
.g
.targetDB
.values()
1541 def iteritems(self
):
1543 source_id
= getattr(obj
,self
.g
.keyColumn
)
1544 if source_id
is None:
1547 yield obj
,self
.g
.sourceDB
[source_id
]
1549 return list(self
.iteritems())
1550 def itervalues(self
):
1551 for obj
,val
in self
.iteritems():
1554 return list(self
.itervalues())
1555 def __invert__(self
):
1558 class ForeignKeyEdge(dict):
1559 '''edge interface to a foreign key in an SQL table.
1560 Caches dict of target nodes in itself; provides dict interface.
1561 Adds or deletes edges by setting foreign key values in the table'''
1562 def __init__(self
,g
,k
):
1566 for v
in g
.targetDB
.select('where %s=%%s' % g
.keyColumn
,(k
.id,)): # SEARCH THE DB
1567 dict.__setitem
__(self
,v
,None) # SAVE IN CACHE
1568 def __setitem__(self
,dest
,v
):
1569 if not hasattr(dest
,'db') or dest
.db
!= self
.g
.targetDB
:
1570 raise KeyError('dest is not in the targetDB bound to this graph!')
1572 raise ValueError('sorry,this graph cannot store edge information!')
1573 if not hasattr(dest
,'_localOnly'): # ONLY CACHE, DON'T SAVE TO DATABASE
1574 old_source
= self
.g
._inverse
[dest
] # CHECK FOR PRE-EXISTING EDGE
1575 if old_source
is not None: # REMOVE OLD EDGE FROM CACHE
1576 dict.__delitem
__(self
.g
[old_source
],dest
)
1577 #self.g.targetDB._update(dest.id,self.g.keyColumn,self.src.id) # SAVE TO DB
1578 setattr(dest
,self
.g
.keyColumn
,self
.src
.id) # SAVE TO DB ATTRIBUTE
1579 dict.__setitem
__(self
,dest
,None) # SAVE IN CACHE
1580 def __delitem__(self
,dest
):
1581 #self.g.targetDB._update(dest.id,self.g.keyColumn,None) # REMOVE FOREIGN KEY VALUE
1582 setattr(dest
,self
.g
.keyColumn
,None) # SAVE TO DB ATTRIBUTE
1583 dict.__delitem
__(self
,dest
) # REMOVE FROM CACHE
1585 class ForeignKeyGraph(object, UserDict
.DictMixin
):
1586 '''graph interface to a foreign key in an SQL table
1587 Caches dict of target nodes in itself; provides dict interface.
1589 def __init__(self
, sourceDB
, targetDB
, keyColumn
, autoGC
=True, **kwargs
):
1590 '''sourceDB is any database of source nodes;
1591 targetDB must be an SQL database of target nodes;
1592 keyColumn is the foreign key column name in targetDB for looking up sourceDB IDs.'''
1593 if autoGC
: # automatically garbage collect unused objects
1594 self
._weakValueDict
= RecentValueDictionary(autoGC
) # object cache
1596 self
._weakValueDict
= {}
1597 self
.autoGC
= autoGC
1598 self
.sourceDB
= sourceDB
1599 self
.targetDB
= targetDB
1600 self
.keyColumn
= keyColumn
1601 self
._inverse
= ForeignKeyInverse(self
)
1602 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, keyColumn
=0, autoGC
=0)
1603 __getstate__
= standard_getstate
########### SUPPORT FOR PICKLING
1604 __setstate__
= standard_setstate
1605 def _inverse_schema(self
):
1606 'provide custom schema rule for inverting this graph... just use keyColumn!'
1607 return dict(invert
=True,uniqueMapping
=True)
1608 def __getitem__(self
,k
):
1609 if not hasattr(k
,'db') or k
.db
!= self
.sourceDB
:
1610 raise KeyError('object is not in the sourceDB bound to this graph!')
1612 return self
._weakValueDict
[k
.id] # get from cache
1615 d
= ForeignKeyEdge(self
,k
)
1616 self
._weakValueDict
[k
.id] = d
# save in cache
1618 def __setitem__(self
, k
, v
):
1619 raise KeyError('''do not save as g[k]=v. Instead follow a graph
1620 interface: g[src]+=dest, or g[src][dest]=None (no edge info allowed)''')
1621 def __delitem__(self
, k
):
1622 raise KeyError('''Instead of del g[k], follow a graph
1623 interface: del g[src][dest]''')
1625 return self
.sourceDB
.values()
1626 __invert__
= standard_invert
1628 def describeDBTables(name
,cursor
,idDict
):
1630 Get table info about database <name> via <cursor>, and store primary keys
1631 in idDict, along with a list of the tables each key indexes.
1633 cursor
.execute('use %s' % name
)
1634 cursor
.execute('show tables')
1636 l
=[c
[0] for c
in cursor
.fetchall()]
1639 o
=SQLTable(tname
,cursor
)
1641 for f
in o
.description
:
1642 if f
==o
.primary_key
:
1643 idDict
.setdefault(f
, []).append(o
)
1644 elif f
[-3:]=='_id' and f
not in idDict
:
1650 def indexIDs(tables
,idDict
=None):
1651 "Get an index of primary keys in the <tables> dictionary."
1654 for o
in tables
.values():
1656 if o
.primary_key
not in idDict
:
1657 idDict
[o
.primary_key
]=[]
1658 idDict
[o
.primary_key
].append(o
) # KEEP LIST OF TABLES WITH THIS PRIMARY KEY
1659 for f
in o
.description
:
1660 if f
[-3:]=='_id' and f
not in idDict
:
1666 def suffixSubset(tables
,suffix
):
1667 "Filter table index for those matching a specific suffix"
1669 for name
,t
in tables
.items():
1670 if name
.endswith(suffix
):
1677 def graphDBTables(tables
,idDict
):
1679 for t
in tables
.values():
1680 for f
in t
.description
:
1681 if f
==t
.primary_key
:
1682 edgeInfo
=PRIMARY_KEY
1685 g
.setEdge(f
,t
,edgeInfo
)
1686 g
.setEdge(t
,f
,edgeInfo
)
1689 SQLTypeTranslation
= {types
.StringType
:'varchar(32)',
1690 types
.IntType
:'int',
1691 types
.FloatType
:'float'}
1693 def createTableFromRepr(rows
,tableName
,cursor
,typeTranslation
=None,
1694 optionalDict
=None,indexDict
=()):
1695 """Save rows into SQL tableName using cursor, with optional
1696 translations of columns to specific SQL types (specified
1697 by typeTranslation dict).
1698 - optionDict can specify columns that are allowed to be NULL.
1699 - indexDict can specify columns that must be indexed; columns
1700 whose names end in _id will be indexed by default.
1701 - rows must be an iterator which in turn returns dictionaries,
1702 each representing a tuple of values (indexed by their column
1706 row
=rows
.next() # GET 1ST ROW TO EXTRACT COLUMN INFO
1707 except StopIteration:
1708 return # IF rows EMPTY, NO NEED TO SAVE ANYTHING, SO JUST RETURN
1710 createTableFromRow(cursor
, tableName
,row
,typeTranslation
,
1711 optionalDict
,indexDict
)
1714 storeRow(cursor
,tableName
,row
) # SAVE OUR FIRST ROW
1715 for row
in rows
: # NOW SAVE ALL THE ROWS
1716 storeRow(cursor
,tableName
,row
)
1718 def createTableFromRow(cursor
, tableName
, row
,typeTranslation
=None,
1719 optionalDict
=None,indexDict
=()):
1721 for col
,val
in row
.items(): # PREPARE SQL TYPES FOR COLUMNS
1723 if typeTranslation
!=None and col
in typeTranslation
:
1724 coltype
=typeTranslation
[col
] # USER-SUPPLIED TRANSLATION
1725 elif type(val
) in SQLTypeTranslation
:
1726 coltype
=SQLTypeTranslation
[type(val
)]
1727 else: # SEARCH FOR A COMPATIBLE TYPE
1728 for t
in SQLTypeTranslation
:
1729 if isinstance(val
,t
):
1730 coltype
=SQLTypeTranslation
[t
]
1733 raise TypeError("Don't know SQL type to use for %s" % col
)
1734 create_def
='%s %s' %(col
,coltype
)
1735 if optionalDict
==None or col
not in optionalDict
:
1736 create_def
+=' not null'
1737 create_defs
.append(create_def
)
1738 for col
in row
: # CREATE INDEXES FOR ID COLUMNS
1739 if col
[-3:]=='_id' or col
in indexDict
:
1740 create_defs
.append('index(%s)' % col
)
1741 cmd
='create table if not exists %s (%s)' % (tableName
,','.join(create_defs
))
1742 cursor
.execute(cmd
) # CREATE THE TABLE IN THE DATABASE
1745 def storeRow(cursor
, tableName
, row
):
1746 row_format
=','.join(len(row
)*['%s'])
1747 cmd
='insert into %s values (%s)' % (tableName
,row_format
)
1748 cursor
.execute(cmd
,tuple(row
.values()))
1750 def storeRowDelayed(cursor
, tableName
, row
):
1751 row_format
=','.join(len(row
)*['%s'])
1752 cmd
='insert delayed into %s values (%s)' % (tableName
,row_format
)
1753 cursor
.execute(cmd
,tuple(row
.values()))
1756 class TableGroup(dict):
1757 'provide attribute access to dbname qualified tablenames'
1758 def __init__(self
,db
='test',suffix
=None,**kw
):
1761 if suffix
is not None:
1763 for k
,v
in kw
.items():
1764 if v
is not None and '.' not in v
:
1765 v
=self
.db
+'.'+v
# ADD DATABASE NAME AS PREFIX
1767 def __getattr__(self
,k
):
1770 def sqlite_connect(*args
, **kwargs
):
1771 sqlite
= import_sqlite()
1772 connection
= sqlite
.connect(*args
, **kwargs
)
1773 cursor
= connection
.cursor()
1774 return connection
, cursor
1776 class DBServerInfo(object):
1777 'picklable reference to a database server'
1778 def __init__(self
, moduleName
='MySQLdb', *args
, **kwargs
):
1780 self
.__class
__ = _DBServerModuleDict
[moduleName
]
1782 raise ValueError('Module name not found in _DBServerModuleDict: '\
1784 self
.moduleName
= moduleName
1786 self
.kwargs
= kwargs
# connection arguments
1789 """returns cursor associated with the DB server info (reused)"""
1792 except AttributeError:
1793 self
._start
_connection
()
1796 def new_cursor(self
, arraysize
=None):
1797 """returns a NEW cursor; you must close it yourself! """
1798 if not hasattr(self
, '_connection'):
1799 self
._start
_connection
()
1800 cursor
= self
._connection
.cursor()
1801 if arraysize
is not None:
1802 cursor
.arraysize
= arraysize
1806 """Close file containing this database"""
1807 self
._cursor
.close()
1808 self
._connection
.close()
1810 del self
._connection
1812 def __getstate__(self
):
1813 """return all picklable arguments"""
1814 return dict(args
=self
.args
, kwargs
=self
.kwargs
,
1815 moduleName
=self
.moduleName
)
1818 class MySQLServerInfo(DBServerInfo
):
1819 'customized for MySQLdb SSCursor support via new_cursor()'
1820 def _start_connection(self
):
1821 self
._connection
,self
._cursor
= mysql_connect(*self
.args
, **self
.kwargs
)
1822 def new_cursor(self
, arraysize
=None):
1823 'provide streaming cursor support'
1825 conn
= self
._conn
_sscursor
1826 except AttributeError:
1827 self
._conn
_sscursor
,cursor
= mysql_connect(useStreaming
=True,
1828 *self
.args
, **self
.kwargs
)
1830 cursor
= self
._conn
_sscursor
.cursor()
1831 if arraysize
is not None:
1832 cursor
.arraysize
= arraysize
1835 DBServerInfo
.close(self
)
1837 self
._conn
_sscursor
.close()
1838 del self
._conn
_sscursor
1839 except AttributeError:
1841 def iter_keys(self
, db
, cursor
, map_f
=iter,
1842 cache_f
=lambda x
:[t
[0] for t
in x
], **kwargs
):
1843 block_generator
= BlockGenerator(db
, self
, cursor
, **kwargs
)
1844 return db
.generic_iterator(cursor
=cursor
, cache_f
=cache_f
,
1845 map_f
=map_f
, fetch_f
=block_generator
)
1848 class BlockGenerator(object):
1849 def __init__(self
, db
, serverInfo
, cursor
, whereClause
='', **kwargs
):
1851 self
.serverInfo
= serverInfo
1852 self
.cursor
= cursor
1853 self
.kwargs
= kwargs
1854 self
.blockSize
= 10000
1855 self
.whereClause
= ''
1856 #self.__iter__() # start me up!
1858 ## def __iter__(self):
1859 ## 'initialize this iterator'
1860 ## self.db._select(cursor=cursor, selectCols='min(%s),max(%s),count(*)'
1861 ## % (self.db.name, self.db.name))
1862 ## l = self.cursor.fetchall()
1863 ## self.minID, self.maxID, self.count = l[0]
1864 ## self.start = self.minID - 1 # only works for int
1868 ## 'get the next start position'
1869 ## if self.start >= self.maxID:
1870 ## raise StopIteration
1874 'get the next block of data'
1876 ## start = self.next()
1877 ## except StopIteration:
1879 print 'SELECT ... %s LIMIT %s' % (self
.whereClause
, self
.blockSize
)
1880 self
.db
._select
(cursor
=self
.cursor
, whereClause
=self
.whereClause
,
1881 limit
='LIMIT %s' % self
.blockSize
, **kwargs
)
1882 rows
= self
.cursor
.fetchall()
1884 if len(lastrow
) > 1: # extract the last ID value in this block
1885 start
= lastrow
[self
.db
.data
['id']]
1888 self
.whereClause
= 'WHERE %s>%s' %(self
.db
.primary_key
,start
)
1893 class SQLiteServerInfo(DBServerInfo
):
1894 """picklable reference to a sqlite database"""
1895 def __init__(self
, database
, *args
, **kwargs
):
1896 """Takes same arguments as sqlite3.connect()"""
1897 DBServerInfo
.__init
__(self
, 'sqlite',
1898 SourceFileName(database
), # save abs path!
1900 def _start_connection(self
):
1901 self
._connection
,self
._cursor
= sqlite_connect(*self
.args
, **self
.kwargs
)
1902 def __getstate__(self
):
1903 if self
.args
[0] == ':memory:':
1904 raise ValueError('SQLite in-memory database is not picklable!')
1905 return DBServerInfo
.__getstate
__(self
)
1907 # list of DBServerInfo subclasses for different modules
1908 _DBServerModuleDict
= dict(MySQLdb
=MySQLServerInfo
, sqlite
=SQLiteServerInfo
)
1911 class MapView(object, UserDict
.DictMixin
):
1912 'general purpose 1:1 mapping defined by any SQL query'
1913 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None,
1914 serverInfo
=None, inverseSQL
=None, **kwargs
):
1915 self
.sourceDB
= sourceDB
1916 self
.targetDB
= targetDB
1917 self
.viewSQL
= viewSQL
1918 self
.inverseSQL
= inverseSQL
1920 if serverInfo
is not None: # get cursor from serverInfo
1921 cursor
= serverInfo
.cursor()
1923 try: # can we get it from our other db?
1924 serverInfo
= sourceDB
.serverInfo
1925 except AttributeError:
1926 raise ValueError('you must provide serverInfo or cursor!')
1928 cursor
= serverInfo
.cursor()
1929 self
.cursor
= cursor
1930 self
.serverInfo
= serverInfo
1931 self
.get_sql_format(False) # get sql formatter for this db interface
1932 _schemaModuleDict
= _schemaModuleDict
# default module list
1933 get_sql_format
= get_table_schema
1934 def __getitem__(self
, k
):
1935 if not hasattr(k
,'db') or k
.db
is not self
.sourceDB
:
1936 raise KeyError('object is not in the sourceDB bound to this map!')
1937 sql
,params
= self
._format
_query
(self
.viewSQL
, (k
.id,))
1938 self
.cursor
.execute(sql
, params
) # formatted for this db interface
1939 t
= self
.cursor
.fetchmany(2) # get at most two rows
1941 raise KeyError('%s not found in MapView, or not unique'
1943 return self
.targetDB
[t
[0][0]] # get the corresponding object
1944 _pickleAttrs
= dict(sourceDB
=0, targetDB
=0, viewSQL
=0, serverInfo
=0,
1946 __getstate__
= standard_getstate
1947 __setstate__
= standard_setstate
1948 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
1949 setdefault
= read_only_error
1951 'only yield sourceDB items that are actually in this mapping!'
1952 for k
in self
.sourceDB
.itervalues():
1959 return [k
for k
in self
] # don't use list(self); causes infinite loop!
1960 def __invert__(self
):
1962 return self
._inverse
1963 except AttributeError:
1964 if self
.inverseSQL
is None:
1965 raise ValueError('this MapView has no inverseSQL!')
1966 self
._inverse
= self
.__class
__(self
.targetDB
, self
.sourceDB
,
1967 self
.inverseSQL
, self
.cursor
,
1968 serverInfo
=self
.serverInfo
,
1969 inverseSQL
=self
.viewSQL
)
1970 self
._inverse
._inverse
= self
1971 return self
._inverse
1973 class GraphViewEdgeDict(UserDict
.DictMixin
):
1974 'edge dictionary for GraphView: just pre-loaded on init'
1975 def __init__(self
, g
, k
):
1978 sql
,params
= self
.g
._format
_query
(self
.g
.viewSQL
, (k
.id,))
1979 self
.g
.cursor
.execute(sql
, params
) # run the query
1980 l
= self
.g
.cursor
.fetchall() # get results
1982 raise KeyError('key %s not in GraphView' % k
.id)
1983 self
.targets
= [t
[0] for t
in l
] # preserve order of the results
1984 d
= {} # also keep targetID:edgeID mapping
1985 if self
.g
.edgeDB
is not None: # save with edge info
1993 return len(self
.targets
)
1995 for k
in self
.targets
:
1996 yield self
.g
.targetDB
[k
]
1999 def iteritems(self
):
2000 if self
.g
.edgeDB
is not None: # save with edge info
2001 for k
in self
.targets
:
2002 yield (self
.g
.targetDB
[k
], self
.g
.edgeDB
[self
.targetDict
[k
]])
2003 else: # just save the list of targets, no edge info
2004 for k
in self
.targets
:
2005 yield (self
.g
.targetDB
[k
], None)
2006 def __getitem__(self
, o
, exitIfFound
=False):
2007 'for the specified target object, return its associated edge object'
2009 if o
.db
is not self
.g
.targetDB
:
2010 raise KeyError('key is not part of targetDB!')
2011 edgeID
= self
.targetDict
[o
.id]
2012 except AttributeError:
2013 raise KeyError('key has no id or db attribute?!')
2016 if self
.g
.edgeDB
is not None: # return the edge object
2017 return self
.g
.edgeDB
[edgeID
]
2018 else: # no edge info
2020 def __contains__(self
, o
):
2022 self
.__getitem
__(o
, True) # raise KeyError if not found
2026 __setitem__
= __delitem__
= clear
= pop
= popitem
= update
= \
2027 setdefault
= read_only_error
2029 class GraphView(MapView
):
2030 'general purpose graph interface defined by any SQL query'
2031 def __init__(self
, sourceDB
, targetDB
, viewSQL
, cursor
=None, edgeDB
=None,
2033 'if edgeDB not None, viewSQL query must return (targetID,edgeID) tuples'
2034 self
.edgeDB
= edgeDB
2035 MapView
.__init
__(self
, sourceDB
, targetDB
, viewSQL
, cursor
, **kwargs
)
2036 def __getitem__(self
, k
):
2037 if not hasattr(k
,'db') or k
.db
is not self
.sourceDB
:
2038 raise KeyError('object is not in the sourceDB bound to this map!')
2039 return GraphViewEdgeDict(self
, k
)
2040 _pickleAttrs
= MapView
._pickleAttrs
.copy()
2041 _pickleAttrs
.update(dict(edgeDB
=0))
2043 # @CTB move to sqlgraph.py?
2045 class SQLSequence(SQLRow
, SequenceBase
):
2046 """Transparent access to a DB row representing a sequence.
2048 Use attrAlias dict to rename 'length' to something else.
2050 def _init_subclass(cls
, db
, **kwargs
):
2051 db
.seqInfoDict
= db
# db will act as its own seqInfoDict
2052 SQLRow
._init
_subclass
(db
=db
, **kwargs
)
2053 _init_subclass
= classmethod(_init_subclass
)
2054 def __init__(self
, id):
2055 SQLRow
.__init
__(self
, id)
2056 SequenceBase
.__init
__(self
)
2059 def strslice(self
,start
,end
):
2060 "Efficient access to slice of a sequence, useful for huge contigs"
2061 return self
._select
('%%(SUBSTRING)s(%s %%(SUBSTR_FROM)s %d %%(SUBSTR_FOR)s %d)'
2062 %(self
.db
._attrSQL
('seq'),start
+1,end
-start
))
2064 class DNASQLSequence(SQLSequence
):
2065 _seqtype
=DNA_SEQTYPE
2067 class RNASQLSequence(SQLSequence
):
2068 _seqtype
=RNA_SEQTYPE
2070 class ProteinSQLSequence(SQLSequence
):
2071 _seqtype
=PROTEIN_SEQTYPE