udscs: Fix a potential NULL pointer dereference
[vd_agent.git] / src / udscs.c
blobfdd75a482d3a8299fcfe7c6b718f70f68b05b23e
1 /* udscs.c Unix Domain Socket Client Server framework. A framework for quickly
2 creating select() based servers capable of handling multiple clients and
3 matching select() based clients using variable size messages.
5 Copyright 2010 Red Hat, Inc.
7 Red Hat Authors:
8 Hans de Goede <hdegoede@redhat.com>
10 This program is free software: you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation, either version 3 of the License, or
13 (at your option) any later version.
15 This program is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
20 You should have received a copy of the GNU General Public License
21 along with this program. If not, see <http://www.gnu.org/licenses/>.
23 #ifdef HAVE_CONFIG_H
24 #include <config.h>
25 #endif
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <syslog.h>
30 #include <unistd.h>
31 #include <errno.h>
32 #include <sys/socket.h>
33 #include <sys/un.h>
34 #include "udscs.h"
36 struct udscs_buf {
37 uint8_t *buf;
38 size_t pos;
39 size_t size;
41 struct udscs_buf *next;
44 struct udscs_connection {
45 int fd;
46 const char * const *type_to_string;
47 int no_types;
48 int debug;
49 void *user_data;
50 #ifndef UDSCS_NO_SERVER
51 struct ucred peer_cred;
52 #endif
54 /* Read stuff, single buffer, separate header and data buffer */
55 int header_read;
56 struct udscs_message_header header;
57 struct udscs_buf data;
59 /* Writes are stored in a linked list of buffers, with both the header
60 + data for a single message in 1 buffer. */
61 struct udscs_buf *write_buf;
63 /* Callbacks */
64 udscs_read_callback read_callback;
65 udscs_disconnect_callback disconnect_callback;
67 struct udscs_connection *next;
68 struct udscs_connection *prev;
71 struct udscs_connection *udscs_connect(const char *socketname,
72 udscs_read_callback read_callback,
73 udscs_disconnect_callback disconnect_callback,
74 const char * const type_to_string[], int no_types, int debug)
76 int c;
77 struct sockaddr_un address;
78 struct udscs_connection *conn;
80 conn = calloc(1, sizeof(*conn));
81 if (!conn)
82 return NULL;
84 conn->type_to_string = type_to_string;
85 conn->no_types = no_types;
86 conn->debug = debug;
88 conn->fd = socket(PF_UNIX, SOCK_STREAM, 0);
89 if (conn->fd == -1) {
90 syslog(LOG_ERR, "creating unix domain socket: %m");
91 free(conn);
92 return NULL;
95 address.sun_family = AF_UNIX;
96 snprintf(address.sun_path, sizeof(address.sun_path), "%s", socketname);
97 c = connect(conn->fd, (struct sockaddr *)&address, sizeof(address));
98 if (c != 0) {
99 if (conn->debug) {
100 syslog(LOG_DEBUG, "connect %s: %m", socketname);
102 free(conn);
103 return NULL;
106 conn->read_callback = read_callback;
107 conn->disconnect_callback = disconnect_callback;
109 if (conn->debug)
110 syslog(LOG_DEBUG, "%p connected to %s", conn, socketname);
112 return conn;
115 void udscs_destroy_connection(struct udscs_connection **connp)
117 struct udscs_buf *wbuf, *next_wbuf;
118 struct udscs_connection *conn = *connp;
120 if (!conn)
121 return;
123 if (conn->disconnect_callback)
124 conn->disconnect_callback(conn);
126 wbuf = conn->write_buf;
127 while (wbuf) {
128 next_wbuf = wbuf->next;
129 free(wbuf->buf);
130 free(wbuf);
131 wbuf = next_wbuf;
134 free(conn->data.buf);
135 conn->data.buf = NULL;
137 if (conn->next)
138 conn->next->prev = conn->prev;
139 if (conn->prev)
140 conn->prev->next = conn->next;
142 close(conn->fd);
144 if (conn->debug)
145 syslog(LOG_DEBUG, "%p disconnected", conn);
147 free(conn);
148 *connp = NULL;
151 void udscs_set_user_data(struct udscs_connection *conn, void *data)
153 conn->user_data = data;
156 void *udscs_get_user_data(struct udscs_connection *conn)
158 if (!conn)
159 return NULL;
161 return conn->user_data;
164 int udscs_write(struct udscs_connection *conn, uint32_t type, uint32_t arg1,
165 uint32_t arg2, const uint8_t *data, uint32_t size)
167 struct udscs_buf *wbuf, *new_wbuf;
168 struct udscs_message_header header;
170 new_wbuf = malloc(sizeof(*new_wbuf));
171 if (!new_wbuf)
172 return -1;
174 new_wbuf->pos = 0;
175 new_wbuf->size = sizeof(header) + size;
176 new_wbuf->next = NULL;
177 new_wbuf->buf = malloc(new_wbuf->size);
178 if (!new_wbuf->buf) {
179 free(new_wbuf);
180 return -1;
183 header.type = type;
184 header.arg1 = arg1;
185 header.arg2 = arg2;
186 header.size = size;
188 memcpy(new_wbuf->buf, &header, sizeof(header));
189 memcpy(new_wbuf->buf + sizeof(header), data, size);
191 if (conn->debug) {
192 if (type < conn->no_types)
193 syslog(LOG_DEBUG, "%p sent %s, arg1: %u, arg2: %u, size %u",
194 conn, conn->type_to_string[type], arg1, arg2, size);
195 else
196 syslog(LOG_DEBUG,
197 "%p sent invalid message %u, arg1: %u, arg2: %u, size %u",
198 conn, type, arg1, arg2, size);
201 if (!conn->write_buf) {
202 conn->write_buf = new_wbuf;
203 return 0;
206 /* maybe we should limit the write_buf stack depth ? */
207 wbuf = conn->write_buf;
208 while (wbuf->next)
209 wbuf = wbuf->next;
211 wbuf->next = new_wbuf;
213 return 0;
216 /* A helper for udscs_do_read() */
217 static void udscs_read_complete(struct udscs_connection **connp)
219 struct udscs_connection *conn = *connp;
221 if (conn->debug) {
222 if (conn->header.type < conn->no_types)
223 syslog(LOG_DEBUG,
224 "%p received %s, arg1: %u, arg2: %u, size %u",
225 conn, conn->type_to_string[conn->header.type],
226 conn->header.arg1, conn->header.arg2, conn->header.size);
227 else
228 syslog(LOG_DEBUG,
229 "%p received invalid message %u, arg1: %u, arg2: %u, size %u",
230 conn, conn->header.type, conn->header.arg1, conn->header.arg2,
231 conn->header.size);
234 if (conn->read_callback) {
235 conn->read_callback(connp, &conn->header, conn->data.buf);
236 if (!*connp) /* Was the connection disconnected by the callback ? */
237 return;
240 free(conn->data.buf);
241 memset(&conn->data, 0, sizeof(conn->data)); /* data.buf = NULL */
242 conn->header_read = 0;
245 /* A helper for udscs_client_handle_fds() */
246 static void udscs_do_read(struct udscs_connection **connp)
248 ssize_t n;
249 size_t to_read;
250 uint8_t *dest;
251 struct udscs_connection *conn = *connp;
253 if (conn->header_read < sizeof(conn->header)) {
254 to_read = sizeof(conn->header) - conn->header_read;
255 dest = (uint8_t *)&conn->header + conn->header_read;
256 } else {
257 to_read = conn->data.size - conn->data.pos;
258 dest = conn->data.buf + conn->data.pos;
261 n = read(conn->fd, dest, to_read);
262 if (n < 0) {
263 if (errno == EINTR)
264 return;
265 syslog(LOG_ERR, "reading unix domain socket: %m, disconnecting %p",
266 conn);
268 if (n <= 0) {
269 udscs_destroy_connection(connp);
270 return;
273 if (conn->header_read < sizeof(conn->header)) {
274 conn->header_read += n;
275 if (conn->header_read == sizeof(conn->header)) {
276 if (conn->header.size == 0) {
277 udscs_read_complete(connp);
278 return;
280 conn->data.pos = 0;
281 conn->data.size = conn->header.size;
282 conn->data.buf = malloc(conn->data.size);
283 if (!conn->data.buf) {
284 syslog(LOG_ERR, "out of memory, disconnecting %p", conn);
285 udscs_destroy_connection(connp);
286 return;
289 } else {
290 conn->data.pos += n;
291 if (conn->data.pos == conn->data.size)
292 udscs_read_complete(connp);
296 /* A helper for udscs_client_handle_fds() */
297 static void udscs_do_write(struct udscs_connection **connp)
299 ssize_t n;
300 size_t to_write;
301 struct udscs_connection *conn = *connp;
303 struct udscs_buf* wbuf = conn->write_buf;
304 if (!wbuf) {
305 syslog(LOG_ERR,
306 "%p do_write called on a connection without a write buf ?!",
307 conn);
308 return;
311 to_write = wbuf->size - wbuf->pos;
312 n = write(conn->fd, wbuf->buf + wbuf->pos, to_write);
313 if (n < 0) {
314 if (errno == EINTR)
315 return;
316 syslog(LOG_ERR, "writing to unix domain socket: %m, disconnecting %p",
317 conn);
318 udscs_destroy_connection(connp);
319 return;
322 wbuf->pos += n;
323 if (wbuf->pos == wbuf->size) {
324 conn->write_buf = wbuf->next;
325 free(wbuf->buf);
326 free(wbuf);
330 void udscs_client_handle_fds(struct udscs_connection **connp, fd_set *readfds,
331 fd_set *writefds)
333 if (!*connp)
334 return;
336 if (FD_ISSET((*connp)->fd, readfds))
337 udscs_do_read(connp);
339 if (*connp && FD_ISSET((*connp)->fd, writefds))
340 udscs_do_write(connp);
343 int udscs_client_fill_fds(struct udscs_connection *conn, fd_set *readfds,
344 fd_set *writefds)
346 if (!conn)
347 return -1;
349 FD_SET(conn->fd, readfds);
350 if (conn->write_buf)
351 FD_SET(conn->fd, writefds);
353 return conn->fd + 1;
357 #ifndef UDSCS_NO_SERVER
359 /* ---------- Server-side implementation ---------- */
361 struct udscs_server {
362 int fd;
363 const char * const *type_to_string;
364 int no_types;
365 int debug;
366 struct udscs_connection connections_head;
367 udscs_connect_callback connect_callback;
368 udscs_read_callback read_callback;
369 udscs_disconnect_callback disconnect_callback;
372 struct udscs_server *udscs_create_server(const char *socketname,
373 udscs_connect_callback connect_callback,
374 udscs_read_callback read_callback,
375 udscs_disconnect_callback disconnect_callback,
376 const char * const type_to_string[], int no_types, int debug)
378 int c;
379 struct sockaddr_un address;
380 struct udscs_server *server;
382 server = calloc(1, sizeof(*server));
383 if (!server)
384 return NULL;
386 server->type_to_string = type_to_string;
387 server->no_types = no_types;
388 server->debug = debug;
390 server->fd = socket(PF_UNIX, SOCK_STREAM, 0);
391 if (server->fd == -1) {
392 syslog(LOG_ERR, "creating unix domain socket: %m");
393 free(server);
394 return NULL;
397 address.sun_family = AF_UNIX;
398 snprintf(address.sun_path, sizeof(address.sun_path), "%s", socketname);
399 c = bind(server->fd, (struct sockaddr *)&address, sizeof(address));
400 if (c != 0) {
401 syslog(LOG_ERR, "bind %s: %m", socketname);
402 free(server);
403 return NULL;
406 c = listen(server->fd, 5);
407 if (c != 0) {
408 syslog(LOG_ERR, "listen: %m");
409 free(server);
410 return NULL;
413 server->connect_callback = connect_callback;
414 server->read_callback = read_callback;
415 server->disconnect_callback = disconnect_callback;
417 return server;
420 void udscs_destroy_server(struct udscs_server *server)
422 struct udscs_connection *conn, *next_conn;
424 if (!server)
425 return;
427 conn = server->connections_head.next;
428 while (conn) {
429 next_conn = conn->next;
430 udscs_destroy_connection(&conn);
431 conn = next_conn;
433 close(server->fd);
434 free(server);
437 struct ucred udscs_get_peer_cred(struct udscs_connection *conn)
439 return conn->peer_cred;
442 static void udscs_server_accept(struct udscs_server *server) {
443 struct udscs_connection *new_conn, *conn;
444 struct sockaddr_un address;
445 socklen_t length = sizeof(address);
446 int r, fd;
448 fd = accept(server->fd, (struct sockaddr *)&address, &length);
449 if (fd == -1) {
450 if (errno == EINTR)
451 return;
452 syslog(LOG_ERR, "accept: %m");
453 return;
456 new_conn = calloc(1, sizeof(*conn));
457 if (!new_conn) {
458 syslog(LOG_ERR, "out of memory, disconnecting new client");
459 close(fd);
460 return;
463 new_conn->fd = fd;
464 new_conn->type_to_string = server->type_to_string;
465 new_conn->no_types = server->no_types;
466 new_conn->debug = server->debug;
467 new_conn->read_callback = server->read_callback;
468 new_conn->disconnect_callback = server->disconnect_callback;
470 length = sizeof(new_conn->peer_cred);
471 r = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &new_conn->peer_cred, &length);
472 if (r != 0) {
473 syslog(LOG_ERR, "Could not get peercred, disconnecting new client");
474 close(fd);
475 free(new_conn);
476 return;
479 conn = &server->connections_head;
480 while (conn->next)
481 conn = conn->next;
483 new_conn->prev = conn;
484 conn->next = new_conn;
486 if (server->debug)
487 syslog(LOG_DEBUG, "new client accepted: %p, pid: %d",
488 new_conn, (int)new_conn->peer_cred.pid);
490 if (server->connect_callback)
491 server->connect_callback(new_conn);
494 int udscs_server_fill_fds(struct udscs_server *server, fd_set *readfds,
495 fd_set *writefds)
497 struct udscs_connection *conn;
498 int nfds;
500 if (!server)
501 return -1;
503 nfds = server->fd + 1;
504 FD_SET(server->fd, readfds);
506 conn = server->connections_head.next;
507 while (conn) {
508 int conn_nfds = udscs_client_fill_fds(conn, readfds, writefds);
509 if (conn_nfds > nfds)
510 nfds = conn_nfds;
512 conn = conn->next;
515 return nfds;
518 void udscs_server_handle_fds(struct udscs_server *server, fd_set *readfds,
519 fd_set *writefds)
521 struct udscs_connection *conn, *next_conn;
523 if (!server)
524 return;
526 if (FD_ISSET(server->fd, readfds))
527 udscs_server_accept(server);
529 conn = server->connections_head.next;
530 while (conn) {
531 /* conn maybe destroyed by udscs_client_handle_fds (when disconnected),
532 so get the next connection first. */
533 next_conn = conn->next;
534 udscs_client_handle_fds(&conn, readfds, writefds);
535 conn = next_conn;
539 int udscs_server_write_all(struct udscs_server *server,
540 uint32_t type, uint32_t arg1, uint32_t arg2,
541 const uint8_t *data, uint32_t size)
543 struct udscs_connection *conn;
545 conn = server->connections_head.next;
546 while (conn) {
547 if (udscs_write(conn, type, arg1, arg2, data, size))
548 return -1;
549 conn = conn->next;
552 return 0;
555 int udscs_server_for_all_clients(struct udscs_server *server,
556 udscs_for_all_clients_callback func, void *priv)
558 int r = 0;
559 struct udscs_connection *conn, *next_conn;
561 if (!server)
562 return 0;
564 conn = server->connections_head.next;
565 while (conn) {
566 /* Get next conn as func may destroy the current conn */
567 next_conn = conn->next;
568 r += func(&conn, priv);
569 conn = next_conn;
571 return r;
574 #endif