1 # This file is part of the Enkel web programming library.
3 # Copyright (C) 2007 Espen Angell Kristiansen (espeak@users.sourceforge.net)
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # of the License, or (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19 from cStringIO
import StringIO
21 from enkel
.model
.field
.base
import *
22 from enkel
.model
.ds
import *
25 from adapter
import Adapter
26 from query
import WhereClause
, And
27 from manip_fetcher
import *
28 from dbfields
import *
31 class StdAdapter(Adapter
):
32 def _get_fk_value(self
, datasrc
, value
):
33 """ Get the value of a foreign-key field. Works both
34 if the value is a Manip and if it is a normal value.
36 if isinstance(value
, Manip
):
37 pk
= datasrc
.primary_keys
[0]
38 return getattr(value
, pk
)
43 def _get_sqltype(self
, fieldtype
, fieldname
, field
):
44 if isinstance(field
, String
):
45 return "VARCHAR(%d)" % field
.maxlength
46 elif isinstance(field
, DbSmallInt
):
48 elif isinstance(field
, DbLong
):
50 elif isinstance(field
, DbInt
):
52 elif isinstance(field
, Text
):
54 elif isinstance(field
, Date
):
56 elif isinstance(field
, Time
):
58 elif isinstance(field
, DateTime
):
60 elif fieldtype
== Table
.FT_FK
:
61 return self
._get
_sqltype
(Table
.FT_NORM
, fieldname
,
62 field
.datasource
.field
)
64 raise ValueError("unsupported field-type: %r" % field
)
66 def _get_sqlfield(self
, fieldtype
, fieldname
, field
):
71 return "%s %s%s" % (fieldname
,
72 self
._get
_sqltype
(fieldtype
, fieldname
, field
), n
)
75 def _manip_to_args(self
, table
, manip
, ignore
):
76 """ Parse a manip for the information needed by
77 L{insert} and L{update}.
79 @return: (fieldnames, params, mm_dict). Where fieldnames is a list
80 of fieldnames of normal and foreign-key fields. params is
81 a list of the values for normal and foreign-key fields. And
82 mm_dict is a dict of values for many-to-many fields.
88 for fieldname
, value
in manip
.iteritems():
89 if fieldname
in ignore
:
91 field
= table
.model
[fieldname
]
92 if isinstance(field
, DatasourceField
):
93 if isinstance(field
, One
):
94 v
= self
._get
_fk
_value
(
95 field
.datasource
, value
)
97 fieldnames
.append(fieldname
)
98 elif isinstance(field
, Many
):
99 mmtable
= table
.mmtables
[fieldname
]
101 pk
= getattr(manip
, table
.primary_keys
[0])
103 pk2
= self
._get
_fk
_value
(field
.datasource
, val
)
104 m
.append(mmtable
.manip(source
=pk
, target
=pk2
))
108 fieldnames
.append(fieldname
)
109 return fieldnames
, params
, mm
112 def _create_indexes(self
, table
):
113 for index
in table
.indexes
:
114 if isinstance(index
, UniqueIndex
):
118 q
= "CREATE%sINDEX %s ON %s(%s)" % (u
,
119 index
.get_name(table
.name
), table
.name
,
120 ",".join(index
.fields
))
126 def paramgen(self
, index
):
127 """ Parameter generator following the interface specified
128 L{here <query.paramgen_interface>}.
130 Must be implemented in subclasses.
132 raise NotImplementedError()
135 def drop_table(self
, table
):
137 self
.cursor
.execute("DROP TABLE %s" % table
.name
)
141 def insert(self
, table
, manip
, ignore
=[]):
142 fieldnames
, params
, mm
= self
._manip
_to
_args
(
143 table
, manip
, ignore
)
144 values
= [self
.paramgen(i
) for i
, v
in enumerate(params
)]
146 qry
= "INSERT INTO %s(%s) VALUES(%s)" % (table
.name
,
147 ",".join(fieldnames
), ",".join(values
))
148 self
.execute(qry
, params
)
150 # insert into many-to-many tables
151 for fieldname
, manips
in mm
.iteritems():
152 mmtable
= table
.mmtables
[fieldname
]
157 def _manipwhere(self
, table
, manip
):
158 a
= [getattr(table
.c
, pk
) == manip
[pk
] for pk
in table
.primary_keys
]
159 q
, params
= And(*a
).compile(None, self
.paramgen
)
162 def delete(self
, table
, manip
):
163 f
, p
, mm
= self
._manip
_to
_args
(table
, manip
, [])
166 pk
= manip
[table
.get_pk()]
168 mmtable
= table
.mmtables
[fieldname
]
169 mmtable
.qry_delete(mmtable
.c
.source
== pk
)
171 q
, params
= self
._manipwhere
(table
, manip
)
172 qry
= "DELETE FROM %s WHERE %s" % (table
.name
, q
)
173 self
.execute(qry
, params
)
176 def update(self
, table
, manip
, mm_add
=[], ignore
=[]):
177 ignore
.extend(table
.primary_keys
)
178 fieldnames
, params
, mm
= self
._manip
_to
_args
(
179 table
, manip
, ignore
)
180 values
= ["%s=%s" % (fn
, self
.paramgen(i
))
181 for i
, fn
in enumerate(fieldnames
)]
183 q
, p
= self
._manipwhere
(table
, manip
)
185 qry
= "UPDATE %s SET %s WHERE %s" % (table
.name
, ",".join(values
), q
)
186 self
.execute(qry
, params
)
189 # remove all many-to-many columns not in "mm_add"
190 pk
= manip
[table
.get_pk()]
191 for fieldname
, manips
in mm
.iteritems():
192 if not fieldname
in mm_add
:
193 mmtable
= table
.mmtables
[fieldname
]
194 mmtable
.qry_delete(mmtable
.c
.source
== pk
)
196 # insert into many-to-many tables
197 for fieldname
, manips
in mm
.iteritems():
198 mmtable
= table
.mmtables
[fieldname
]
203 def _where_clause(self
, table
, *ops
, **named_ops
):
204 return WhereClause(table
.name
, self
.paramgen
,
205 *ops
, **named_ops
).compile()
208 def select(self
, table
, *ops
, **named_ops
):
209 ignore
= set(named_ops
.pop("ignore", [])) # using set for fast lookup
211 q
, params
= self
._where
_clause
(table
, *ops
, **named_ops
)
214 for fieldname
, field
in table
.model
.iteritems():
215 if fieldname
in table
.mmtables
or fieldname
in ignore
:
217 fields
.append(fieldname
)
219 query
= "SELECT %s FROM %s %s" % (",".join(fields
), table
.name
, q
)
220 c
= self
.new_cursor()
221 self
.execute(query
, params
, c
)
222 return SimpleManipFetcher(c
, table
)
224 def select_one(self
, table
, *ops
, **named_ops
):
225 named_ops
["limit"] = 1
226 r
= self
.select(table
, *ops
, **named_ops
)
232 def _field_alias(self
, fieldname
, table
, sep
):
233 return "%s.%s AS %s%s%s" % (table
.name
, fieldname
,
234 table
.name
, sep
, fieldname
)
236 def jselect(self
, table
, *ops
, **named_ops
):
237 ignore
= set(named_ops
.pop("ignore", [])) # using set for fast lookup
239 if not table
.foreign_keys
:
240 return self
.select(table
, *ops
, **named_ops
)
242 sep
= named_ops
.get("sep", "_")
243 q
, params
= self
._where
_clause
(table
, *ops
, **named_ops
)
247 for t
, fieldname
, field
in table
.iter_fields():
248 if t
== Table
.FT_MM
or fieldname
in ignore
:
251 ftable
= field
.datasource
252 for fn
in ftable
.model
:
253 if not (fn
in ftable
.mmtables
or field
in ignore
):
254 fields
.append(self
._field
_alias
(fn
, ftable
, sep
))
255 joins
.write("\n\tLEFT OUTER JOIN %s ON %s.%s = %s.%s" % (
257 table
.name
, fieldname
,
258 ftable
.name
, ftable
.primary_keys
[0]))
260 fields
.append(self
._field
_alias
(fieldname
, table
, sep
))
262 query
= "SELECT %s\nFROM %s%s\n%s" % (
263 ",\n\t".join(fields
), table
.name
, joins
.getvalue(), q
)
264 c
= self
.new_cursor()
265 self
.execute(query
, params
, c
)
266 return JoinManipFetcher(c
, table
)
268 def jselect_one(self
, table
, *ops
, **named_ops
):
269 named_ops
["limit"] = 1
270 r
= self
.jselect(table
, *ops
, **named_ops
)
276 def rselect(self
, table
, *ops
, **named_ops
):
277 ignore
= set(named_ops
.pop("ignore", [])) # using set for fast lookup
279 q
, params
= self
._where
_clause
(table
, *ops
, **named_ops
)
282 for fieldname
, field
in table
.model
.iteritems():
283 if fieldname
in table
.mmtables
or fieldname
in ignore
:
285 fields
.append(fieldname
)
287 query
= "SELECT %s FROM %s %s" % (",".join(fields
), table
.name
, q
)
288 c
= self
.new_cursor()
289 self
.execute(query
, params
, c
)
290 return RecursiveManipFetcher(c
, table
)
292 def rselect_one(self
, table
, *ops
, **named_ops
):
293 named_ops
["limit"] = 1
294 r
= self
.rselect(table
, *ops
, **named_ops
)