Fixes (workarounds) in OPML parsing, more work on GUI...
[straw/fork.git] / straw / storage / SQLiteDAO.py
blob47f8217a30d8f57d024718f3b925fdfd00f95102
1 from straw.model import Category, Feed
2 import SQLiteStorage as Storage
4 class DAO(object):
5 def __init__(self, storage):
6 self.storage = storage
8 def _get_indexed(self, entities, index_field):
9 return dict([(getattr(entity, index_field), entity) for entity in entities])
11 def _get_grouped(self, entities, field):
12 groups = {}
13 previous = None
15 for entity in entities:
16 current = getattr(entity, field)
18 if previous != current:
19 # we have new group, let's make a dict entry for it
20 groups[current] = []
21 previous = current
23 if current != None:
24 groups[current].append(entity)
26 return groups
28 def _is_inherited(self, clazz):
29 return len(clazz.__bases__) == 1 and hasattr(clazz, "inheritance") and clazz.inheritance
31 def _get_persistent_fields(self, clazz):
32 fields = list(clazz.persistent_properties)
34 if self._is_inherited(clazz):
35 fields.extend(clazz.__bases__[0].persistent_properties)
37 return fields
39 def tx_begin(self):
40 self.storage._tx_begin()
42 def tx_commit(self):
43 self.storage._tx_commit()
45 def save(self, entity):
46 do_insert = entity.id == None
48 if self._is_inherited(entity.__class__):
49 base_class = entity.__class__.__bases__[0]
50 base_data = dict([(property, getattr(entity, property)) for property in base_class.persistent_properties if property != base_class.primary_key])
52 if do_insert:
53 id = self.storage.insert(base_class.persistent_table, base_data)
54 entity.id = id
55 else:
56 self.storage.update(base_class.persistent_table, entity.id, base_data)
58 data = dict([(property, getattr(entity, property)) for property in entity.__class__.persistent_properties if property != entity.__class__.primary_key])
60 if do_insert:
61 id = self.storage.insert(entity.__class__.persistent_table, data)
63 if entity.id == None:
64 entity.id = id
65 else:
66 self.storage.update(entity.__class__.persistent_table, entity.id, data)
68 #print "save = %s" % str(time.time() - s)
69 return entity.id
71 def get_one(self, clazz, id):
72 #result = self.storage.query("SELECT * FROM %s WHERE id = ?" % clazz.persistent_table, params = (id,))
73 result = self.get(clazz, " WHERE id = %s" % id)
74 return result[0]
76 def get(self, clazz, sql = "", params = []):
77 result = self.storage.query("SELECT * FROM %s%s" % (clazz.persistent_table, sql), params)
78 entities = []
80 for row in result:
81 entity = clazz()
83 i = 0
84 for field in clazz.persistent_properties:
85 setattr(entity, field, row[i])
86 i += 1
88 entities.append(entity)
90 return entities
92 def get_with_params(self, clazz, params):
93 #self.storage.query(query, params = None)
94 where_clause = ""
95 for key in params.keys():
96 where_clause
97 self.storage.query("SELECT * FROM %s WHERE %s" % (clazz.persistent_table, where_clause))
99 def get_indexed(self, clazz, index_field, sql = ""):
100 result = self.storage.query("SELECT * FROM %s%s" % (clazz.persistent_table, sql))
101 entities = {}
103 for row in result:
104 entity = clazz()
106 i = 0
107 for field in clazz.persistent_properties:
108 setattr(entity, field, row[i])
109 i += 1
111 entities[getattr(entity, index_field)] = entity
113 return entities
115 def get_nodes(self, sql = ""):
116 #result = self.storage.query("SELECT f.*, feed_id, unread_count FROM feeds f LEFT JOIN (SELECT COUNT(*) AS unread_count, feed_id FROM items i WHERE i.is_read = 0 GROUP BY i.feed_id) ON feed_id = f.id ORDER BY category_id")
117 result = self.storage.query("SELECT *, unread_count FROM nodes n LEFT JOIN feeds f ON (f.id = n.id) LEFT JOIN categories c ON (c.id = n.id) LEFT JOIN (SELECT COUNT(*) AS unread_count, feed_id FROM items i WHERE i.is_read = 0 GROUP BY i.feed_id) ON feed_id = f.id ORDER BY n.parent_id, n.norder")
118 entities = []
120 for row in result:
121 if row["type"] == "C":
122 clazz = Category
123 elif row["type"] == "F":
124 clazz = Feed
126 fields = self._get_persistent_fields(clazz)
127 entity = clazz()
128 entity.unread_count = 0
130 for field in fields:
131 value = row[field]
132 #if row[field] == None:
133 # value = 0
135 setattr(entity, field, value)
137 if row["unread_count"] != None:
138 setattr(entity, "unread_count", row["unread_count"])
140 entities.append(entity)
142 return self._get_indexed(entities, "id"), self._get_grouped(entities, "parent_id")