Removed some leftovers from media.xsl.
[enkel.git] / enkel / sqldb / std_adapter.py
blobbb290532d4d360ef3c1c1039a8c9d1d84adcb9d3
1 # This file is part of the Enkel web programming library.
3 # Copyright (C) 2007 Espen Angell Kristiansen (espeak@users.sourceforge.net)
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # of the License, or (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19 from cStringIO import StringIO
21 from enkel.model.field.base import *
22 from enkel.model.ds import *
24 from table import *
25 from adapter import Adapter
26 from query import WhereClause, And
27 from manip_fetcher import *
28 from dbfields import *
31 class StdAdapter(Adapter):
32 def _get_fk_value(self, datasrc, value):
33 """ Get the value of a foreign-key field. Works both
34 if the value is a Manip and if it is a normal value.
35 """
36 if isinstance(value, Manip):
37 pk = datasrc.primary_keys[0]
38 return getattr(value, pk)
39 else:
40 return value
43 def _get_sqltype(self, fieldtype, fieldname, field):
44 if isinstance(field, String):
45 return "VARCHAR(%d)" % field.maxlength
46 elif isinstance(field, DbSmallInt):
47 return "SMALLINT"
48 elif isinstance(field, DbLong):
49 return "BIGINT"
50 elif isinstance(field, DbInt):
51 return "INTEGER"
52 elif isinstance(field, Text):
53 return "TEXT"
54 elif isinstance(field, Date):
55 return "DATE"
56 elif isinstance(field, Time):
57 return "TIME"
58 elif isinstance(field, DateTime):
59 return "DATETIME"
60 elif fieldtype == Table.FT_FK:
61 return self._get_sqltype(Table.FT_NORM, fieldname,
62 field.datasource.field)
63 else:
64 raise ValueError("unsupported field-type: %r" % field)
66 def _get_sqlfield(self, fieldtype, fieldname, field):
67 if field.required:
68 n = " NOT NULL"
69 else:
70 n = ""
71 return "%s %s%s" % (fieldname,
72 self._get_sqltype(fieldtype, fieldname, field), n)
75 def _manip_to_args(self, table, manip, ignore):
76 """ Parse a manip for the information needed by
77 L{insert} and L{update}.
79 @return: (fieldnames, params, mm_dict). Where fieldnames is a list
80 of fieldnames of normal and foreign-key fields. params is
81 a list of the values for normal and foreign-key fields. And
82 mm_dict is a dict of values for many-to-many fields.
83 """
84 ignore = set(ignore)
85 params = []
86 fieldnames = []
87 mm = {}
88 for fieldname, value in manip.iteritems():
89 if fieldname in ignore:
90 continue
91 field = table.model[fieldname]
92 if isinstance(field, DatasourceField):
93 if isinstance(field, One):
94 v = self._get_fk_value(
95 field.datasource, value)
96 params.append(v)
97 fieldnames.append(fieldname)
98 elif isinstance(field, Many):
99 mmtable = table.mmtables[fieldname]
100 m = []
101 pk = getattr(manip, table.primary_keys[0])
102 for val in value:
103 pk2 = self._get_fk_value(field.datasource, val)
104 m.append(mmtable.manip(source=pk, target=pk2))
105 mm[fieldname] = m
106 else:
107 params.append(value)
108 fieldnames.append(fieldname)
109 return fieldnames, params, mm
112 def _create_indexes(self, table):
113 for index in table.indexes:
114 if isinstance(index, UniqueIndex):
115 u = " UNIQUE "
116 else:
117 u = " "
118 q = "CREATE%sINDEX %s ON %s(%s)" % (u,
119 index.get_name(table.name), table.name,
120 ",".join(index.fields))
121 self.execute(q)
126 def paramgen(self, index):
127 """ Parameter generator following the interface specified
128 L{here <query.paramgen_interface>}.
130 Must be implemented in subclasses.
132 raise NotImplementedError()
135 def drop_table(self, table):
136 self.commit()
137 self.cursor.execute("DROP TABLE %s" % table.name)
138 self.commit()
141 def insert(self, table, manip, ignore=[]):
142 fieldnames, params, mm = self._manip_to_args(
143 table, manip, ignore)
144 values = [self.paramgen(i) for i, v in enumerate(params)]
146 qry = "INSERT INTO %s(%s) VALUES(%s)" % (table.name,
147 ",".join(fieldnames), ",".join(values))
148 self.execute(qry, params)
150 # insert into many-to-many tables
151 for fieldname, manips in mm.iteritems():
152 mmtable = table.mmtables[fieldname]
153 for m in manips:
154 mmtable.insert(m)
157 def _manipwhere(self, table, manip):
158 a = [getattr(table.c, pk) == manip[pk] for pk in table.primary_keys]
159 q, params = And(*a).compile(None, self.paramgen)
160 return q, params
162 def delete(self, table, manip):
163 f, p, mm = self._manip_to_args(table, manip, [])
165 if mm:
166 pk = manip[table.get_pk()]
167 for fieldname in mm:
168 mmtable = table.mmtables[fieldname]
169 mmtable.qry_delete(mmtable.c.source == pk)
171 q, params = self._manipwhere(table, manip)
172 qry = "DELETE FROM %s WHERE %s" % (table.name, q)
173 self.execute(qry, params)
176 def update(self, table, manip, mm_add=[], ignore=[]):
177 ignore.extend(table.primary_keys)
178 fieldnames, params, mm = self._manip_to_args(
179 table, manip, ignore)
180 values = ["%s=%s" % (fn, self.paramgen(i))
181 for i, fn in enumerate(fieldnames)]
183 q, p = self._manipwhere(table, manip)
184 params += p
185 qry = "UPDATE %s SET %s WHERE %s" % (table.name, ",".join(values), q)
186 self.execute(qry, params)
188 if mm:
189 # remove all many-to-many columns not in "mm_add"
190 pk = manip[table.get_pk()]
191 for fieldname, manips in mm.iteritems():
192 if not fieldname in mm_add:
193 mmtable = table.mmtables[fieldname]
194 mmtable.qry_delete(mmtable.c.source == pk)
196 # insert into many-to-many tables
197 for fieldname, manips in mm.iteritems():
198 mmtable = table.mmtables[fieldname]
199 for m in manips:
200 mmtable.insert(m)
203 def _where_clause(self, table, *ops, **named_ops):
204 return WhereClause(table.name, self.paramgen,
205 *ops, **named_ops).compile()
208 def select(self, table, *ops, **named_ops):
209 ignore = set(named_ops.pop("ignore", [])) # using set for fast lookup
211 q, params = self._where_clause(table, *ops, **named_ops)
213 fields = []
214 for fieldname, field in table.model.iteritems():
215 if fieldname in table.mmtables or fieldname in ignore:
216 continue
217 fields.append(fieldname)
219 query = "SELECT %s FROM %s %s" % (",".join(fields), table.name, q)
220 c = self.new_cursor()
221 self.execute(query, params, c)
222 return SimpleManipFetcher(c, table)
224 def select_one(self, table, *ops, **named_ops):
225 named_ops["limit"] = 1
226 r = self.select(table, *ops, **named_ops)
227 return r.fetchone()
232 def _field_alias(self, fieldname, table, sep):
233 return "%s.%s AS %s%s%s" % (table.name, fieldname,
234 table.name, sep, fieldname)
236 def jselect(self, table, *ops, **named_ops):
237 ignore = set(named_ops.pop("ignore", [])) # using set for fast lookup
239 if not table.foreign_keys:
240 return self.select(table, *ops, **named_ops)
242 sep = named_ops.get("sep", "_")
243 q, params = self._where_clause(table, *ops, **named_ops)
245 fields = []
246 joins = StringIO()
247 for t, fieldname, field in table.iter_fields():
248 if t == Table.FT_MM or fieldname in ignore:
249 continue
250 if t == Table.FT_FK:
251 ftable = field.datasource
252 for fn in ftable.model:
253 if not (fn in ftable.mmtables or field in ignore):
254 fields.append(self._field_alias(fn, ftable, sep))
255 joins.write("\n\tLEFT OUTER JOIN %s ON %s.%s = %s.%s" % (
256 ftable.name,
257 table.name, fieldname,
258 ftable.name, ftable.primary_keys[0]))
259 else:
260 fields.append(self._field_alias(fieldname, table, sep))
262 query = "SELECT %s\nFROM %s%s\n%s" % (
263 ",\n\t".join(fields), table.name, joins.getvalue(), q)
264 c = self.new_cursor()
265 self.execute(query, params, c)
266 return JoinManipFetcher(c, table)
268 def jselect_one(self, table, *ops, **named_ops):
269 named_ops["limit"] = 1
270 r = self.jselect(table, *ops, **named_ops)
271 return r.fetchone()
276 def rselect(self, table, *ops, **named_ops):
277 ignore = set(named_ops.pop("ignore", [])) # using set for fast lookup
279 q, params = self._where_clause(table, *ops, **named_ops)
281 fields = []
282 for fieldname, field in table.model.iteritems():
283 if fieldname in table.mmtables or fieldname in ignore:
284 continue
285 fields.append(fieldname)
287 query = "SELECT %s FROM %s %s" % (",".join(fields), table.name, q)
288 c = self.new_cursor()
289 self.execute(query, params, c)
290 return RecursiveManipFetcher(c, table)
292 def rselect_one(self, table, *ops, **named_ops):
293 named_ops["limit"] = 1
294 r = self.rselect(table, *ops, **named_ops)
295 return r.fetchone()