lib/util: add tfork()
[Samba.git] / lib / util / tfork.c
blob27b6cc05fcd4730e33d6b66287e1f4396a2087e5
1 /*
2 fork on steroids to avoid SIGCHLD and waitpid
4 Copyright (C) Stefan Metzmacher 2010
5 Copyright (C) Ralph Boehme 2017
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 3 of the License, or
10 (at your option) any later version.
12 This program is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with this program. If not, see <http://www.gnu.org/licenses/>.
21 #include "replace.h"
22 #include "system/wait.h"
23 #include "system/filesys.h"
24 #include "lib/util/samba_util.h"
25 #include "lib/util/sys_rw.h"
26 #include "lib/util/tfork.h"
27 #include "lib/util/debug.h"
29 struct tfork_state {
30 void (*old_sig_chld)(int);
31 int status_pipe[2];
32 pid_t *parent;
34 pid_t level0_pid;
35 int level0_status;
37 pid_t level1_pid;
38 int level1_errno;
40 pid_t level2_pid;
41 int level2_errno;
43 pid_t level3_pid;
47 * TODO: We should make this global thread local
49 static struct tfork_state *tfork_global;
51 static void tfork_sig_chld(int signum)
53 if (tfork_global->level1_pid > 0) {
54 int ret = waitpid(tfork_global->level1_pid,
55 &tfork_global->level0_status,
56 WNOHANG);
57 if (ret == tfork_global->level1_pid) {
58 tfork_global->level1_pid = -1;
59 return;
64 * Not our child, forward to old handler
67 if (tfork_global->old_sig_chld == SIG_IGN) {
68 return;
71 if (tfork_global->old_sig_chld == SIG_DFL) {
72 return;
75 tfork_global->old_sig_chld(signum);
78 static pid_t level2_fork_and_wait(int child_ready_fd)
80 int status;
81 ssize_t written;
82 pid_t pid;
83 int fd;
84 bool wait;
87 * Child level 2.
89 * Do a final fork and if the tfork() caller passed a status_fd, wait
90 * for child3 and return its exit status via status_fd.
93 pid = fork();
94 if (pid == 0) {
96 * Child level 3, this one finally returns from tfork() as child
97 * with pid 0.
99 * Cleanup all ressources we allocated before returning.
101 close(child_ready_fd);
102 close(tfork_global->status_pipe[1]);
104 if (tfork_global->parent != NULL) {
106 * we're in the child and return the level0 parent pid
108 *tfork_global->parent = tfork_global->level0_pid;
111 anonymous_shared_free(tfork_global);
112 tfork_global = NULL;
114 return 0;
117 tfork_global->level3_pid = pid;
118 if (tfork_global->level3_pid == -1) {
119 tfork_global->level2_errno = errno;
120 _exit(0);
123 sys_write(child_ready_fd, &(char){0}, 1);
125 if (tfork_global->status_pipe[1] == -1) {
126 _exit(0);
128 wait = true;
131 * We're going to stay around until child3 exits, so lets close all fds
132 * other then the pipe fd we may have inherited from the caller.
134 fd = dup2(tfork_global->status_pipe[1], 0);
135 if (fd == -1) {
136 status = errno;
137 kill(tfork_global->level3_pid, SIGKILL);
138 wait = false;
140 closefrom(1);
142 while (wait) {
143 int ret = waitpid(tfork_global->level3_pid, &status, 0);
144 if (ret == -1) {
145 if (errno == EINTR) {
146 continue;
148 status = errno;
150 break;
153 written = sys_write(fd, &status, sizeof(status));
154 if (written != sizeof(status)) {
155 abort();
158 _exit(0);
161 pid_t tfork(int *status_fd, pid_t *parent)
163 int ret;
164 pid_t pid;
165 pid_t child;
167 tfork_global = (struct tfork_state *)
168 anonymous_shared_allocate(sizeof(struct tfork_state));
169 if (tfork_global == NULL) {
170 return -1;
173 tfork_global->parent = parent;
174 tfork_global->status_pipe[0] = -1;
175 tfork_global->status_pipe[1] = -1;
177 tfork_global->level0_pid = getpid();
178 tfork_global->level0_status = -1;
179 tfork_global->level1_pid = -1;
180 tfork_global->level1_errno = ECANCELED;
181 tfork_global->level2_pid = -1;
182 tfork_global->level2_errno = ECANCELED;
183 tfork_global->level3_pid = -1;
185 if (status_fd != NULL) {
186 ret = pipe(&tfork_global->status_pipe[0]);
187 if (ret != 0) {
188 int saved_errno = errno;
190 anonymous_shared_free(tfork_global);
191 tfork_global = NULL;
192 errno = saved_errno;
193 return -1;
196 *status_fd = tfork_global->status_pipe[0];
200 * We need to set our own signal handler to prevent any existing signal
201 * handler from reaping our child.
203 tfork_global->old_sig_chld = CatchSignal(SIGCHLD, tfork_sig_chld);
205 pid = fork();
206 if (pid == 0) {
207 int level2_pipe[2];
208 char c;
209 ssize_t nread;
212 * Child level 1.
214 * Restore SIGCHLD handler
216 CatchSignal(SIGCHLD, SIG_DFL);
219 * Close read end of the signal pipe, we don't need it anymore
220 * and don't want to leak it into childs.
222 if (tfork_global->status_pipe[0] != -1) {
223 close(tfork_global->status_pipe[0]);
224 tfork_global->status_pipe[0] = -1;
228 * Create a pipe for waiting for the child level 2 to finish
229 * forking.
231 ret = pipe(&level2_pipe[0]);
232 if (ret != 0) {
233 tfork_global->level1_errno = errno;
234 _exit(0);
237 pid = fork();
238 if (pid == 0) {
241 * Child level 2.
244 close(level2_pipe[0]);
245 return level2_fork_and_wait(level2_pipe[1]);
248 tfork_global->level2_pid = pid;
249 if (tfork_global->level2_pid == -1) {
250 tfork_global->level1_errno = errno;
251 _exit(0);
254 close(level2_pipe[1]);
255 level2_pipe[1] = -1;
257 nread = sys_read(level2_pipe[0], &c, 1);
258 if (nread != 1) {
259 abort();
261 _exit(0);
264 tfork_global->level1_pid = pid;
265 if (tfork_global->level1_pid == -1) {
266 int saved_errno = errno;
268 anonymous_shared_free(tfork_global);
269 tfork_global = NULL;
270 errno = saved_errno;
271 return -1;
275 * By using the helper variable pid we avoid a TOCTOU with the signal
276 * handler that will set tfork_global->level1_pid to -1 (which would
277 * cause waitpid() to block waiting for another exitted child).
279 * We can't avoid the race waiting for pid twice (in the signal handler
280 * and then again here in the while loop), but we must avoid waiting for
281 * -1 and this does the trick.
283 pid = tfork_global->level1_pid;
285 while (tfork_global->level1_pid != -1) {
286 ret = waitpid(pid, &tfork_global->level0_status, 0);
287 if (ret == -1 && errno == EINTR) {
288 continue;
291 break;
294 CatchSignal(SIGCHLD, tfork_global->old_sig_chld);
296 if (tfork_global->level0_status != 0) {
297 anonymous_shared_free(tfork_global);
298 tfork_global = NULL;
299 errno = ECHILD;
300 return -1;
303 if (tfork_global->level2_pid == -1) {
304 int saved_errno = tfork_global->level1_errno;
306 anonymous_shared_free(tfork_global);
307 tfork_global = NULL;
308 errno = saved_errno;
309 return -1;
312 if (tfork_global->level3_pid == -1) {
313 int saved_errno = tfork_global->level2_errno;
315 anonymous_shared_free(tfork_global);
316 tfork_global = NULL;
317 errno = saved_errno;
318 return -1;
321 child = tfork_global->level3_pid;
322 anonymous_shared_free(tfork_global);
323 tfork_global = NULL;
325 return child;