Cleanup.
[straw.git] / straw / storage / SQLiteDAO.py
blob8773e9ee7e2153b793840b295bf61d16199261e3
1 from straw.model import Category, Feed
2 import SQLiteStorage as Storage
4 class DAO(object):
5 def __init__(self, storage):
6 self.storage = storage
8 def _get_indexed(self, entities, index_field):
9 return dict([(getattr(entity, index_field), entity) for entity in entities])
11 def _get_grouped(self, entities, field):
12 groups = {}
13 previous = None
15 for entity in entities:
16 current = getattr(entity, field)
18 if previous != current:
19 # we have new group, let's make a dict entry for it
20 groups[current] = []
21 previous = current
23 if current != None:
24 groups[current].append(entity)
26 return groups
28 def _is_inherited(self, clazz):
29 return len(clazz.__bases__) == 1 and hasattr(clazz, "inheritance") and clazz.inheritance
31 def _get_persistent_fields(self, clazz):
32 fields = list(clazz.persistent_properties)
34 if self._is_inherited(clazz):
35 fields.extend(clazz.__bases__[0].persistent_properties)
37 return fields
39 def tx_begin(self):
40 self.storage._tx_begin()
42 def tx_commit(self):
43 self.storage._tx_commit()
45 def query(self, sql):
46 return self.storage.query(sql)
48 def save(self, entity):
49 do_insert = entity.id == None
51 if self._is_inherited(entity.__class__):
52 base_class = entity.__class__.__bases__[0]
53 base_data = dict([(property, getattr(entity, property)) for property in base_class.persistent_properties if property != base_class.primary_key])
55 if do_insert:
56 id = self.storage.insert(base_class.persistent_table, base_data)
57 entity.id = id
58 else:
59 self.storage.update(base_class.persistent_table, entity.id, base_data)
61 data = dict([(property, getattr(entity, property)) for property in entity.__class__.persistent_properties if property != entity.__class__.primary_key])
63 if do_insert:
64 id = self.storage.insert(entity.__class__.persistent_table, data)
66 if entity.id == None:
67 entity.id = id
68 else:
69 self.storage.update(entity.__class__.persistent_table, entity.id, data)
71 #print "save = %s" % str(time.time() - s)
72 return entity.id
74 def get_one(self, clazz, id):
75 #result = self.storage.query("SELECT * FROM %s WHERE id = ?" % clazz.persistent_table, params = (id,))
76 result = self.get(clazz, " WHERE id = %s" % id)
77 return result[0]
79 def get(self, clazz, sql = "", query = None, params = []):
80 import time
81 s = time.time()
82 if not query:
83 result = self.storage.query("SELECT * FROM %s%s" % (clazz.persistent_table, sql), params)
84 else:
85 result = self.storage.query(query, params)
87 entities = []
89 for row in result:
90 entity = clazz()
91 [setattr(entity, field, row[field]) for field in clazz.persistent_properties]
92 entities.append(entity)
94 print time.time() - s
96 return entities
98 def get_with_params(self, clazz, params):
99 #self.storage.query(query, params = None)
100 where_clause = ""
101 for key in params.keys():
102 where_clause
103 self.storage.query("SELECT * FROM %s WHERE %s" % (clazz.persistent_table, where_clause))
105 def get_indexed(self, clazz, index_field, sql = ""):
106 result = self.storage.query("SELECT * FROM %s%s" % (clazz.persistent_table, sql))
107 entities = {}
109 for row in result:
110 entity = clazz()
112 i = 0
113 for field in clazz.persistent_properties:
114 setattr(entity, field, row[i])
115 i += 1
117 entities[getattr(entity, index_field)] = entity
119 return entities
121 def get_nodes(self, sql = ""):
122 #result = self.storage.query("SELECT f.*, feed_id, unread_count FROM feeds f LEFT JOIN (SELECT COUNT(*) AS unread_count, feed_id FROM items i WHERE i.is_read = 0 GROUP BY i.feed_id) ON feed_id = f.id ORDER BY category_id")
123 result = self.storage.query("SELECT *, unread_count FROM nodes n LEFT JOIN feeds f ON (f.id = n.id) LEFT JOIN categories c ON (c.id = n.id) LEFT JOIN (SELECT COUNT(*) AS unread_count, feed_id FROM items i WHERE i.is_read = 0 GROUP BY i.feed_id) ON feed_id = f.id ORDER BY n.parent_id, n.norder")
124 entities = []
126 for row in result:
127 if row["type"] == "C":
128 clazz = Category
129 elif row["type"] == "F":
130 clazz = Feed
132 fields = self._get_persistent_fields(clazz)
133 entity = clazz()
134 entity.unread_count = 0
136 for field in fields:
137 value = row[field]
138 #if row[field] == None:
139 # value = 0
141 setattr(entity, field, value)
143 if row["unread_count"] != None:
144 setattr(entity, "unread_count", row["unread_count"])
146 entities.append(entity)
148 return self._get_indexed(entities, "id")
150 def delete_category(self, id):
151 self.storage.query("DELETE FROM categories WHERE id = ?", (id,))
152 self.storage.query("DELETE FROM nodes WHERE id = ?", (id,))
154 def delete_feed(self, id):
155 self.storage.query("DELETE FROM items WHERE feed_id = ?", (id,))
156 self.storage.query("DELETE FROM feeds WHERE id = ?", (id,))
157 self.storage.query("DELETE FROM nodes WHERE id = ?", (id,))