Allow operations to be aborted.
[rox-archive.git] / support.py
blob7ce496e04775c7fd8c45ff5e6913f31dbf0eaa05
1 #!/usr/bin/env python
3 import findrox
4 from rox import g, saving
6 import sys, os
8 class ChildError(Exception):
9 "Raised when the child process reports an error."
11 class ChildKilled(saving.AbortSave):
12 "Raised when child died due to calling the kill method."
13 def __init__(self):
14 saving.AbortSave.__init__(self, "Operation aborted at user's request")
16 def escape(text):
17 """Return text with \ and ' escaped"""
18 return text.replace("\\", "\\\\").replace("'", "\\'")
20 def Tmp(mode = 'w+b'):
21 "Create a seekable, randomly named temp file (deleted automatically after use)."
22 import tempfile
23 import random
24 name = tempfile.mktemp(`random.randint(1, 1000000)` + '-archive')
26 fd = os.open(name, os.O_RDWR|os.O_CREAT|os.O_EXCL, 0700)
27 tmp = tempfile.TemporaryFileWrapper(os.fdopen(fd, mode), name)
28 tmp.name = name
29 return tmp
31 class PipeThroughCommand:
32 def __init__(self, command, src, dst):
33 """Execute 'command' with src as stdin and writing to stream
34 dst. src must be a fileno() stream, but dst need not be.
35 Either stream may be None if input or output is not required.
36 Call the wait() method to wait for the command to finish."""
38 assert src is None or hasattr(src, 'fileno')
40 # Output to 'dst' directly if it's a fileno stream. Otherwise,
41 # send output to a temporary file.
42 if dst:
43 if hasattr(dst, 'fileno'):
44 dst.flush()
45 tmp_stream = dst
46 else:
47 tmp_stream = Tmp()
48 fd = tmp_stream.fileno()
49 else:
50 fd = -1
52 # Create a pipe to collect stderr from child
53 stderr_r, stderr_w = os.pipe()
54 try:
55 child = os.fork()
56 except:
57 os.close(stderr_r)
58 os.close(stderr_w)
59 raise
61 if child == 0:
62 # This is the child process
63 try:
64 os.setpgid(0, 0) # Start a new process group
66 os.close(stderr_r)
67 os.dup2(stderr_w, 2)
68 if src:
69 os.dup2(src.fileno(), 0)
70 if fd != -1:
71 os.dup2(fd, 1)
72 if os.system(command) == 0:
73 os._exit(0) # No error code or signal
74 finally:
75 os._exit(1)
76 assert 0
78 # This is the parent process
79 os.close(stderr_w)
80 self.err_from_child = stderr_r
81 self.child = child
82 self.command = command
83 self.callback = None
84 self.killed = 0
86 self.dst = dst
87 if dst and tmp_stream is not dst:
88 self.tmp_stream = tmp_stream
89 else:
90 self.tmp_stream = None
92 self.errors = ""
93 self.tag = g.input_add_full(self.err_from_child,
94 g.gdk.INPUT_READ, self.got_errors)
96 def got_errors(self, source, cond):
97 got = os.read(self.err_from_child, 100)
98 if got:
99 self.errors += got
100 return
102 g.input_remove(self.tag)
103 del self.tag
105 errors = self.errors.strip()
107 # Reap zombie
108 pid, status = os.waitpid(self.child, 0)
109 self.child = -1
111 err = None
113 if self.killed:
114 err = ChildKilled
115 elif errors:
116 err = ChildError("Errors from command '%s':\n%s" % (self.command, errors))
117 elif status != 0:
118 err = ChildError("Command '%s' returned an error code!" % self.command)
120 # If dst wasn't a fileno stream, copy from the temp file to it
121 if not err and self.tmp_stream:
122 self.tmp_stream.seek(0)
123 self.dst.write(self.tmp_stream.read())
125 self.callback(err)
127 def wait(self):
128 """Run a recursive mainloop until the command terminates.
129 Raises an exception on error."""
130 done = []
131 def set_done(exception):
132 done.append(exception)
133 g.mainquit()
134 self.callback = set_done
135 while not done:
136 g.mainloop()
137 exception, = done
138 if exception:
139 raise exception
141 def kill(self):
142 assert self.child != -1
143 import signal
144 self.killed = 1
145 os.kill(-self.child, signal.SIGTERM)
147 def test():
148 "Check that this module works."
150 def show():
151 error = sys.exc_info()[1]
152 print "(error reported was '%s')" % error
154 def pipe_through_command(command, src, dst): PipeThroughCommand(command, src, dst).wait()
156 print "Test escape()..."
158 assert escape(''' a test ''') == ' a test '
159 assert escape(''' "a's test" ''') == ''' "a\\'s test" '''
160 assert escape(''' "a\\'s test" ''') == ''' "a\\\\\\'s test" '''
162 print "Test Tmp()..."
164 file = Tmp()
165 file.write('Hello')
166 print >>file, ' ',
167 file.flush()
168 os.write(file.fileno(), 'World')
170 file.seek(0)
171 assert file.read() == 'Hello World'
173 print "Test pipe_through_command():"
175 print "Try an invalid command..."
176 try:
177 pipe_through_command('bad_command_1234', None, None)
178 assert 0
179 except ChildError:
180 show()
181 else:
182 assert 0
184 print "Try a valid command..."
185 pipe_through_command('exit 0', None, None)
187 print "Writing to a non-fileno stream..."
188 from cStringIO import StringIO
189 a = StringIO()
190 pipe_through_command('echo Hello', None, a)
191 assert a.getvalue() == 'Hello\n'
193 print "Reading from a stream to a StringIO..."
194 file.seek(1)
195 pipe_through_command('cat', file, a)
196 assert a.getvalue() == 'Hello\nello World'
198 print "Writing to a fileno stream..."
199 file.seek(0)
200 file.truncate(0)
201 pipe_through_command('echo Foo', None, file)
202 file.seek(0)
203 assert file.read() == 'Foo\n'
205 print "Read and write fileno streams..."
206 src = Tmp()
207 src.write('123')
208 src.seek(0)
209 file.seek(0)
210 file.truncate(0)
211 pipe_through_command('cat', src, file)
212 file.seek(0)
213 assert file.read() == '123'
215 print "Detect non-zero exit value..."
216 try:
217 pipe_through_command('exit 1', None, None)
218 except ChildError:
219 show()
220 else:
221 assert 0
223 print "Detect writes to stderr..."
224 try:
225 pipe_through_command('echo one >&2; sleep 2; echo two >&2', None, None)
226 except ChildError:
227 show()
228 else:
229 assert 0
231 print "Check tmp file is deleted..."
232 name = file.name
233 assert os.path.exists(name)
234 file = None
235 assert not os.path.exists(name)
237 print "Check we can kill a runaway proces..."
238 ptc = PipeThroughCommand('sleep 100; exit 1', None, None)
239 def stop():
240 ptc.kill()
241 g.timeout_add(2000, stop)
242 try:
243 ptc.wait()
244 assert 0
245 except ChildKilled:
246 pass
248 print "All tests passed!"
250 if __name__ == '__main__':
251 test()