dbwrap: Add overflow protection to dbwrap_record_watchers_key()
[Samba.git] / source3 / lib / dbwrap / dbwrap_watch.c
bloba6e032a722a2fd3a769005db48273ada3b3eb8cd
1 /*
2 Unix SMB/CIFS implementation.
3 Watch dbwrap record changes
4 Copyright (C) Volker Lendecke 2012
6 This program is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 3 of the License, or
9 (at your option) any later version.
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with this program. If not, see <http://www.gnu.org/licenses/>.
20 #include "includes.h"
21 #include "system/filesys.h"
22 #include "dbwrap/dbwrap.h"
23 #include "dbwrap_watch.h"
24 #include "dbwrap_open.h"
25 #include "lib/util/util_tdb.h"
26 #include "lib/util/tevent_ntstatus.h"
27 #include "server_id_watch.h"
29 static struct db_context *dbwrap_record_watchers_db(void)
31 static struct db_context *watchers_db;
33 if (watchers_db == NULL) {
34 char *db_path = lock_path("dbwrap_watchers.tdb");
35 if (db_path == NULL) {
36 return NULL;
39 watchers_db = db_open(
40 NULL, db_path, 0,
41 TDB_CLEAR_IF_FIRST | TDB_INCOMPATIBLE_HASH,
42 O_RDWR|O_CREAT, 0600, DBWRAP_LOCK_ORDER_3,
43 DBWRAP_FLAG_NONE);
44 TALLOC_FREE(db_path);
46 return watchers_db;
49 static ssize_t dbwrap_record_watchers_key(struct db_context *db,
50 struct db_record *rec,
51 uint8_t *wkey, size_t wkey_len)
53 size_t db_id_len = dbwrap_db_id(db, NULL, 0);
54 uint8_t db_id[db_id_len];
55 size_t needed;
56 TDB_DATA key;
58 dbwrap_db_id(db, db_id, db_id_len);
60 key = dbwrap_record_get_key(rec);
62 needed = sizeof(uint32_t) + db_id_len;
63 if (needed < sizeof(uint32_t)) {
64 return -1;
67 needed += key.dsize;
68 if (needed < key.dsize) {
69 return -1;
72 if (wkey_len >= needed) {
73 SIVAL(wkey, 0, db_id_len);
74 memcpy(wkey + sizeof(uint32_t), db_id, db_id_len);
75 memcpy(wkey + sizeof(uint32_t) + db_id_len,
76 key.dptr, key.dsize);
79 return needed;
82 static bool dbwrap_record_watchers_key_parse(
83 TDB_DATA wkey, uint8_t **p_db_id, size_t *p_db_id_len, TDB_DATA *key)
85 size_t db_id_len;
87 if (wkey.dsize < sizeof(uint32_t)) {
88 DEBUG(1, ("Invalid watchers key\n"));
89 return false;
91 db_id_len = IVAL(wkey.dptr, 0);
92 if (db_id_len > (wkey.dsize - sizeof(uint32_t))) {
93 DEBUG(1, ("Invalid watchers key, wkey.dsize=%d, "
94 "db_id_len=%d\n", (int)wkey.dsize, (int)db_id_len));
95 return false;
97 if (p_db_id != NULL) {
98 *p_db_id = wkey.dptr + sizeof(uint32_t);
100 if (p_db_id_len != NULL) {
101 *p_db_id_len = db_id_len;
103 if (key != NULL) {
104 key->dptr = wkey.dptr + sizeof(uint32_t) + db_id_len;
105 key->dsize = wkey.dsize - sizeof(uint32_t) - db_id_len;
107 return true;
110 static NTSTATUS dbwrap_record_add_watcher(TDB_DATA w_key, struct server_id id)
112 struct TALLOC_CTX *frame = talloc_stackframe();
113 struct db_context *db;
114 struct db_record *rec;
115 TDB_DATA value;
116 struct server_id *ids;
117 size_t num_ids;
118 NTSTATUS status;
120 db = dbwrap_record_watchers_db();
121 if (db == NULL) {
122 status = map_nt_error_from_unix(errno);
123 goto fail;
125 rec = dbwrap_fetch_locked(db, talloc_tos(), w_key);
126 if (rec == NULL) {
127 status = map_nt_error_from_unix(errno);
128 goto fail;
130 value = dbwrap_record_get_value(rec);
132 if ((value.dsize % sizeof(struct server_id)) != 0) {
133 status = NT_STATUS_INTERNAL_DB_CORRUPTION;
134 goto fail;
137 ids = (struct server_id *)value.dptr;
138 num_ids = value.dsize / sizeof(struct server_id);
140 ids = talloc_array(talloc_tos(), struct server_id,
141 num_ids + 1);
142 if (ids == NULL) {
143 status = NT_STATUS_NO_MEMORY;
144 goto fail;
146 memcpy(ids, value.dptr, value.dsize);
147 ids[num_ids] = id;
148 num_ids += 1;
150 status = dbwrap_record_store(
151 rec, make_tdb_data((uint8_t *)ids, talloc_get_size(ids)), 0);
152 fail:
153 TALLOC_FREE(frame);
154 return status;
157 static NTSTATUS dbwrap_record_del_watcher(TDB_DATA w_key, struct server_id id)
159 struct TALLOC_CTX *frame = talloc_stackframe();
160 struct db_context *db;
161 struct db_record *rec;
162 struct server_id *ids;
163 size_t i, num_ids;
164 TDB_DATA value;
165 NTSTATUS status;
167 db = dbwrap_record_watchers_db();
168 if (db == NULL) {
169 status = map_nt_error_from_unix(errno);
170 goto fail;
172 rec = dbwrap_fetch_locked(db, talloc_tos(), w_key);
173 if (rec == NULL) {
174 status = map_nt_error_from_unix(errno);
175 goto fail;
177 value = dbwrap_record_get_value(rec);
179 if ((value.dsize % sizeof(struct server_id)) != 0) {
180 status = NT_STATUS_INTERNAL_DB_CORRUPTION;
181 goto fail;
184 ids = (struct server_id *)value.dptr;
185 num_ids = value.dsize / sizeof(struct server_id);
187 for (i=0; i<num_ids; i++) {
188 if (serverid_equal(&id, &ids[i])) {
189 ids[i] = ids[num_ids-1];
190 value.dsize -= sizeof(struct server_id);
191 break;
194 if (value.dsize == 0) {
195 status = dbwrap_record_delete(rec);
196 goto done;
198 status = dbwrap_record_store(rec, value, 0);
199 fail:
200 done:
201 TALLOC_FREE(frame);
202 return status;
205 struct dbwrap_record_watch_state {
206 struct tevent_context *ev;
207 struct db_context *db;
208 struct tevent_req *req;
209 struct messaging_context *msg;
210 TDB_DATA w_key;
211 bool blockerdead;
212 struct server_id blocker;
215 static bool dbwrap_record_watch_filter(struct messaging_rec *rec,
216 void *private_data);
217 static void dbwrap_record_watch_done(struct tevent_req *subreq);
218 static void dbwrap_record_watch_blocker_died(struct tevent_req *subreq);
219 static int dbwrap_record_watch_state_destructor(
220 struct dbwrap_record_watch_state *state);
222 struct tevent_req *dbwrap_record_watch_send(TALLOC_CTX *mem_ctx,
223 struct tevent_context *ev,
224 struct db_record *rec,
225 struct messaging_context *msg,
226 struct server_id blocker)
228 struct tevent_req *req, *subreq;
229 struct dbwrap_record_watch_state *state;
230 struct db_context *watchers_db;
231 NTSTATUS status;
232 ssize_t needed;
234 req = tevent_req_create(mem_ctx, &state,
235 struct dbwrap_record_watch_state);
236 if (req == NULL) {
237 return NULL;
239 state->db = dbwrap_record_get_db(rec);
240 state->ev = ev;
241 state->req = req;
242 state->msg = msg;
243 state->blocker = blocker;
245 watchers_db = dbwrap_record_watchers_db();
246 if (watchers_db == NULL) {
247 tevent_req_nterror(req, map_nt_error_from_unix(errno));
248 return tevent_req_post(req, ev);
251 needed = dbwrap_record_watchers_key(state->db, rec, NULL, 0);
252 if (needed == -1) {
253 tevent_req_nterror(req, NT_STATUS_INSUFFICIENT_RESOURCES);
254 return tevent_req_post(req, ev);
256 state->w_key.dsize = needed;
258 state->w_key.dptr = talloc_array(state, uint8_t, state->w_key.dsize);
259 if (tevent_req_nomem(state->w_key.dptr, req)) {
260 return tevent_req_post(req, ev);
262 dbwrap_record_watchers_key(
263 state->db, rec, state->w_key.dptr, state->w_key.dsize);
265 subreq = messaging_filtered_read_send(
266 state, ev, state->msg, dbwrap_record_watch_filter, state);
267 if (tevent_req_nomem(subreq, req)) {
268 return tevent_req_post(req, ev);
270 tevent_req_set_callback(subreq, dbwrap_record_watch_done, req);
272 if (blocker.pid != 0) {
273 subreq = server_id_watch_send(state, ev, msg, blocker);
274 if (tevent_req_nomem(subreq, req)) {
275 return tevent_req_post(req, ev);
277 tevent_req_set_callback(
278 subreq, dbwrap_record_watch_blocker_died, req);
281 status = dbwrap_record_add_watcher(
282 state->w_key, messaging_server_id(state->msg));
283 if (tevent_req_nterror(req, status)) {
284 return tevent_req_post(req, ev);
286 talloc_set_destructor(state, dbwrap_record_watch_state_destructor);
288 return req;
291 static bool dbwrap_record_watch_filter(struct messaging_rec *rec,
292 void *private_data)
294 struct dbwrap_record_watch_state *state = talloc_get_type_abort(
295 private_data, struct dbwrap_record_watch_state);
297 if (rec->msg_type != MSG_DBWRAP_MODIFIED) {
298 return false;
300 if (rec->num_fds != 0) {
301 return false;
303 if (rec->buf.length != state->w_key.dsize) {
304 return false;
306 return memcmp(rec->buf.data, state->w_key.dptr, rec->buf.length) == 0;
309 static int dbwrap_record_watch_state_destructor(
310 struct dbwrap_record_watch_state *s)
312 if (s->msg != NULL) {
313 dbwrap_record_del_watcher(
314 s->w_key, messaging_server_id(s->msg));
316 return 0;
319 static void dbwrap_watch_record_stored_fn(TDB_DATA key, TDB_DATA data,
320 void *private_data)
322 struct messaging_context *msg = private_data;
323 size_t i, num_ids;
325 if ((data.dsize % sizeof(struct server_id)) != 0) {
326 DBG_WARNING("Invalid data size: %zu\n", data.dsize);
327 return;
329 num_ids = data.dsize / sizeof(struct server_id);
331 for (i=0; i<num_ids; i++) {
332 struct server_id dst;
333 NTSTATUS status;
335 memcpy(&dst, data.dptr + i * sizeof(struct server_id),
336 sizeof(struct server_id));
338 status = messaging_send_buf(msg, dst, MSG_DBWRAP_MODIFIED,
339 key.dptr, key.dsize);
340 if (!NT_STATUS_IS_OK(status)) {
341 struct server_id_buf tmp;
342 DBG_WARNING("messaging_send to %s failed: %s\n",
343 server_id_str_buf(dst, &tmp),
344 nt_errstr(status));
349 static void dbwrap_watch_record_stored(struct db_context *db,
350 struct db_record *rec,
351 void *private_data)
353 struct messaging_context *msg = talloc_get_type_abort(
354 private_data, struct messaging_context);
355 struct db_context *watchers_db;
357 size_t wkey_len = dbwrap_record_watchers_key(db, rec, NULL, 0);
358 uint8_t wkey_buf[wkey_len];
359 TDB_DATA wkey = { .dptr = wkey_buf, .dsize = wkey_len };
361 NTSTATUS status;
363 watchers_db = dbwrap_record_watchers_db();
364 if (watchers_db == NULL) {
365 return;
368 dbwrap_record_watchers_key(db, rec, wkey_buf, wkey_len);
370 status = dbwrap_parse_record(watchers_db, wkey,
371 dbwrap_watch_record_stored_fn, msg);
372 if (NT_STATUS_EQUAL(status, NT_STATUS_NOT_FOUND)) {
373 return;
375 if (!NT_STATUS_IS_OK(status)) {
376 DBG_WARNING("dbwrap_parse_record failed: %s\n",
377 nt_errstr(status));
381 void dbwrap_watch_db(struct db_context *db, struct messaging_context *msg)
383 dbwrap_set_stored_callback(db, dbwrap_watch_record_stored, msg);
386 static void dbwrap_record_watch_done(struct tevent_req *subreq)
388 struct tevent_req *req = tevent_req_callback_data(
389 subreq, struct tevent_req);
390 struct messaging_rec *rec;
391 int ret;
393 ret = messaging_filtered_read_recv(subreq, talloc_tos(), &rec);
394 TALLOC_FREE(subreq);
395 if (ret != 0) {
396 tevent_req_nterror(req, map_nt_error_from_unix(ret));
397 return;
399 tevent_req_done(req);
402 static void dbwrap_record_watch_blocker_died(struct tevent_req *subreq)
404 struct tevent_req *req = tevent_req_callback_data(
405 subreq, struct tevent_req);
406 struct dbwrap_record_watch_state *state = tevent_req_data(
407 req, struct dbwrap_record_watch_state);
408 int ret;
410 ret = server_id_watch_recv(subreq, NULL);
411 TALLOC_FREE(subreq);
412 if (ret != 0) {
413 tevent_req_nterror(req, map_nt_error_from_unix(ret));
414 return;
416 state->blockerdead = true;
417 tevent_req_done(req);
420 NTSTATUS dbwrap_record_watch_recv(struct tevent_req *req,
421 TALLOC_CTX *mem_ctx,
422 struct db_record **prec,
423 bool *blockerdead,
424 struct server_id *blocker)
426 struct dbwrap_record_watch_state *state = tevent_req_data(
427 req, struct dbwrap_record_watch_state);
428 NTSTATUS status;
429 TDB_DATA key;
430 struct db_record *rec;
431 bool ok;
433 if (tevent_req_is_nterror(req, &status)) {
434 return status;
436 if (blockerdead != NULL) {
437 *blockerdead = state->blockerdead;
439 if (blocker != NULL) {
440 *blocker = state->blocker;
442 if (prec == NULL) {
443 return NT_STATUS_OK;
446 ok = dbwrap_record_watchers_key_parse(state->w_key, NULL, NULL, &key);
447 if (!ok) {
448 return NT_STATUS_INTERNAL_DB_ERROR;
451 rec = dbwrap_fetch_locked(state->db, mem_ctx, key);
452 if (rec == NULL) {
453 return NT_STATUS_INTERNAL_DB_ERROR;
455 *prec = rec;
456 return NT_STATUS_OK;
459 struct dbwrap_watchers_traverse_read_state {
460 int (*fn)(const uint8_t *db_id, size_t db_id_len, const TDB_DATA key,
461 const struct server_id *watchers, size_t num_watchers,
462 void *private_data);
463 void *private_data;
466 static int dbwrap_watchers_traverse_read_callback(
467 struct db_record *rec, void *private_data)
469 struct dbwrap_watchers_traverse_read_state *state =
470 (struct dbwrap_watchers_traverse_read_state *)private_data;
471 uint8_t *db_id;
472 size_t db_id_len;
473 TDB_DATA w_key, key, w_data;
474 int res;
476 w_key = dbwrap_record_get_key(rec);
477 w_data = dbwrap_record_get_value(rec);
479 if (!dbwrap_record_watchers_key_parse(w_key, &db_id, &db_id_len,
480 &key)) {
481 return 0;
483 if ((w_data.dsize % sizeof(struct server_id)) != 0) {
484 return 0;
486 res = state->fn(db_id, db_id_len, key,
487 (struct server_id *)w_data.dptr,
488 w_data.dsize / sizeof(struct server_id),
489 state->private_data);
490 return res;
493 void dbwrap_watchers_traverse_read(
494 int (*fn)(const uint8_t *db_id, size_t db_id_len, const TDB_DATA key,
495 const struct server_id *watchers, size_t num_watchers,
496 void *private_data),
497 void *private_data)
499 struct dbwrap_watchers_traverse_read_state state;
500 struct db_context *db;
502 db = dbwrap_record_watchers_db();
503 if (db == NULL) {
504 return;
506 state.fn = fn;
507 state.private_data = private_data;
508 dbwrap_traverse_read(db, dbwrap_watchers_traverse_read_callback,
509 &state, NULL);
512 static int dbwrap_wakeall_cb(const uint8_t *db_id, size_t db_id_len,
513 const TDB_DATA key,
514 const struct server_id *watchers,
515 size_t num_watchers,
516 void *private_data)
518 struct messaging_context *msg = talloc_get_type_abort(
519 private_data, struct messaging_context);
520 uint32_t i;
521 DATA_BLOB blob;
523 blob.data = key.dptr;
524 blob.length = key.dsize;
526 for (i=0; i<num_watchers; i++) {
527 messaging_send(msg, watchers[i], MSG_DBWRAP_MODIFIED, &blob);
529 return 0;
532 void dbwrap_watchers_wakeall(struct messaging_context *msg)
534 dbwrap_watchers_traverse_read(dbwrap_wakeall_cb, msg);