Allow overriding CFLAGS
[vd_agent/hramrach.git] / udscs.c
blob6f936487c5ca627fc6fef8cf7ae9b1d40d6a5ca0
1 /* udscs.c Unix Domain Socket Client Server framework. A framework for quickly
2 creating select() based servers capable of handling multiple clients and
3 matching select() based clients using variable size messages.
5 Copyright 2010 Red Hat, Inc.
7 Red Hat Authors:
8 Hans de Goede <hdegoede@redhat.com>
10 This program is free software: you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation, either version 3 of the License, or
13 (at your option) any later version.
15 This program is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
20 You should have received a copy of the GNU General Public License
21 along with this program. If not, see <http://www.gnu.org/licenses/>.
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27 #include <unistd.h>
28 #include <errno.h>
29 #include <sys/socket.h>
30 #include <sys/un.h>
31 #include "udscs.h"
33 struct udscs_buf {
34 uint8_t *buf;
35 size_t pos;
36 size_t size;
38 struct udscs_buf *next;
41 struct udscs_connection {
42 int fd;
43 const char * const *type_to_string;
44 int no_types;
45 FILE *logfile;
46 FILE *errfile;
47 struct ucred peer_cred;
48 void *user_data;
50 /* Read stuff, single buffer, separate header and data buffer */
51 int header_read;
52 struct udscs_message_header header;
53 struct udscs_buf data;
55 /* Writes are stored in a linked list of buffers, with both the header
56 + data for a single message in 1 buffer. */
57 struct udscs_buf *write_buf;
59 /* Callbacks */
60 udscs_read_callback read_callback;
61 udscs_disconnect_callback disconnect_callback;
63 struct udscs_connection *next;
64 struct udscs_connection *prev;
67 struct udscs_server {
68 int fd;
69 const char * const *type_to_string;
70 int no_types;
71 FILE *logfile;
72 FILE *errfile;
73 struct udscs_connection connections_head;
74 udscs_connect_callback connect_callback;
75 udscs_read_callback read_callback;
76 udscs_disconnect_callback disconnect_callback;
79 static void udscs_do_write(struct udscs_connection **connp);
80 static void udscs_do_read(struct udscs_connection **connp);
83 struct udscs_server *udscs_create_server(const char *socketname,
84 udscs_connect_callback connect_callback,
85 udscs_read_callback read_callback,
86 udscs_disconnect_callback disconnect_callback,
87 const char * const type_to_string[], int no_types,
88 FILE *logfile, FILE *errfile)
90 int c;
91 struct sockaddr_un address;
92 struct udscs_server *server;
94 server = calloc(1, sizeof(*server));
95 if (!server)
96 return NULL;
98 server->logfile = logfile;
99 server->errfile = errfile;
100 server->type_to_string = type_to_string;
101 server->no_types = no_types;
103 server->fd = socket(PF_UNIX, SOCK_STREAM, 0);
104 if (server->fd == -1) {
105 fprintf(server->errfile, "creating unix domain socket: %s\n",
106 strerror(errno));
107 free(server);
108 return NULL;
111 c = unlink(socketname);
112 if (c != 0 && errno != ENOENT) {
113 fprintf(server->errfile, "unlink %s: %s\n", socketname,
114 strerror(errno));
115 free(server);
116 return NULL;
119 address.sun_family = AF_UNIX;
120 snprintf(address.sun_path, sizeof(address.sun_path), "%s", socketname);
121 c = bind(server->fd, (struct sockaddr *)&address, sizeof(address));
122 if (c != 0) {
123 fprintf(server->errfile, "bind %s: %s\n", socketname, strerror(errno));
124 free(server);
125 return NULL;
128 c = listen(server->fd, 5);
129 if (c != 0) {
130 fprintf(server->errfile, "listen: %s\n", strerror(errno));
131 free(server);
132 return NULL;
135 server->connect_callback = connect_callback;
136 server->read_callback = read_callback;
137 server->disconnect_callback = disconnect_callback;
139 return server;
142 void udscs_destroy_server(struct udscs_server *server)
144 struct udscs_connection *conn, *next_conn;
146 if (!server)
147 return;
149 conn = server->connections_head.next;
150 while (conn) {
151 next_conn = conn->next;
152 udscs_destroy_connection(&conn);
153 conn = next_conn;
155 close(server->fd);
156 free(server);
159 struct udscs_connection *udscs_connect(const char *socketname,
160 udscs_read_callback read_callback,
161 udscs_disconnect_callback disconnect_callback,
162 const char * const type_to_string[], int no_types,
163 FILE *logfile, FILE *errfile)
165 int c;
166 struct sockaddr_un address;
167 struct udscs_connection *conn;
169 conn = calloc(1, sizeof(*conn));
170 if (!conn)
171 return NULL;
173 conn->logfile = logfile;
174 conn->errfile = errfile;
175 conn->type_to_string = type_to_string;
176 conn->no_types = no_types;
178 conn->fd = socket(PF_UNIX, SOCK_STREAM, 0);
179 if (conn->fd == -1) {
180 fprintf(conn->errfile, "creating unix domain socket: %s\n",
181 strerror(errno));
182 free(conn);
183 return NULL;
186 address.sun_family = AF_UNIX;
187 snprintf(address.sun_path, sizeof(address.sun_path), "%s", socketname);
188 c = connect(conn->fd, (struct sockaddr *)&address, sizeof(address));
189 if (c != 0) {
190 fprintf(conn->errfile, "connect %s: %s\n", socketname,
191 strerror(errno));
192 free(conn);
193 return NULL;
196 conn->read_callback = read_callback;
197 conn->disconnect_callback = disconnect_callback;
199 if (conn->logfile)
200 fprintf(conn->logfile, "%p connected to %s\n", conn, socketname);
202 return conn;
205 void udscs_destroy_connection(struct udscs_connection **connp)
207 struct udscs_buf *wbuf, *next_wbuf;
208 struct udscs_connection *conn = *connp;
210 if (!conn)
211 return;
213 if (conn->disconnect_callback)
214 conn->disconnect_callback(conn);
216 wbuf = conn->write_buf;
217 while (wbuf) {
218 next_wbuf = wbuf->next;
219 free(wbuf->buf);
220 free(wbuf);
221 wbuf = next_wbuf;
224 free(conn->data.buf);
226 if (conn->prev)
227 conn->prev->next = conn->next;
229 close(conn->fd);
231 if (conn->logfile)
232 fprintf(conn->logfile, "%p disconnected\n", conn);
234 free(conn);
235 *connp = NULL;
238 struct ucred udscs_get_peer_cred(struct udscs_connection *conn)
240 return conn->peer_cred;
243 int udscs_server_fill_fds(struct udscs_server *server, fd_set *readfds,
244 fd_set *writefds)
246 struct udscs_connection *conn;
247 int nfds = server->fd + 1;
249 if (!server)
250 return -1;
252 FD_SET(server->fd, readfds);
254 conn = server->connections_head.next;
255 while (conn) {
256 int conn_nfds = udscs_client_fill_fds(conn, readfds, writefds);
257 if (conn_nfds > nfds)
258 nfds = conn_nfds;
260 conn = conn->next;
263 return nfds;
266 int udscs_client_fill_fds(struct udscs_connection *conn, fd_set *readfds,
267 fd_set *writefds)
269 if (!conn)
270 return -1;
272 FD_SET(conn->fd, readfds);
273 if (conn->write_buf)
274 FD_SET(conn->fd, writefds);
276 return conn->fd + 1;
279 static void udscs_server_accept(struct udscs_server *server) {
280 struct udscs_connection *new_conn, *conn;
281 struct sockaddr_un address;
282 socklen_t length = sizeof(address);
283 int r, fd;
285 fd = accept(server->fd, (struct sockaddr *)&address, &length);
286 if (fd == -1) {
287 if (errno == EINTR)
288 return;
289 fprintf(server->errfile, "accept: %s\n", strerror(errno));
290 return;
293 new_conn = calloc(1, sizeof(*conn));
294 if (!new_conn) {
295 fprintf(server->errfile, "out of memory, disconnecting new client\n");
296 close(fd);
297 return;
300 new_conn->fd = fd;
301 new_conn->logfile = server->logfile;
302 new_conn->errfile = server->errfile;
303 new_conn->type_to_string = server->type_to_string;
304 new_conn->no_types = server->no_types;
305 new_conn->read_callback = server->read_callback;
306 new_conn->disconnect_callback = server->disconnect_callback;
308 length = sizeof(new_conn->peer_cred);
309 r = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &new_conn->peer_cred, &length);
310 if (r != 0) {
311 fprintf(server->errfile,
312 "Could not get peercred, disconnecting new client\n");
313 close(fd);
314 free(new_conn);
315 return;
318 conn = &server->connections_head;
319 while (conn->next)
320 conn = conn->next;
322 new_conn->prev = conn;
323 conn->next = new_conn;
325 if (server->logfile)
326 fprintf(server->logfile, "new client accepted: %p, pid: %d\n",
327 new_conn, (int)new_conn->peer_cred.pid);
329 if (server->connect_callback)
330 server->connect_callback(new_conn);
333 void udscs_server_handle_fds(struct udscs_server *server, fd_set *readfds,
334 fd_set *writefds)
336 struct udscs_connection *conn, *next_conn;
338 if (!server)
339 return;
341 if (FD_ISSET(server->fd, readfds))
342 udscs_server_accept(server);
344 conn = server->connections_head.next;
345 while (conn) {
346 /* conn maybe destroyed by udscs_client_handle_fds (when disconnected),
347 so get the next connection first. */
348 next_conn = conn->next;
349 udscs_client_handle_fds(&conn, readfds, writefds);
350 conn = next_conn;
354 void udscs_client_handle_fds(struct udscs_connection **connp, fd_set *readfds,
355 fd_set *writefds)
357 if (!*connp)
358 return;
360 if (FD_ISSET((*connp)->fd, readfds))
361 udscs_do_read(connp);
363 if (*connp && FD_ISSET((*connp)->fd, writefds))
364 udscs_do_write(connp);
367 int udscs_write(struct udscs_connection *conn, uint32_t type, uint32_t opaque,
368 const uint8_t *data, uint32_t size)
370 struct udscs_buf *wbuf, *new_wbuf;
371 struct udscs_message_header header;
373 new_wbuf = malloc(sizeof(*new_wbuf));
374 if (!new_wbuf)
375 return -1;
377 new_wbuf->pos = 0;
378 new_wbuf->size = sizeof(header) + size;
379 new_wbuf->next = NULL;
380 new_wbuf->buf = malloc(new_wbuf->size);
381 if (!new_wbuf->buf) {
382 free(new_wbuf);
383 return -1;
386 header.type = type;
387 header.opaque = opaque;
388 header.size = size;
390 memcpy(new_wbuf->buf, &header, sizeof(header));
391 memcpy(new_wbuf->buf + sizeof(header), data, size);
393 if (conn->logfile) {
394 if (type < conn->no_types)
395 fprintf(conn->logfile, "%p sent %s, opaque: %u, size %u\n",
396 conn, conn->type_to_string[type], opaque, size);
397 else
398 fprintf(conn->logfile,
399 "%p sent invalid message %u, opaque: %u, size %u\n",
400 conn, type, opaque, size);
403 if (!conn->write_buf) {
404 conn->write_buf = new_wbuf;
405 return 0;
408 /* maybe we should limit the write_buf stack depth ? */
409 wbuf = conn->write_buf;
410 while (wbuf->next)
411 wbuf = wbuf->next;
413 wbuf->next = new_wbuf;
415 return 0;
418 int udscs_server_write_all(struct udscs_server *server,
419 uint32_t type, uint32_t opaque,
420 const uint8_t *data, uint32_t size)
422 struct udscs_connection *conn;
424 conn = server->connections_head.next;
425 while (conn) {
426 if (udscs_write(conn, type, opaque, data, size))
427 return -1;
428 conn = conn->next;
431 return 0;
434 int udscs_server_for_all_clients(struct udscs_server *server,
435 udscs_for_all_clients_callback func, void *priv)
437 int r = 0;
438 struct udscs_connection *conn, *next_conn;
440 if (!server)
441 return 0;
443 conn = server->connections_head.next;
444 while (conn) {
445 /* Get next conn as func may destroy the current conn */
446 next_conn = conn->next;
447 r += func(&conn, priv);
448 conn = next_conn;
450 return r;
453 static void udscs_read_complete(struct udscs_connection **connp)
455 struct udscs_connection *conn = *connp;
457 if (conn->logfile) {
458 if (conn->header.type < conn->no_types)
459 fprintf(conn->logfile, "%p received %s, opaque: %u, size %u\n",
460 conn, conn->type_to_string[conn->header.type],
461 conn->header.opaque, conn->header.size);
462 else
463 fprintf(conn->logfile,
464 "%p received invalid message %u, opaque: %u, size %u\n",
465 conn, conn->header.type, conn->header.opaque,
466 conn->header.size);
469 if (conn->read_callback) {
470 conn->read_callback(connp, &conn->header, conn->data.buf);
471 if (!*connp) /* Was the connection disconnected by the callback ? */
472 return;
475 free(conn->data.buf);
476 conn->header_read = 0;
477 memset(&conn->data, 0, sizeof(conn->data));
480 static void udscs_do_read(struct udscs_connection **connp)
482 ssize_t n;
483 size_t to_read;
484 uint8_t *dest;
485 struct udscs_connection *conn = *connp;
487 if (conn->header_read < sizeof(conn->header)) {
488 to_read = sizeof(conn->header) - conn->header_read;
489 dest = (uint8_t *)&conn->header + conn->header_read;
490 } else {
491 to_read = conn->data.size - conn->data.pos;
492 dest = conn->data.buf + conn->data.pos;
495 n = read(conn->fd, dest, to_read);
496 if (n < 0) {
497 if (errno == EINTR)
498 return;
499 fprintf(conn->errfile,
500 "reading unix domain socket: %s, disconnecting %p\n",
501 strerror(errno), conn);
503 if (n <= 0) {
504 udscs_destroy_connection(connp);
505 return;
508 if (conn->header_read < sizeof(conn->header)) {
509 conn->header_read += n;
510 if (conn->header_read == sizeof(conn->header)) {
511 if (conn->header.size == 0) {
512 udscs_read_complete(connp);
513 return;
515 conn->data.pos = 0;
516 conn->data.size = conn->header.size;
517 conn->data.buf = malloc(conn->data.size);
518 if (!conn->data.buf) {
519 fprintf(conn->errfile, "out of memory, disconnecting %p\n",
520 conn);
521 udscs_destroy_connection(connp);
522 return;
525 } else {
526 conn->data.pos += n;
527 if (conn->data.pos == conn->data.size)
528 udscs_read_complete(connp);
532 static void udscs_do_write(struct udscs_connection **connp)
534 ssize_t n;
535 size_t to_write;
536 struct udscs_connection *conn = *connp;
538 struct udscs_buf* wbuf = conn->write_buf;
539 if (!wbuf) {
540 fprintf(conn->errfile,
541 "%p do_write called on a connection without a write buf ?!\n",
542 conn);
543 return;
546 to_write = wbuf->size - wbuf->pos;
547 n = write(conn->fd, wbuf->buf + wbuf->pos, to_write);
548 if (n < 0) {
549 if (errno == EINTR)
550 return;
551 fprintf(conn->errfile,
552 "writing to unix domain socket: %s, disconnecting %p\n",
553 strerror(errno), conn);
554 udscs_destroy_connection(connp);
555 return;
558 wbuf->pos += n;
559 if (wbuf->pos == wbuf->size) {
560 conn->write_buf = wbuf->next;
561 free(wbuf->buf);
562 free(wbuf);
566 void udscs_set_user_data(struct udscs_connection *conn, void *data)
568 conn->user_data = data;
571 void *udscs_get_user_data(struct udscs_connection *conn)
573 if (!conn)
574 return NULL;
576 return conn->user_data;