From a1a8e13fcf98b106a409673d895a45babd4bedea Mon Sep 17 00:00:00 2001 From: Petr Baudis Date: Wed, 7 Mar 2012 16:49:28 +0100 Subject: [PATCH] Rewrite the communication protocol sendmsg() with only ancilliary message does not work, apparently. Therefore, to make things cleaner, pass command/reply directly using sendmsg() instead of newline-terminated strings. --- README | 6 ++-- compctl.c | 91 ++++++++++++++++++++++++++++++++++++++-------------------- compctld.c | 97 ++++++++++++++++++++++++++++++++++++++++++-------------------- 3 files changed, 129 insertions(+), 65 deletions(-) diff --git a/README b/README index 92f0096..812101b 100644 --- a/README +++ b/README @@ -10,9 +10,9 @@ tweaking the cgroup limits. The client compctl interface simply queries the server using a synchronous protocol over a UNIX socket. First, the client -sends a SCM_CREDENTIALS ancilliary message. Then, it follows -with a CRLF-terminated command string and receives a CRLF-terminated -reply string. Connection is closed immediately on breach of protocol. +sends a command string message coupled with a SCM_CREDENTIALS +ancilliary message. Then, it receives a reply message. +Connection is closed immediately on breach of protocol. You can tweak some simple compile-time configuration variables diff --git a/compctl.c b/compctl.c index dd73291..486280d 100644 --- a/compctl.c +++ b/compctl.c @@ -1,5 +1,6 @@ #define _GNU_SOURCE /* struct ucred */ #include +#include #include #include #include @@ -11,8 +12,8 @@ #include "common.h" -FILE * -connectd(void) +char * +daemon_chat(char *cmd) { int s = socket(AF_UNIX, SOCK_STREAM, 0); struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE }; @@ -21,29 +22,67 @@ connectd(void) exit(EXIT_FAILURE); } - /* Send message with credentials. */ + /* Send command. */ + struct iovec iov_cmd = { + .iov_base = cmd, + .iov_len = strlen(cmd), + }; + struct msghdr msg = { + .msg_iov = &iov_cmd, + .msg_iovlen = 1, + }; + + /* Include credentials in the message. */ struct ucred cred = { .pid = getpid(), .uid = getuid(), .gid = getgid(), }; - char cbuf[CMSG_SPACE(sizeof(cred))]; - struct msghdr msg = { - .msg_control = cbuf, - .msg_controllen = sizeof(cbuf), - }; - + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_type = SCM_CREDENTIALS; cmsg->cmsg_len = CMSG_LEN(sizeof(cred)); memcpy(CMSG_DATA(cmsg), &cred, sizeof(cred)); - sendmsg(s, &msg, 0); + ssize_t sent = sendmsg(s, &msg, 0); + if (sent < 0) { + perror("sendmsg"); + exit(EXIT_FAILURE); + } + if (sent < msg.msg_iov->iov_len) { + fprintf(stderr, "incomplete send %zd < %zu, FIXME\n", sent, msg.msg_iov->iov_len); + exit(EXIT_FAILURE); + } + + /* Receive reply. */ + + char reply[1024]; + struct iovec iov_reply = { + .iov_base = reply, + .iov_len = sizeof(reply), + }; + msg.msg_iov = &iov_reply; + msg.msg_iovlen = 1; + +recvagain:; + int replylen = recvmsg(s, &msg, 0); + if (replylen < 0) { + if (errno == EAGAIN) + goto recvagain; + perror("recvmsg"); + exit(EXIT_FAILURE); + } + if (replylen >= 1024) { + fprintf(stderr, "too long reply from the server\n"); + exit(EXIT_FAILURE); + } + reply[replylen] = 0; - return fdopen(s, "rw"); + return strdup(reply); } @@ -58,15 +97,12 @@ help(FILE *f) int run(int argc, char *argv[]) { - FILE *f = connectd(); - fputs("blessme\r\n", f); - char line[1024]; - fgets(line, sizeof(line), f); - fclose(f); + char *line = daemon_chat("blessme"); if (line[0] != '1') { fputs(*line ? line : "unexpected hangup\n", stderr); return EXIT_FAILURE; } + free(line); char *argvx[argc + 1]; for (int i = 0; i < argc; i++) @@ -91,45 +127,38 @@ screen(int argc, char *argv[]) void stop(pid_t pid) { - FILE *f = connectd(); - fprintf(f, "stop %d\r\n", pid); - char line[1024]; - fgets(line, sizeof(line), f); - fclose(f); + char cmd[256]; snprintf(cmd, sizeof(cmd), "stop %d", pid); + char *line = daemon_chat(cmd); if (line[0] != '1') { fputs(*line ? line : "unexpected hangup\n", stderr); exit(EXIT_FAILURE); } + free(line); } void stop_all(void) { - FILE *f = connectd(); - fputs("stopall\r\n", f); - char line[1024]; - fgets(line, sizeof(line), f); - fclose(f); + char *line = daemon_chat("stopall"); if (line[0] != '1') { fputs(*line ? line : "unexpected hangup\n", stderr); exit(EXIT_FAILURE); } fputs(line + 2, stdout); + free(line); } void limit_mem(size_t limit) { - FILE *f = connectd(); - fprintf(f, "limitmem %zu\r\n", limit); - char line[1024]; - fgets(line, sizeof(line), f); - fclose(f); + char cmd[256]; snprintf(cmd, sizeof(cmd), "limitmem %zu", limit); + char *line = daemon_chat(cmd); if (line[0] != '1') { /* TODO: Error message postprocessing. */ fputs(*line ? line : "unexpected hangup\n", stderr); exit(EXIT_FAILURE); } + free(line); } void diff --git a/compctld.c b/compctld.c index 6ca9dfd..abe33e5 100644 --- a/compctld.c +++ b/compctld.c @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -77,6 +78,36 @@ cgroup_init(void) } +void +mprintf(int fd, char *fmt, ...) +{ + char buf[1024]; + + va_list v; + va_start(v, fmt); + vsnprintf(buf, sizeof(buf), fmt, v); + va_end(v); + + struct iovec iov = { + .iov_base = buf, + .iov_len = strlen(buf), + }; + struct msghdr msg = { + .msg_iov = &iov, + .msg_iovlen = 1, + }; + ssize_t sent = sendmsg(fd, &msg, 0); + if (sent < 0) { + logperror("sendmsg"); + return; + } + if (sent < iov.iov_len) { + syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len); + return; + } +} + + int main(int argc, char *argv[]) { @@ -100,8 +131,6 @@ main(int argc, char *argv[]) setsid(); int s = socket(AF_UNIX, SOCK_STREAM, 0); - int on = 1; setsockopt(s, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on)); - /* TODO: Protect against double execution? */ unlink(SOCKFILE); struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE }; @@ -119,17 +148,27 @@ main(int argc, char *argv[]) * the daemon, this is just an attack vector we ignore. */ /* TODO: alarm() to wake from stuck clients. */ - /* Decode the message with credentials. */ + /* Decode the message with command and credentials. */ + + int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on)); struct ucred *cred; char cbuf[CMSG_SPACE(sizeof(*cred))]; + char line[1024]; + struct iovec iov = { + .iov_base = line, + .iov_len = sizeof(line), + }; struct msghdr msg = { + .msg_iov = &iov, + .msg_iovlen = 1, .msg_control = cbuf, .msg_controllen = sizeof(cbuf), }; char *errmsg; -recvagain: - if (recvmsg(fd, &msg, MSG_WAITALL) < 0) { +recvagain:; + int replylen = recvmsg(fd, &msg, 0); + if (replylen < 0) { if (errno == EAGAIN) goto recvagain; errmsg = "recvmsg"; @@ -151,24 +190,20 @@ sockerror: } cred = (struct ucred *) CMSG_DATA(cmsg); - FILE *f = fdopen(fd, "r"); - char line[1024]; - fgets(line, sizeof(line), f); - size_t linelen = strlen(line); - if (linelen < 2 || strcmp(&line[linelen - 2], "\r\n")) { + line[replylen] = 0; + if (replylen < 2) { syslog(LOG_WARNING, "protocol error (%s)", line); - fclose(f); + close(fd); continue; } - line[linelen - 2] = 0; /* Analyze command */ if (!strcmp("blessme", line)) { syslog(LOG_INFO, "new computation process %d", cred->pid); if (cgroup_add_task(chier, cgroup, cred->pid) < 0) - fprintf(f, "0 error: %s\r\n", strerror(errno)); + mprintf(fd, "0 error: %s", strerror(errno)); else - fputs("1 blessed\r\n", f); + mprintf(fd, "1 blessed"); } else if (begins_with("stop ", line)) { pid_t pid = atoi(line + sizeof("stop ")); @@ -176,27 +211,27 @@ sockerror: /* Sanity check. */ if (pid < 10 || pid > 32768) { syslog(LOG_WARNING, "stop: invalid pid (%d)", pid); - fputs("0 invalid pid\r\n", f); - fclose(f); + mprintf(fd, "0 invalid pid"); + close(fd); continue; } if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) { - fputs("0 task not marked as computation\r\n", f); - fclose(f); + mprintf(fd, "0 task not marked as computation"); + close(fd); continue; } syslog(LOG_INFO, "stopping process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid); kill(pid, SIGTERM); /* TODO: Grace period and then kill with SIGKILL. */ - fputs("1 task stopped\r\n", f); + mprintf(fd, "1 task stopped"); } else if (!strcmp("stopall", line)) { pid_t *tasks; int tasks_n = cgroup_task_list(chier, cgroup, &tasks); if (tasks_n < 0) { - fprintf(f, "0 error: %s\r\n", strerror(errno)); - fclose(f); + mprintf(fd, "0 error: %s\r\n", strerror(errno)); + close(fd); continue; } for (int i = 0; i < tasks_n; i++) { @@ -204,7 +239,7 @@ sockerror: kill(tasks[i], SIGTERM); } /* TODO: Grace period and then kill with SIGKILL. */ - fprintf(f, "1 %d tasks stopped\r\n", tasks_n); + mprintf(fd, "1 %d tasks stopped", tasks_n); free(tasks); } else if (begins_with("limitmem ", line)) { @@ -215,33 +250,33 @@ sockerror: /* Sanity check. */ if (limit < 1024 || limit > total) { syslog(LOG_WARNING, "limitmem: invalid limit (%zu)", limit); - fputs("0 invalid limit value\r\n", f); - fclose(f); + mprintf(fd, "0 invalid limit value"); + close(fd); continue; } if (limit < mincomp) { - fprintf(f, "-1 at least %zuM must remain available for computations.\r\n", mincomp / 1048576); - fclose(f); + mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576); + close(fd); continue; } if (total - limit < minuser) { - fprintf(f, "-2 at least %zuM must remain available for users.\r\n", minuser / 1048576); - fclose(f); + mprintf(fd, "-2 at least %zuM must remain available for users.", minuser / 1048576); + close(fd); continue; } syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid); if (cgroup_set_mem_limit(chier, cgroup, limit) < 0) - fprintf(f, "0 error: %s\r\n", strerror(errno)); + mprintf(fd, "0 error: %s", strerror(errno)); else - fputs("1 limit set\r\n", f); + mprintf(fd, "1 limit set"); } else { syslog(LOG_WARNING, "invalid command (%s)", line); } - fclose(f); + close(fd); } return EXIT_FAILURE; -- 2.11.4.GIT