readme
[lsqlite.git] / lsqlite.c
blob9b86cc0b4bdfa34c35e7d56b542db0fd7e416e48
1 #include <errno.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <lua.h>
5 #include <lualib.h>
6 #include <lauxlib.h>
7 #include <fcntl.h>
8 #include <unistd.h>
9 #include <sqlite3.h>
10 #include <assert.h>
12 typedef struct {
13 sqlite3 *db;
14 int changes;
15 char name[1];
16 } sqlite_db;
18 typedef struct {
19 int nst, npars;
20 sqlite3_stmt *st[0];
21 } sqlite_st;
23 #define SQLITEDB "sqlite.db*"
24 #define SQLITEST "sqlite.st*"
25 #define SQLITEUV "sqlite.uv*"
26 #define SQLITEQS "sqlite.qs*"
27 #define MAX_ST 128
29 #define SLUV_ST lua_upvalueindex(1) /* Metatable of statement entries. */
30 #define SLUV_DB lua_upvalueindex(2) /* Metatable of DB object entries. */
31 #define SLUV_QS lua_upvalueindex(3) /* Table of stmt->qs mappings. */
32 #define SLUV_ROWS lua_upvalueindex(4) /* db:rows() aux. */
33 #define SLUV_COLS lua_upvalueindex(5) /* db:cols() aux. */
35 /* Lua 5.[12] compat */
36 #if LUA_VERSION_NUM < 503
37 static void lua_setuservalue(lua_State *L, int idx)
39 lua_pushvalue(L, idx); /* +1 target */
40 lua_pushliteral(L, SQLITEUV);
41 lua_rawget(L, LUA_REGISTRYINDEX); /* +2 tab */
42 lua_pushvalue(L, -2);
43 lua_pushvalue(L, -4); /* top stack caller */
44 lua_rawset(L, -3);
45 lua_pop(L, 3);
47 static void lua_getuservalue(lua_State *L, int idx)
49 lua_pushvalue(L, idx);
50 lua_pushliteral(L, SQLITEUV);
51 lua_rawget(L, LUA_REGISTRYINDEX);
52 lua_pushvalue(L, -2);
53 lua_rawget(L, -2);
54 lua_replace(L, -3);
55 lua_pop(L, 1);
57 #if LUA_VERSION_NUM < 502
58 #define luaL_setfuncs(L, reg, nup) luaL_openlib(L, NULL, reg, nup)
59 #endif
60 #endif
62 /* Load uservalue of given metatable type */
63 static void *sl_uvdata(lua_State *L, int idx, int uvidx)
65 void *data = lua_touserdata(L, idx);
66 if (!data)
67 return data;
68 if (!lua_getmetatable(L, idx))
69 return NULL;
70 if (!lua_rawequal(L, -1, uvidx))
71 data = NULL;
72 lua_pop(L, 1);
73 return data;
76 /* Load statement object */
77 static inline sqlite_st *sl_tost(lua_State *L, int idx, int chk)
79 sqlite_st *st = sl_uvdata(L, idx, SLUV_ST);
80 if (!st || (chk && st->nst == -1)) {
81 luaL_argerror(L, idx, "invalid sqlite statement (stale ref?)");
83 return st;
87 /* Load DB object */
88 static inline sqlite_db *sl_todb(lua_State *L, int idx, int chk)
90 sqlite_db *db = sl_uvdata(L, idx, SLUV_DB);
91 if (!db || (chk && !db->db))
92 luaL_argerror(L, idx, "invalid sqlite database (stale ref?)");
93 return db;
96 /*
97 * Statements caching works as follows:
99 * 1. when statement emerges in Lua world (that is, exposed via iterator),
100 * it is uncached, otherwise it is cached.
102 * 2. uncached statements have their underlying database set as uservalue, they
103 * also get a metatable attached, where sl_cacheback() is __gc.
105 * 3. cached statements are stored in a cache table for each db object. db
106 * objects have that table as uservalue. cached statements don't have __gc
107 * set. their uservalue forms a linked list to next cached statement under
108 * same query string. it looks like this:
110 * db->uservalue[qs] = { stmt->uservalue { stmt->uservalue { ... }}
112 * All of this allows us to cache statements with little to no gc pressure.
114 static void nuke_stmt(lua_State *L, sqlite_st *st)
116 int i;
117 for (i = 0; i < st->nst; i++)
118 sqlite3_finalize(st->st[i]);
119 st->nst = -1;
120 lua_pushlightuserdata(L, st);
121 lua_pushnil(L);
122 lua_rawset(L, SLUV_QS);
125 static int sl_cacheback(lua_State *L)
127 sqlite_st *st = sl_tost(L, 1, 0);
128 sqlite_db *db;
129 int i;
130 lua_settop(L,1);
131 lua_getuservalue(L, 1); /* load db, stack = 2 */
132 db = sl_todb(L, 2, 0);
134 /* Meanwhile the database closed */
135 if (!db->db) {
136 /* So just clear statements. */
137 nuke_stmt(L, st);
138 return 0;
141 lua_getuservalue(L, 2); /* load db's cache table, stack = 3 */
142 assert(lua_istable(L, -1));
143 assert(lua_gettop(L) == 3);
145 /* Lookup the query string */
146 lua_pushlightuserdata(L, st);
147 lua_rawget(L, SLUV_QS); /* stack = 4 */
148 assert(lua_isstring(L, -1));
150 /* And load cache chain */
151 lua_pushvalue(L, -1);
152 lua_rawget(L, 3);
154 /* Chain that result into our uservalue. */
155 lua_setuservalue(L, 1);
157 /* And set it as cache entry. */
158 lua_pushvalue(L, 4); /* qs */
159 lua_pushvalue(L, 1); /* our entry */
160 lua_rawset(L, 3);
162 /* Kill metatable of this statement .*/
163 lua_pushnil(L);
164 lua_setmetatable(L, 1);
166 /* So that sqlite frees unused memory. */
167 for (i = 0; i < st->nst; i++) {
168 sqlite3_clear_bindings(st->st[i]);
169 sqlite3_reset(st->st[i]);
171 return 0;
174 /* Find or create statement. Returns statement on top, cache table just below. */
175 static sqlite_st *do_prepare(lua_State *L, int dbidx, int qsidx, int uncache,
176 sqlite_db **db)
178 size_t qsl;
179 const char *qs, *qsp;
180 sqlite3_stmt *stmts[MAX_ST];
181 int count = 0;
182 int npars = 0;
183 sqlite_st *st;
185 /* Get cache. */
186 *db = sl_todb(L, dbidx, 1);
187 lua_getuservalue(L, dbidx);
188 assert(lua_istable(L, -1));
190 /* Lookup query string. */
191 lua_pushvalue(L, qsidx);
192 lua_rawget(L, -2);
194 /* Found cached? */
195 if (!lua_isnil(L, -1)) {
196 if (uncache) {
197 /* Unlink the entry. */
198 lua_pushvalue(L, qsidx);
199 lua_getuservalue(L, -2); /* Next in chain. */
200 lua_rawset(L, -4);
202 st = lua_touserdata(L, -1);
203 } else {
204 lua_pop(L, 1);
205 /* Build new. */
206 qs = luaL_checklstring(L, 2, &qsl);
207 for (qsp = qs; *qsp; ) {
208 sqlite3_stmt *s = NULL;
209 int err;
211 if (count >= MAX_ST) {
212 int i;
213 for (i = 0; i < count; i++)
214 sqlite3_finalize(stmts[i]);
215 luaL_error(L, "Too many statements (max %d)", MAX_ST);
217 err = sqlite3_prepare_v2((*db)->db, qsp, qs + qsl - qsp,
218 stmts + count++, &qsp);
219 if (err != SQLITE_OK) {
220 int i;
221 for (i = 0; i < count; i++)
222 sqlite3_finalize(stmts[i]);
223 luaL_error(L, "%s", sqlite3_errmsg((*db)->db));
225 npars += sqlite3_bind_parameter_count(s);
227 st = lua_newuserdata(L, sizeof(*st) + count * sizeof(st->st[0]));
228 memset(st, 0, sizeof(*st) + count * sizeof(st->st[0]));
229 memcpy(st->st, stmts, count*sizeof(st->st[0]));
230 st->npars = npars;
231 st->nst = count;
233 /* Remember for UD->QS lookup */
234 lua_pushlightuserdata(L, st);
235 lua_pushvalue(L, 2);
236 lua_rawset(L, SLUV_QS);
238 if (!uncache) {
239 /* Link in the entry. */
240 lua_pushvalue(L, qsidx);
241 lua_pushvalue(L, -2);
242 assert(lua_istable(L, -4));
243 lua_rawset(L, -4);
244 } else {
245 /* Set its metatable if it gets uncached. */
246 lua_pushvalue(L, SLUV_ST);
247 lua_setmetatable(L, -2);
248 /* And uservalue to db */
249 lua_pushvalue(L, dbidx);
250 lua_setuservalue(L, -2);
252 return st;
255 /* Bind one statement. Return number of parameters bound. */
256 static int do_bind(lua_State *L, sqlite3 *db, sqlite3_stmt *st,
257 int pars, int count, int names)
259 int err,i,j,bn = sqlite3_bind_parameter_count(st);
260 i = 0;
262 if ((err = sqlite3_reset(st)) != SQLITE_OK)
263 goto out;
264 for (i = 1, j = pars; i <= bn; i++) {
265 const char *bname;
266 int tj = j;
267 if (names && (bname = sqlite3_bind_parameter_name(st, i))) {
268 lua_getfield(L, names, bname + 1);
269 tj = -1;
270 } else {
271 if (j >= pars+count)
272 break;
273 j++;
275 if (lua_isboolean(L, tj)) {
276 err = sqlite3_bind_int(st, i, lua_toboolean(L, tj));
277 #if LUA_VERSION_NUM >= 503
278 } else if (lua_isinteger(L, tj)) {
279 err = sqlite3_bind_int64(st, i, lua_tointeger(L, tj));
280 #endif
281 } else if (lua_isnumber(L, tj)) {
282 err = sqlite3_bind_double(st, i, lua_tonumber(L, tj));
283 } else if (lua_isnil(L, tj)) {
284 err = sqlite3_bind_null(st, i);
285 } else {
286 size_t sl;
287 const char *s = luaL_checklstring(L, tj, &sl);
288 err = sqlite3_bind_text(st, i, s, sl, NULL);
290 if (tj == -1)
291 lua_pop(L, 1);
292 if (err != SQLITE_OK)
293 goto out;
295 for (; i <= bn; i++) {
296 j++;
297 if ((err = sqlite3_bind_null(st, i))!=SQLITE_OK)
298 break;
300 out:;
301 luaL_argcheck(L, (err == SQLITE_OK), j-1, sqlite3_errmsg(db));
302 return bn;
305 /* Pull statement from idx, bind arguments. Must have clean (call) stack top. */
306 static sqlite_st *do_binds(lua_State *L, int dbidx, int qsidx, int uncache,
307 sqlite_db **db)
309 int avail = lua_gettop(L) - qsidx;
310 sqlite_st *st = do_prepare(L, dbidx, qsidx, uncache, db);
311 int i, parpos = qsidx+1;
312 int names = 0;
313 if (avail && lua_istable(L, parpos)) {
314 avail--;
315 names = parpos++;
317 for (i = 0; i < st->nst; i++) {
318 int got = do_bind(L, (*db)->db, st->st[i], parpos,
319 avail<0?0:avail, names);
320 avail -= got;
321 parpos += got;
323 return st;
326 /* Push one row column value at `idx` to the stack. */
327 static void push_field(lua_State *L, struct sqlite3_stmt *row, int idx)
329 switch (sqlite3_column_type(row, idx)) {
330 case SQLITE_INTEGER:
331 lua_pushinteger(L, sqlite3_column_int64(row, idx));
332 return;
333 case SQLITE_FLOAT:
334 lua_pushnumber(L, sqlite3_column_double(row, idx));
335 return;
336 case SQLITE_TEXT:
337 case SQLITE_BLOB: {
338 const char *p = sqlite3_column_blob(row, idx);
339 if (!p) lua_pushnil(L); else
340 lua_pushlstring(L, p, sqlite3_column_bytes(row, idx));
341 return;
343 case SQLITE_NULL:
344 lua_pushnil(L);
345 return;
346 default:
347 abort();
351 /* Push all columns of a row on the stack. Returns number of columns. */
352 static int push_fields(lua_State *L, sqlite3_stmt *row)
354 int i, n = sqlite3_data_count(row);
355 for (i = 0; i < n; i++)
356 push_field(L, row, i);
357 return n;
360 /* Set columns as named keys/values in table `tab`. */
361 static void set_fields(lua_State *L, sqlite3_stmt *row, int tab)
363 int i, n = sqlite3_data_count(row);
364 for (i = 0; i < n; i++) {
365 push_field(L, row, i);
366 lua_setfield(L, tab, sqlite3_column_name(row, i));
370 /* Do one row step, accumulate columns on stack (concatenate rows). */
371 static int row_step(lua_State *L, sqlite_st *st, sqlite_db *db)
373 int i, total = 0;
374 for (i = 0; i < st->nst; i++) {
375 int err = sqlite3_step(st->st[i]);
376 if (err == SQLITE_DONE) {
377 sqlite3_reset(st->st[i]);
378 } else if (err == SQLITE_ROW) {
379 total += push_fields(L, st->st[i]);
380 sqlite3_reset(st->st[i]);
381 } else {
382 luaL_error(L, "while executing statement #%d: %s", i,
383 sqlite3_errmsg(db->db));
386 return total;
389 /* changed, col1, col2, .. = db:exec(stmts) */
390 static int sl_exec(lua_State *L)
392 sqlite_db *db;
393 sqlite_st *st = do_binds(L, 1, 2, 0, &db);
394 int total, bchanges = sqlite3_total_changes(db->db);
395 total = row_step(L, st, db);
396 lua_pushinteger(L, sqlite3_total_changes(db->db) - bchanges);
397 lua_replace(L, -total-2);
398 return total+1;
401 /* col1, col2 .. = db:row(stmts) */
402 static int sl_row(lua_State *L)
404 sqlite_db *db;
405 sqlite_st *st = do_binds(L, 1, 2, 0, &db);
406 return row_step(L, st, db);
409 /* Perform one step and collect all columns of a row as k/vs into table `tab`. */
410 static void col_step(lua_State *L, sqlite_st *st, sqlite_db *db, int ttab)
412 int i;
413 for (i = 0; i < st->nst; i++) {
414 int err = sqlite3_step(st->st[i]);
415 if (err == SQLITE_DONE) {
416 sqlite3_reset(st->st[i]);
417 } else if (err == SQLITE_ROW) {
418 set_fields(L, st->st[i], ttab);
419 sqlite3_reset(st->st[i]);
420 } else {
421 luaL_error(L, "while executing statement #%d: %s", i,
422 sqlite3_errmsg(db->db));
427 /* {tab.colname, tab.colname2..} = db:col(stmts) */
428 static int sl_col(lua_State *L)
430 sqlite_db *db;
431 sqlite_st *st = do_binds(L, 1, 2, 0, &db);
432 lua_newtable(L);
433 col_step(L, st, db, lua_gettop(L));
434 return 1;
437 /* {tab.colname, tab.colname2..} = db:tcol(tab, stmts) */
438 static int sl_tcol(lua_State *L)
440 sqlite_db *db;
441 sqlite_st *st = do_binds(L, 1, 3, 0, &db);
442 col_step(L, st, db, 2);
443 lua_settop(L, 2);
444 return 1;
447 /* Return number of rows changed since last call. */
448 static int sl_changes(lua_State *L)
450 sqlite_db *db = sl_todb(L, 1, 1);
451 int prev = db->changes;
452 db->changes = sqlite3_total_changes(db->db);
453 lua_pushinteger(L, db->changes - prev);
454 return 1;
457 /* For loop iterator for db:cols() */
458 static int sl_cols_aux(lua_State *L)
460 sqlite_st *st = sl_tost(L, 1, 1);
461 int err, curridx = luaL_checkinteger(L, 2);
462 lua_settop(L, 2);
463 retry:;
464 if (curridx > st->nst) /* We're done here, put back into cache. */
465 return sl_cacheback(L);
466 err = sqlite3_step(st->st[curridx-1]);
467 if (err == SQLITE_DONE) {
468 lua_pop(L, 1);
469 curridx++;
470 lua_pushinteger(L, curridx);
471 goto retry;
472 } else if (err == SQLITE_ROW) {
473 lua_createtable(L, 0, sqlite3_data_count(st->st[curridx-1])+1);
474 set_fields(L, st->st[curridx-1], -1);
475 return 2;
476 } else {
477 sqlite_db *db;
478 lua_getuservalue(L, 1);
479 db = sl_todb(L, -1, 0);
480 sl_cacheback(L);
481 luaL_error(L, "while executing statement #%d: %s", curridx,
482 sqlite3_errmsg(db->db));
484 return 0;
487 /* for idx,tab in db:cols() iterator producer */
488 static int sl_cols(lua_State *L)
490 sqlite_db *db;
491 do_binds(L, 1, 2, 1, &db);
492 lua_pushvalue(L, SLUV_COLS);
493 lua_pushvalue(L, -2);
494 lua_pushinteger(L, 1);
495 return 3;
498 /* db:rows() iterator */
499 static int sl_rows_aux(lua_State *L)
501 sqlite_st *st = sl_tost(L, 1, 1);
502 int err, curridx = luaL_checkinteger(L, 2);
503 lua_settop(L, 2);
504 retry:;
505 if (curridx > st->nst) /* We're done here, put back into cache. */
506 return sl_cacheback(L);
507 err = sqlite3_step(st->st[curridx-1]);
508 if (err == SQLITE_DONE) {
509 lua_pop(L, 1);
510 curridx++;
511 lua_pushinteger(L, curridx);
512 goto retry;
513 } else if (err == SQLITE_ROW) {
514 return push_fields(L, st->st[curridx-1])+1;
515 } else {
516 sqlite_db *db;
517 lua_getuservalue(L, 1);
518 db = sl_todb(L, -1, 0);
519 sl_cacheback(L);
520 luaL_error(L, "while executing statement #%d: %s", curridx,
521 sqlite3_errmsg(db->db));
523 return 0;
526 /* for idx,col1,col2.. in db:rows() */
527 static int sl_rows(lua_State *L)
529 sqlite_db *db;
530 do_binds(L, 1, 2, 1, &db);
531 lua_pushvalue(L, SLUV_ROWS);
532 lua_pushvalue(L, -2);
533 lua_pushinteger(L, 1);
534 return 3;
537 /* Open a database file. */
538 static int sl_open(lua_State *L)
540 sqlite3 *sql = NULL;
541 sqlite_db *db;
542 size_t nlen;
543 const char *fn = luaL_checklstring(L, 1, &nlen);
545 int err = sqlite3_open(fn, &sql);
546 if (!sql)
547 luaL_error(L, "failed to open '%s': %s", fn, sqlite3_errstr(err));
548 db = lua_newuserdata(L, sizeof(*db) + nlen);
549 db->db = sql;
550 db->changes = sqlite3_total_changes(sql);
551 strcpy(db->name, fn);
552 lua_pushvalue(L, SLUV_DB);
553 lua_setmetatable(L, -2);
554 lua_newtable(L);
555 lua_setuservalue(L, -2);
556 return 1;
559 /* Close the handle. */
560 static int sl_close(lua_State *L)
562 int err = 0;
563 sqlite_db *db = sl_uvdata(L, 1, SLUV_DB);
564 if (!db) return 0;
565 lua_settop(L, 1);
566 if (db->db) {
567 lua_getuservalue(L, 1);
568 lua_pushnil(L);
569 while (lua_next(L, 2)) { /* Drop cached statements. */
570 do {
571 sqlite_st *st = lua_touserdata(L, -1);
572 nuke_stmt(L, st);
573 lua_getuservalue(L, -1);
574 lua_replace(L, -2);
575 } while (!lua_isnil(L, -1));
576 lua_pop(L, 1);
578 lua_pushnil(L); /* Drop cache table. */
579 lua_setuservalue(L, 1);
580 err = sqlite3_close_v2(db->db);
581 if (err == SQLITE_OK)
582 db->db = NULL; /* Signal that it is closed. */
584 lua_pushinteger(L, err);
585 return 1;
588 static luaL_Reg sl_api[] = {
589 { "open", sl_open },
590 { "close", sl_close },
591 { NULL, NULL }
594 static luaL_Reg db_meth[] = {
595 { "exec", sl_exec },
596 { "row", sl_row },
597 { "col", sl_col },
598 { "tcol", sl_tcol },
599 { "rows", sl_rows },
600 { "cols", sl_cols },
601 { "changes", sl_changes },
602 { "__gc", sl_close },
603 { NULL, NULL }
606 int luaopen_sqlite(lua_State *L)
608 int i;
609 #if LUA_VERSION_NUM < 503
610 luaL_newmetatable(L, SQLITEUV);
611 lua_newtable(L);
612 lua_pushliteral(L, "k");
613 lua_setfield(L, -2, "__mode");
614 lua_setmetatable(L, -2);
615 #endif
616 lua_settop(L, 0);
618 lua_newtable(L); /* SLUV_ST */
619 lua_newtable(L); /* SLUV_DB */
620 lua_newtable(L); /* SLUV_QS */
621 for (i = 1; i <= 3; i++) lua_pushvalue(L, i);
622 lua_pushcclosure(L, sl_rows_aux, 3); /* SLUV_ROWS */
623 for (i = 1; i <= 3; i++) lua_pushvalue(L, i);
624 lua_pushcclosure(L, sl_cols_aux, 3); /* SLUV_COLS */
626 /* SLUV_ST */
627 for (i = 1; i <= 5; i++) lua_pushvalue(L, i);
628 lua_pushcclosure(L, sl_cacheback, 5);
629 lua_setfield(L, 1, "__gc");
630 lua_pushvalue(L, 1);
631 lua_setmetatable(L, 1);
633 /* SLUV_DB */
634 lua_pushvalue(L, 2);
635 lua_setfield(L, 2, "__index");
636 lua_pushvalue(L, 2);
637 for (i = 1; i <= 5; i++) lua_pushvalue(L, i);
638 luaL_setfuncs(L, db_meth, 5);
639 lua_pushvalue(L, 2);
640 lua_setmetatable(L, 2);
641 lua_pop(L, 1);
643 /* API */
644 lua_newtable(L);
645 for (i = 1; i <= 5; i++) lua_pushvalue(L, i);
646 luaL_setfuncs(L, sl_api, 5);
647 return 1;