New release.
[rox-archive.git] / support.py
blob14187306e6d7a93b3adaab5dcb1d7dcd8da855b1
1 #!/usr/bin/env python
3 from rox import g, saving
4 import rox
5 import fcntl
7 try:
8 from rox import processes
9 except ImportError:
10 rox.croak(_('Sorry, this version of Archive requires ROX-Lib 1.9.3 or later'))
12 import sys, os
14 class ChildError(Exception):
15 "Raised when the child process reports an error."
17 class ChildKilled(saving.AbortSave):
18 "Raised when child died due to calling the kill method."
19 def __init__(self):
20 saving.AbortSave.__init__(self, _("Operation aborted at user's request"))
22 def escape(text):
23 """Return text with \ and ' escaped"""
24 return text.replace("\\", "\\\\").replace("'", "\\'")
26 def Tmp(mode = 'w+b'):
27 "Create a seekable, randomly named temp file (deleted automatically after use)."
28 import tempfile
29 try:
30 return tempfile.NamedTemporaryFile(mode, suffix = '-archive')
31 except:
32 # python2.2 doesn't have NamedTemporaryFile...
33 pass
35 import random
36 name = tempfile.mktemp(`random.randint(1, 1000000)` + '-archive')
38 fd = os.open(name, os.O_RDWR|os.O_CREAT|os.O_EXCL, 0700)
39 tmp = tempfile.TemporaryFileWrapper(os.fdopen(fd, mode), name)
40 tmp.name = name
41 return tmp
43 def keep_on_exec(fd):
44 fcntl.fcntl(fd, fcntl.F_SETFD, 0)
46 class PipeThroughCommand(processes.Process):
47 def __init__(self, command, src, dst):
48 """Execute 'command' with src as stdin and writing to stream
49 dst. src must be a fileno() stream, but dst need not be.
50 Either stream may be None if input or output is not required.
51 Call the wait() method to wait for the command to finish."""
53 assert src is None or hasattr(src, 'fileno')
55 processes.Process.__init__(self)
57 self.command = command
58 self.dst = dst
59 self.src = src
60 self.tmp_stream = None
62 self.callback = None
63 self.killed = 0
64 self.errors = ""
66 self.start()
68 def pre_fork(self):
69 # Output to 'dst' directly if it's a fileno stream. Otherwise,
70 # send output to a temporary file.
71 assert self.tmp_stream is None
73 if self.dst:
74 if hasattr(self.dst, 'fileno'):
75 self.dst.flush()
76 self.tmp_stream = self.dst
77 else:
78 self.tmp_stream = Tmp()
80 def start_error(self):
81 """Clean up effects of pre_fork()."""
82 self.tmp_stream = None
84 def child_run(self):
85 src = self.src
87 if src:
88 os.dup2(src.fileno(), 0)
89 keep_on_exec(0)
90 os.lseek(0, 0, 0) # OpenBSD needs this, dunno why
91 if self.dst:
92 os.dup2(self.tmp_stream.fileno(), 1)
93 keep_on_exec(1)
95 if os.system(self.command) == 0:
96 os._exit(0) # No error code or signal
97 os._exit(1)
99 def parent_post_fork(self):
100 if self.dst and self.tmp_stream is self.dst:
101 self.tmp_stream = None
103 def got_error_output(self, data):
104 self.errors += data
106 def child_died(self, status):
107 errors = self.errors.strip()
109 err = None
111 if self.killed:
112 err = ChildKilled
113 elif errors:
114 err = ChildError(_("Errors from command '%s':\n%s") % (self.command, errors))
115 elif status != 0:
116 err = ChildError(_("Command '%s' returned an error code!") % self.command)
118 # If dst wasn't a fileno stream, copy from the temp file to it
119 if not err and self.tmp_stream:
120 self.tmp_stream.seek(0)
121 self.dst.write(self.tmp_stream.read())
122 self.tmp_stream = None
124 self.callback(err)
126 def wait(self):
127 """Run a recursive mainloop until the command terminates.
128 Raises an exception on error."""
129 done = []
130 def set_done(exception):
131 done.append(exception)
132 g.mainquit()
133 self.callback = set_done
134 while not done:
135 g.mainloop()
136 exception, = done
137 if exception:
138 raise exception
140 def kill(self):
141 self.killed = 1
142 processes.Process.kill(self)
144 def test():
145 "Check that this module works."
147 def show():
148 error = sys.exc_info()[1]
149 print "(error reported was '%s')" % error
151 def pipe_through_command(command, src, dst): PipeThroughCommand(command, src, dst).wait()
153 print "Test escape()..."
155 assert escape(''' a test ''') == ' a test '
156 assert escape(''' "a's test" ''') == ''' "a\\'s test" '''
157 assert escape(''' "a\\'s test" ''') == ''' "a\\\\\\'s test" '''
159 print "Test Tmp()..."
161 file = Tmp()
162 file.write('Hello')
163 print >>file, ' ',
164 file.flush()
165 os.write(file.fileno(), 'World')
167 file.seek(0)
168 assert file.read() == 'Hello World'
170 print "Test pipe_through_command():"
172 print "Try an invalid command..."
173 try:
174 pipe_through_command('bad_command_1234', None, None)
175 assert 0
176 except ChildError:
177 show()
178 else:
179 assert 0
181 print "Try a valid command..."
182 pipe_through_command('exit 0', None, None)
184 print "Writing to a non-fileno stream..."
185 from cStringIO import StringIO
186 a = StringIO()
187 pipe_through_command('echo Hello', None, a)
188 assert a.getvalue() == 'Hello\n'
190 print "Reading from a stream to a StringIO..."
191 file.seek(1)
192 pipe_through_command('cat', file, a)
193 assert a.getvalue() == 'Hello\nello World'
195 print "Writing to a fileno stream..."
196 file.seek(0)
197 file.truncate(0)
198 pipe_through_command('echo Foo', None, file)
199 file.seek(0)
200 assert file.read() == 'Foo\n'
202 print "Read and write fileno streams..."
203 src = Tmp()
204 src.write('123')
205 src.seek(0)
206 file.seek(0)
207 file.truncate(0)
208 pipe_through_command('cat', src, file)
209 file.seek(0)
210 assert file.read() == '123'
212 print "Detect non-zero exit value..."
213 try:
214 pipe_through_command('exit 1', None, None)
215 except ChildError:
216 show()
217 else:
218 assert 0
220 print "Detect writes to stderr..."
221 try:
222 pipe_through_command('echo one >&2; sleep 2; echo two >&2', None, None)
223 except ChildError:
224 show()
225 else:
226 assert 0
228 print "Check tmp file is deleted..."
229 name = file.name
230 assert os.path.exists(name)
231 file = None
232 assert not os.path.exists(name)
234 print "Check we can kill a runaway proces..."
235 ptc = PipeThroughCommand('sleep 100; exit 1', None, None)
236 def stop():
237 ptc.kill()
238 g.timeout_add(2000, stop)
239 try:
240 ptc.wait()
241 assert 0
242 except ChildKilled:
243 pass
245 print "All tests passed!"
247 if __name__ == '__main__':
248 test()