c89f26307e7b169a6c249962776b2d6fa67f9a07
[straw.git] / straw / storage / SQLiteStorage.py
blobc89f26307e7b169a6c249962776b2d6fa67f9a07
1 from straw.defs import STRAW_DATA_DIR
2 import os
3 from pysqlite2 import dbapi2 as sqlite
4 import threading
6 DATABASE_FILE_NAME = "data.db"
8 class SQLiteStorage:
9 """
10 Provides an entry point to the database.
12 """
14 def __init__(self, db_path):
15 self._connections = {}
16 self._txs = {}
17 self._db_path = db_path
18 self._init_db()
20 def _get_sql(self):
21 f = open(os.path.join(STRAW_DATA_DIR, 'sql', 'create_01.sql'), 'r')
22 sql = f.read()
23 f.close()
24 return sql.split("--")
26 def _init_db(self):
27 do_init = False
29 try:
30 # Check if we need init.
31 # TODO: use smarter strategy maybe?
32 self.query("SELECT COUNT(*) FROM feeds")
33 except Exception:
34 # DB seems not initialized.
35 do_init = True
37 if not do_init:
38 return
40 c = self._connect()
42 try:
43 sql = self._get_sql()
45 self._tx_begin()
47 for statement in sql:
48 c.execute(statement)
50 self._tx_commit()
51 except Exception, e:
52 print "DB init failed -- %s" % e
54 def _connect(self):
55 key = threading.currentThread()
57 if not self._connections.has_key(key):
58 self._connections[key] = sqlite.connect(self._db_path)
59 self._connections[key].execute("PRAGMA cache_size = 20000;")
60 self._connections[key].execute("PRAGMA synchronous = NORMAL;")
61 self._connections[key].row_factory = sqlite.Row
63 return self._connections[key]
65 def _tx_begin(self):
66 key = threading.currentThread()
68 if not self._txs.has_key(key):
69 self._txs[key] = True
71 def _tx_commit(self):
72 key = threading.currentThread()
73 self._connect().commit()
74 self._txs[key] = False
76 def _in_tx(self):
77 key = threading.currentThread()
78 return self._txs.has_key(key) and self._txs[key]
80 def query(self, query, params = None):
81 if params == None:
82 params = ()
83 cursor = self._connect()
84 #print query
85 res = cursor.execute(query, params)
86 return res.fetchall()
88 def insert(self, table, data):
89 """
90 Inserts some data into the database.
92 @param table: a name of the table to insert to
93 @param data: a dictionary where keys are field names and values
94 are field values in the given table
96 """
98 if len(data) == 0:
99 return None
101 cursor = self._connect().cursor()
102 query = "INSERT INTO %s (%s) VALUES (?%s)" % (table, ", ".join(data.keys()),
103 ", ?" * (len(data.keys()) - 1))
104 #print query
105 #print data
106 cursor.execute(query, data.values())
108 #print self._in_tx()
109 if not self._in_tx():
110 self._tx_commit()
112 return cursor.lastrowid
114 def update(self, table, id, data):
116 Updates single row identified by a primary key in the table.
118 @param table: a name of the table to do the update in
119 @param id: primary key of the item to update
120 @param data: a dictionary where keys are field names and values
121 are field values in the given table
125 return self.update_with_where(table, "WHERE id = %d" % id, data)
127 def update_with_where(self, table, where_clause, data):
128 params = data.values()
129 cursor = self._connect().cursor()
130 assignments = ", ".join([("%s = ?" % field_name) for field_name in data.keys()])
131 query = "UPDATE %s SET %s %s" % (table, assignments, where_clause)
133 cursor.execute(query, params)
135 if not self._in_tx():
136 self._tx_commit()
138 return cursor.lastrowid