compctld: Do not require 15bit pid
[compctl.git] / compctld.c
blob27c551191ef949b984499795fb95a8e0507fcca8
1 #define _GNU_SOURCE /* struct ucred */
2 #include <assert.h>
3 #include <errno.h>
4 #include <signal.h>
5 #include <stdarg.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <sys/stat.h>
10 #include <sys/socket.h>
11 #include <sys/un.h>
12 #include <syslog.h>
13 #include <unistd.h>
15 #include "cgroup.h"
16 #include "common.h"
19 #define begins_with(s_, a_) (!strncmp(s_, a_, strlen(s_)))
22 void
23 logperror(const char *s)
25 syslog(LOG_ERR, "%s: %s", s, strerror(errno));
28 void
29 memory_limits(size_t *minuser, size_t *mincomp, size_t *maxcomp, size_t *total)
31 FILE *f = fopen("/proc/meminfo", "r");
32 char line[1024];
33 while (fgets(line, sizeof(line), f)) {
34 if (begins_with("MemTotal:", line)) {
35 *total = 0;
36 sscanf(line, "MemTotal:%zu", total);
37 *total *= 1024;
38 break;
41 fclose(f);
43 *minuser = *total * split_ratio;
44 if (*minuser < static_minfree)
45 *minuser = static_minfree;
46 if (*minuser > static_maxfree)
47 *minuser = static_maxfree;
49 *mincomp = static_minfree;
50 ssize_t smaxcomp = *total - *minuser;
51 *maxcomp = smaxcomp > 0 ? smaxcomp : 0;
52 /* maxcomp < mincomp may happen; they are used in different
53 * settings. */
56 size_t
57 get_default_mem_limit(void)
59 size_t minuser, mincomp, maxcomp, total;
60 memory_limits(&minuser, &mincomp, &maxcomp, &total);
61 return maxcomp;
65 void
66 cgroup_init(void)
68 if (cgroup_setup(chier, "memory") < 0)
69 exit(EXIT_FAILURE);
70 int ret = cgroup_create(chier, cgroup);
71 if (ret < 0)
72 exit(EXIT_FAILURE);
73 if (ret > 0) {
74 /* CGroup newly created, set limit. */
75 if (cgroup_set_mem_limit(chier, cgroup, get_default_mem_limit()) < 0)
76 exit(EXIT_FAILURE);
81 void
82 mprintf(int fd, char *fmt, ...)
84 char buf[1024];
86 va_list v;
87 va_start(v, fmt);
88 vsnprintf(buf, sizeof(buf), fmt, v);
89 va_end(v);
91 struct iovec iov = {
92 .iov_base = buf,
93 .iov_len = strlen(buf),
95 struct msghdr msg = {
96 .msg_iov = &iov,
97 .msg_iovlen = 1,
99 ssize_t sent = sendmsg(fd, &msg, 0);
100 if (sent < 0) {
101 logperror("sendmsg");
102 return;
104 if ((size_t) sent < iov.iov_len) {
105 syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len);
106 return;
112 main(int argc, char *argv[])
114 /* Do this while everyone can still see the error. */
115 cgroup_init();
117 pid_t p = fork();
118 if (p < 0) {
119 perror("fork");
120 exit(EXIT_FAILURE);
122 if (p > 0)
123 exit(EXIT_SUCCESS);
125 fclose(stderr);
126 fclose(stdout);
127 fclose(stdin);
128 openlog("compctl", LOG_PID, LOG_DAEMON);
129 cgroup_perror = logperror;
131 setsid();
133 int s = socket(AF_UNIX, SOCK_STREAM, 0);
134 /* TODO: Protect against double execution? */
135 unlink(SOCKFILE);
136 struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
137 if (bind(s, (struct sockaddr *) &sun, sizeof(sun.sun_family) + strlen(sun.sun_path) + 1) < 0) {
138 logperror(SOCKFILE);
139 exit(EXIT_FAILURE);
141 chmod(SOCKFILE, 0777);
142 listen(s, 10);
144 int fd;
145 while ((fd = accept(s, NULL, NULL)) >= 0) {
146 /* We handle only a single client at a time. This means
147 * that it is rather easy to write a script that will DOS
148 * the daemon, this is just an attack vector we ignore. */
149 /* TODO: alarm() to wake from stuck clients. */
151 /* Decode the message with command and credentials. */
153 int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
155 struct ucred *cred;
156 char cbuf[CMSG_SPACE(sizeof(*cred))];
157 char line[1024];
158 struct iovec iov = {
159 .iov_base = line,
160 .iov_len = sizeof(line),
162 struct msghdr msg = {
163 .msg_iov = &iov,
164 .msg_iovlen = 1,
165 .msg_control = cbuf,
166 .msg_controllen = sizeof(cbuf),
168 char *errmsg;
169 recvagain:;
170 int replylen = recvmsg(fd, &msg, 0);
171 if (replylen < 0) {
172 if (errno == EAGAIN)
173 goto recvagain;
174 errmsg = "recvmsg";
175 sockerror:
176 logperror(errmsg);
177 close(fd);
178 continue;
180 struct cmsghdr *cmsg;
181 cmsg = CMSG_FIRSTHDR(&msg);
182 if (cmsg == NULL || cmsg->cmsg_len != CMSG_LEN(sizeof(*cred))) {
183 syslog(LOG_INFO, "want %zu", CMSG_LEN(sizeof(*cred)));
184 errmsg = "cmsg";
185 goto sockerror;
187 if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDENTIALS) {
188 errmsg = "cmsg designation";
189 goto sockerror;
191 cred = (struct ucred *) CMSG_DATA(cmsg);
193 line[replylen] = 0;
194 if (replylen < 2) {
195 syslog(LOG_WARNING, "protocol error (%s)", line);
196 close(fd);
197 continue;
200 /* Analyze command */
201 if (!strcmp("blessme", line)) {
202 syslog(LOG_INFO, "new computation process %d", cred->pid);
203 if (cgroup_add_task(chier, cgroup, cred->pid) < 0)
204 mprintf(fd, "0 error: %s", strerror(errno));
205 else
206 mprintf(fd, "1 blessed");
208 } else if (begins_with("stop ", line)) {
209 pid_t pid = atoi(line + strlen("stop "));
211 /* Sanity check. */
212 if (pid < 10) {
213 syslog(LOG_WARNING, "stop: invalid pid (%d)", pid);
214 mprintf(fd, "0 invalid pid");
215 close(fd);
216 continue;
218 if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) {
219 mprintf(fd, "0 task not marked as computation");
220 close(fd);
221 continue;
224 syslog(LOG_INFO, "stopping process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid);
225 kill(pid, SIGTERM);
226 /* TODO: Grace period and then kill with SIGKILL. */
227 mprintf(fd, "1 task stopped");
229 } else if (!strcmp("stopall", line)) {
230 pid_t *tasks;
231 int tasks_n = cgroup_task_list(chier, cgroup, &tasks);
232 if (tasks_n < 0) {
233 mprintf(fd, "0 error: %s\r\n", strerror(errno));
234 close(fd);
235 continue;
237 for (int i = 0; i < tasks_n; i++) {
238 syslog(LOG_INFO, "stopping process %d (mass request by pid %d uid %d)", tasks[i], cred->pid, cred->uid);
239 kill(tasks[i], SIGTERM);
241 /* TODO: Grace period and then kill with SIGKILL. */
242 mprintf(fd, "1 %d tasks stopped", tasks_n);
243 free(tasks);
245 } else if (begins_with("limitmem ", line)) {
246 size_t limit = atol(line + strlen("limitmem "));
247 size_t minuser, mincomp, maxcomp, total;
248 memory_limits(&minuser, &mincomp, &maxcomp, &total);
250 if (limit < mincomp) {
251 mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576);
252 close(fd);
253 continue;
255 if (limit > total || total - limit < minuser) {
256 mprintf(fd, "-2 at least %zuM must remain available for users; maximum limit for computations is %zuM.", minuser / 1048576, (total - minuser) / 1048576);
257 close(fd);
258 continue;
261 syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid);
262 if (cgroup_set_mem_limit(chier, cgroup, limit) < 0)
263 mprintf(fd, "0 error: %s", strerror(errno));
264 else
265 mprintf(fd, "1 limit set");
267 } else {
268 syslog(LOG_WARNING, "invalid command (%s)", line);
271 close(fd);
274 return EXIT_FAILURE;