Don't hardcode any data file paths
[straw.git] / straw / storage / SQLiteStorage.py
blob5a70e4748fb8aea4852b3b2727bbb1458dbb2008
1 from straw.defs import STRAW_DATA_DIR
2 import os
3 from pysqlite2 import dbapi2 as sqlite
4 import threading
6 DATABASE_FILE_NAME = "data.db"
8 class SQLiteStorage:
9 """
10 Provides an entry point to the database.
12 """
14 def __init__(self, db_path):
15 self._connections = {}
16 self._txs = {}
17 self._db_path = db_path
18 self._init_db()
20 def _get_sql(self):
21 f = open(os.path.join(STRAW_DATA_DIR, 'sql', 'create_01.sql'), 'r')
22 sql = f.read()
23 f.close()
24 return sql.split("--")
26 def _init_db(self):
27 do_init = False
29 try:
30 # Check if we need init.
31 # TODO: use smarter strategy maybe?
32 result = self.query("SELECT COUNT(*) FROM feeds")
33 except Exception:
34 # DB seems not initialized.
35 do_init = True
37 if not do_init:
38 return
40 c = self._connect()
42 try:
43 sql = self._get_sql()
44 for statement in sql:
45 c.execute(statement)
46 self._tx_commit()
47 except Exception, e:
48 print "DB init failed -- %s" % e
50 def _connect(self):
51 key = threading.currentThread()
53 if not self._connections.has_key(key):
54 self._connections[key] = sqlite.connect(self._db_path)
55 self._connections[key].execute("PRAGMA cache_size = 20000;")
56 self._connections[key].execute("PRAGMA synchronous = NORMAL;")
57 self._connections[key].row_factory = sqlite.Row
59 return self._connections[key]
61 def _tx_begin(self):
62 key = threading.currentThread()
64 if not self._txs.has_key(key):
65 self._txs[key] = True
67 def _tx_commit(self):
68 self._connect().commit()
70 key = threading.currentThread()
71 self._txs[key] = False
73 def _in_tx(self):
74 key = threading.currentThread()
75 return self._txs.has_key(key) and self._txs[key]
77 def query(self, query, params = None):
78 if params == None:
79 params = ()
80 cursor = self._connect()
81 #print query
82 res = cursor.execute(query, params)
83 return res.fetchall()
85 def insert(self, table, data):
86 """
87 Inserts some data into the database.
89 @param table: a name of the table to insert to
90 @param data: a dictionary where keys are field names and values
91 are field values in the given table
93 """
95 if len(data) == 0:
96 return None
98 cursor = self._connect().cursor()
99 query = "INSERT INTO %s (%s) VALUES (?%s)" % (table, ", ".join(data.keys()),
100 ", ?" * (len(data.keys()) - 1))
101 #print query
102 #print data
103 cursor.execute(query, data.values())
105 #print self._in_tx()
106 if not self._in_tx():
107 self._tx_commit()
109 return cursor.lastrowid
111 def update(self, table, id, data):
113 Updates single row identified by a primary key in the table.
115 @param table: a name of the table to do the update in
116 @param id: primary key of the item to update
117 @param data: a dictionary where keys are field names and values
118 are field values in the given table
122 params = data.values()
123 cursor = self._connect().cursor()
124 assignments = ", ".join([("%s = ?" % field_name) for field_name in data.keys()])
125 query = "UPDATE %s SET %s WHERE id = ?" % (table, assignments)
126 #print query
127 params.append(id)
128 cursor.execute(query, params)
130 if not self._in_tx():
131 self._tx_commit()
133 return cursor.lastrowid