Little tweaks to item fetching.
[straw.git] / straw / storage / SQLiteDAO.py
blob83d6ba63786c4c0822f3078956dbd6c39442b274
1 from straw.model import Category, Feed
2 import SQLiteStorage as Storage
4 class DAO(object):
5 def __init__(self, storage):
6 self.storage = storage
8 def _get_indexed(self, entities, index_field):
9 return dict([(getattr(entity, index_field), entity) for entity in entities])
11 def _get_grouped(self, entities, field):
12 groups = {}
13 previous = None
15 for entity in entities:
16 current = getattr(entity, field)
18 if previous != current:
19 # we have new group, let's make a dict entry for it
20 groups[current] = []
21 previous = current
23 if current != None:
24 groups[current].append(entity)
26 return groups
28 def _is_inherited(self, clazz):
29 return len(clazz.__bases__) == 1 and hasattr(clazz, "inheritance") and clazz.inheritance
31 def _get_pdict_table(self, entity):
32 if self._is_inherited(entity.__class__) and hasattr(entity.__class__.__bases__[0], "pdict_table"):
33 return entity.__class__.__bases__[0].pdict_table
34 elif hasattr(entity.__class__, "pdict_table"):
35 return entity.__class__.pdict_table
36 else:
37 return None
39 def _get_persistent_fields(self, clazz):
40 fields = list(clazz.persistent_properties)
42 if self._is_inherited(clazz):
43 fields.extend(clazz.__bases__[0].persistent_properties)
45 return fields
47 def tx_begin(self):
48 self.storage._tx_begin()
50 def tx_commit(self):
51 self.storage._tx_commit()
53 def query(self, sql):
54 return self.storage.query(sql)
56 def save(self, entity):
57 do_insert = (entity.id == None)
58 pdict_table = None
60 if self._is_inherited(entity.__class__):
61 base_class = entity.__class__.__bases__[0]
62 base_data = dict([(property, getattr(entity, property)) for property in base_class.persistent_properties if property != base_class.primary_key])
64 if hasattr(base_class, "pdict_table"):
65 pdict_table = base_class.pdict_table
67 if do_insert:
68 id = self.storage.insert(base_class.persistent_table, base_data)
69 entity.id = id
70 else:
71 self.storage.update(base_class.persistent_table, entity.id, base_data)
72 else:
73 if hasattr(entity.__class__, "pdict_table"):
74 pdict_table = entity.__class__.pdict_table
76 data = dict([(property, getattr(entity, property)) for property in entity.__class__.persistent_properties if property != entity.__class__.primary_key])
78 if do_insert:
79 id = self.storage.insert(entity.__class__.persistent_table, data)
81 if not entity.id:
82 entity.id = id
83 else:
84 self.storage.update(entity.__class__.persistent_table, entity.id, data)
86 if pdict_table:
87 self.save_pdict(entity)
89 return entity.id
91 def save_pdict(self, entity):
92 pdict_table = self._get_pdict_table(entity)
94 if not pdict_table:
95 return None
97 current_pdict = self.load_pdict(entity)
98 new_pdict = entity.pdict
100 for key in new_pdict.keys():
101 if not new_pdict[key]:
102 del new_pdict[key]
103 continue
105 if not current_pdict.has_key(key):
106 self.storage.insert(pdict_table, { "entity_id": entity.id,
107 "entry_key": key,
108 "entry_value": new_pdict[key] })
109 else:
110 self.storage.update_with_where(pdict_table, "WHERE entity_id = %d AND entry_key = '%s'"
111 % (entity.id, key),
112 { "entry_key": key, "entry_value": new_pdict[key] })
113 del current_pdict[key]
115 # Now we need to remove entries that are not in new_pdict = has been deleted by the user.
117 for key in current_pdict.keys():
118 self.storage.query("DELETE FROM %s WHERE entity_id = %d AND entry_key = '%s'"
119 % (pdict_table, entity.id, key))
121 def load_pdict(self, entity):
122 pdict_table = self._get_pdict_table(entity)
124 if not pdict_table:
125 return None
127 result = self.storage.query("SELECT entry_key, entry_value FROM %s WHERE entity_id = ?" % pdict_table, params = (entity.id,))
129 return dict([(entry["entry_key"], entry["entry_value"]) for entry in result])
131 def get_one(self, clazz, id):
132 result = self.get(clazz, " WHERE id = %s" % id)
133 return result[0]
135 def get(self, clazz, sql = "", query = None, params = []):
136 #import time
137 #s = time.time()
138 if not query:
139 result = self.storage.query("SELECT * FROM %s%s" % (clazz.persistent_table, sql), params)
140 else:
141 result = self.storage.query(query, params)
143 fields = clazz.persistent_properties
145 if len(result) > 0:
146 fields = filter(lambda x: x in fields, result[0].keys())
148 entities = []
150 #print "query took %s" % (time.time() - s)
152 #s = time.time()
154 for row in result:
155 entity = clazz()
156 [setattr(entity, field, row[field]) for field in fields]
157 entities.append(entity)
159 #print "hydration took %s" % (time.time() - s)
161 return entities
163 def get_indexed(self, clazz, index_field, sql = ""):
164 result = self.storage.query("SELECT * FROM %s%s" % (clazz.persistent_table, sql))
165 entities = {}
167 for row in result:
168 entity = clazz()
170 i = 0
171 for field in clazz.persistent_properties:
172 setattr(entity, field, row[i])
173 i += 1
175 entities[getattr(entity, index_field)] = entity
177 return entities
179 def get_nodes(self, sql = ""):
180 #result = self.storage.query("SELECT f.*, feed_id, unread_count FROM feeds f LEFT JOIN (SELECT COUNT(*) AS unread_count, feed_id FROM items i WHERE i.is_read = 0 GROUP BY i.feed_id) ON feed_id = f.id ORDER BY category_id")
181 result = self.storage.query("SELECT *, unread_count FROM nodes n LEFT JOIN feeds f ON (f.id = n.id) LEFT JOIN categories c ON (c.id = n.id) LEFT JOIN (SELECT COUNT(*) AS unread_count, feed_id FROM items i WHERE i.is_read = 0 GROUP BY i.feed_id) ON feed_id = f.id ORDER BY n.parent_id, n.norder")
182 entities = []
184 for row in result:
185 if row["type"] == "C":
186 clazz = Category
187 elif row["type"] == "F":
188 clazz = Feed
190 fields = self._get_persistent_fields(clazz)
191 entity = clazz()
192 entity.unread_count = 0
194 for field in fields:
195 value = row[field]
196 #if row[field] == None:
197 # value = 0
199 setattr(entity, field, value)
201 if row["unread_count"] != None:
202 setattr(entity, "unread_count", row["unread_count"])
204 entity.pdict = self.load_pdict(entity)
206 entities.append(entity)
208 return self._get_indexed(entities, "id")
210 def delete_category(self, id):
211 self.storage.query("DELETE FROM categories WHERE id = ?", (id,))
212 self.storage.query("DELETE FROM node_pdict_entries WHERE entity_id = ?", (id,))
213 self.storage.query("DELETE FROM nodes WHERE id = ?", (id,))
215 def delete_feed(self, id):
216 self.storage.query("DELETE FROM items WHERE feed_id = ?", (id,))
217 self.storage.query("DELETE FROM feeds WHERE id = ?", (id,))
218 self.storage.query("DELETE FROM node_pdict_entries WHERE entity_id = ?", (id,))
219 self.storage.query("DELETE FROM nodes WHERE id = ?", (id,))