Fixing a missing quote meant to go on r5471
[autotest-zwu.git] / server / subcommand.py
blob8aa2d96affc0df89668fb508e1e519ecbd4f33ff
1 __author__ = """Copyright Andy Whitcroft, Martin J. Bligh - 2006, 2007"""
3 import sys, os, subprocess, time, signal, cPickle, logging
5 from autotest_lib.client.common_lib import error, utils
8 # entry points that use subcommand must set this to their logging manager
9 # to get log redirection for subcommands
10 logging_manager_object = None
13 def parallel(tasklist, timeout=None, return_results=False):
14 """
15 Run a set of predefined subcommands in parallel.
17 @param tasklist: A list of subcommand instances to execute.
18 @param timeout: Number of seconds after which the commands should timeout.
19 @param return_results: If True instead of an AutoServError being raised
20 on any error a list of the results|exceptions from the tasks is
21 returned. [default: False]
22 """
23 run_error = False
24 for task in tasklist:
25 task.fork_start()
27 remaining_timeout = None
28 if timeout:
29 endtime = time.time() + timeout
31 results = []
32 for task in tasklist:
33 if timeout:
34 remaining_timeout = max(endtime - time.time(), 1)
35 try:
36 status = task.fork_waitfor(timeout=remaining_timeout)
37 except error.AutoservSubcommandError:
38 run_error = True
39 else:
40 if status != 0:
41 run_error = True
43 results.append(cPickle.load(task.result_pickle))
44 task.result_pickle.close()
46 if return_results:
47 return results
48 elif run_error:
49 message = 'One or more subcommands failed:\n'
50 for task, result in zip(tasklist, results):
51 message += 'task: %s returned/raised: %r\n' % (task, result)
52 raise error.AutoservError(message)
55 def parallel_simple(function, arglist, log=True, timeout=None,
56 return_results=False):
57 """
58 Each element in the arglist used to create a subcommand object,
59 where that arg is used both as a subdir name, and a single argument
60 to pass to "function".
62 We create a subcommand object for each element in the list,
63 then execute those subcommand objects in parallel.
65 NOTE: As an optimization, if len(arglist) == 1 a subcommand is not used.
67 @param function: A callable to run in parallel once per arg in arglist.
68 @param arglist: A list of single arguments to be used one per subcommand;
69 typically a list of machine names.
70 @param log: If True, output will be written to output in a subdirectory
71 named after each subcommand's arg.
72 @param timeout: Number of seconds after which the commands should timeout.
73 @param return_results: If True instead of an AutoServError being raised
74 on any error a list of the results|exceptions from the function
75 called on each arg is returned. [default: False]
77 @returns None or a list of results/exceptions.
78 """
79 if not arglist:
80 logging.warn('parallel_simple was called with an empty arglist, '
81 'did you forget to pass in a list of machines?')
82 # Bypass the multithreading if only one machine.
83 if len(arglist) == 1:
84 arg = arglist[0]
85 if return_results:
86 try:
87 result = function(arg)
88 except Exception, e:
89 return [e]
90 return [result]
91 else:
92 function(arg)
93 return
95 subcommands = []
96 for arg in arglist:
97 args = [arg]
98 if log:
99 subdir = str(arg)
100 else:
101 subdir = None
102 subcommands.append(subcommand(function, args, subdir))
103 return parallel(subcommands, timeout, return_results=return_results)
106 class subcommand(object):
107 fork_hooks, join_hooks = [], []
109 def __init__(self, func, args, subdir = None):
110 # func(args) - the subcommand to run
111 # subdir - the subdirectory to log results in
112 if subdir:
113 self.subdir = os.path.abspath(subdir)
114 if not os.path.exists(self.subdir):
115 os.mkdir(self.subdir)
116 self.debug = os.path.join(self.subdir, 'debug')
117 if not os.path.exists(self.debug):
118 os.mkdir(self.debug)
119 else:
120 self.subdir = None
121 self.debug = None
123 self.func = func
124 self.args = args
125 self.lambda_function = lambda: func(*args)
126 self.pid = None
127 self.returncode = None
130 def __str__(self):
131 return str('subcommand(func=%s, args=%s, subdir=%s)' %
132 (self.func, self.args, self.subdir))
135 @classmethod
136 def register_fork_hook(cls, hook):
137 """ Register a function to be called from the child process after
138 forking. """
139 cls.fork_hooks.append(hook)
142 @classmethod
143 def register_join_hook(cls, hook):
144 """ Register a function to be called when from the child process
145 just before the child process terminates (joins to the parent). """
146 cls.join_hooks.append(hook)
149 def redirect_output(self):
150 if self.subdir and logging_manager_object:
151 tag = os.path.basename(self.subdir)
152 logging_manager_object.tee_redirect_debug_dir(self.debug, tag=tag)
155 def fork_start(self):
156 sys.stdout.flush()
157 sys.stderr.flush()
158 r, w = os.pipe()
159 self.returncode = None
160 self.pid = os.fork()
162 if self.pid: # I am the parent
163 os.close(w)
164 self.result_pickle = os.fdopen(r, 'r')
165 return
166 else:
167 os.close(r)
169 # We are the child from this point on. Never return.
170 signal.signal(signal.SIGTERM, signal.SIG_DFL) # clear handler
171 if self.subdir:
172 os.chdir(self.subdir)
173 self.redirect_output()
175 try:
176 for hook in self.fork_hooks:
177 hook(self)
178 result = self.lambda_function()
179 os.write(w, cPickle.dumps(result, cPickle.HIGHEST_PROTOCOL))
180 exit_code = 0
181 except Exception, e:
182 logging.exception('function failed')
183 exit_code = 1
184 os.write(w, cPickle.dumps(e, cPickle.HIGHEST_PROTOCOL))
186 os.close(w)
188 try:
189 for hook in self.join_hooks:
190 hook(self)
191 finally:
192 sys.stdout.flush()
193 sys.stderr.flush()
194 os._exit(exit_code)
197 def _handle_exitstatus(self, sts):
199 This is partially borrowed from subprocess.Popen.
201 if os.WIFSIGNALED(sts):
202 self.returncode = -os.WTERMSIG(sts)
203 elif os.WIFEXITED(sts):
204 self.returncode = os.WEXITSTATUS(sts)
205 else:
206 # Should never happen
207 raise RuntimeError("Unknown child exit status!")
209 if self.returncode != 0:
210 print "subcommand failed pid %d" % self.pid
211 print "%s" % (self.func,)
212 print "rc=%d" % self.returncode
213 print
214 if self.debug:
215 stderr_file = os.path.join(self.debug, 'autoserv.stderr')
216 if os.path.exists(stderr_file):
217 for line in open(stderr_file).readlines():
218 print line,
219 print "\n--------------------------------------------\n"
220 raise error.AutoservSubcommandError(self.func, self.returncode)
223 def poll(self):
225 This is borrowed from subprocess.Popen.
227 if self.returncode is None:
228 try:
229 pid, sts = os.waitpid(self.pid, os.WNOHANG)
230 if pid == self.pid:
231 self._handle_exitstatus(sts)
232 except os.error:
233 pass
234 return self.returncode
237 def wait(self):
239 This is borrowed from subprocess.Popen.
241 if self.returncode is None:
242 pid, sts = os.waitpid(self.pid, 0)
243 self._handle_exitstatus(sts)
244 return self.returncode
247 def fork_waitfor(self, timeout=None):
248 if not timeout:
249 return self.wait()
250 else:
251 end_time = time.time() + timeout
252 while time.time() <= end_time:
253 returncode = self.poll()
254 if returncode is not None:
255 return returncode
256 time.sleep(1)
258 utils.nuke_pid(self.pid)
259 print "subcommand failed pid %d" % self.pid
260 print "%s" % (self.func,)
261 print "timeout after %ds" % timeout
262 print
263 return None