Fix --abort-error: exiting after running the command is OK.
[gsh.git] / gsh / remote_dispatcher.py
blob725df6d5eaa9ca1666ae8fb5668493042f7eea51
1 # This program is free software; you can redistribute it and/or modify
2 # it under the terms of the GNU General Public License as published by
3 # the Free Software Foundation; either version 2 of the License, or
4 # (at your option) any later version.
6 # This program is distributed in the hope that it will be useful,
7 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # GNU Library General Public License for more details.
11 # You should have received a copy of the GNU General Public License
12 # along with this program; if not, write to the Free Software
13 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
15 # See the COPYING file for license information.
17 # Copyright (c) 2006, 2007 Guillaume Chazarain <guichaz@yahoo.fr>
19 import asyncore
20 import fcntl
21 import os
22 import pty
23 import random
24 import signal
25 import struct
26 import sys
27 import termios
28 import time
30 from gsh.buffered_dispatcher import buffered_dispatcher
31 from gsh.console import console_output
32 from gsh.terminal_size import terminal_size
34 # Either the remote shell is expecting a command or one is already running
35 STATE_NOT_STARTED, \
36 STATE_IDLE, \
37 STATE_EXPECTING_NEXT_LINE, \
38 STATE_RUNNING, \
39 STATE_TERMINATED = range(5)
41 STATE_NAMES = ['not_started', 'idle', 'expecting_next_line',
42 'running', 'terminated']
44 def all_instances():
45 """Iterator over all the remote_dispatcher instances"""
46 for i in asyncore.socket_map.itervalues():
47 if isinstance(i, remote_dispatcher):
48 yield i
50 def make_unique_name(name):
51 display_names = set([i.display_name for i in all_instances()])
52 candidate_name = name
53 if candidate_name in display_names:
54 i = 1
55 while True:
56 candidate_name = '%s#%d' % (name, i)
57 if candidate_name not in display_names:
58 break
59 i += 1
60 return candidate_name
62 def count_completed_processes():
63 """Return a tuple with the number of ready processes and the total number"""
64 completed_processes = 0
65 total = 0
66 for i in all_instances():
67 if i.enabled:
68 total += 1
69 if i.state is STATE_IDLE:
70 completed_processes += 1
71 return completed_processes, total
73 def handle_unfinished_lines():
74 """Typically we print only lines with a '\n', but if some buffers keep an
75 unfinished line for some time we'll add an artificial '\n'"""
76 for r in all_instances():
77 if r.read_buffer and r.read_buffer[0] != chr(27):
78 break
79 else:
80 # No unfinished lines
81 return
83 begin = time.time()
84 asyncore.loop(count=1, timeout=0.2, use_poll=True)
85 duration = time.time() - begin
86 if duration >= 0.15:
87 for r in all_instances():
88 r.print_unfinished_line()
90 def dispatch_termination_to_all():
91 """Start the termination procedure in all remote shells"""
92 for r in all_instances():
93 r.dispatch_termination()
95 def all_terminated():
96 """For each remote shell we determine if its terminated by checking if
97 it is in the right state or if it requested termination but will never
98 receive the acknowledgement"""
99 for i in all_instances():
100 if i.state is not STATE_TERMINATED:
101 if i.enabled or not i.termination:
102 return False
103 return True
105 def update_terminal_size():
106 """Propagate the terminal size to the remote shells accounting for the
107 place taken by the longest name"""
108 w, h = terminal_size()
109 lengths = [len(i.display_name) for i in all_instances() if i.enabled]
110 if not lengths:
111 return
112 max_name_len = max(lengths)
113 for i in all_instances():
114 padding_len = max_name_len - len(i.display_name)
115 new_prefix = i.display_name + padding_len * ' ' + ': '
116 if len(new_prefix) < len(i.prefix) and not i.options.interactive:
117 # In non-interactive mode, remote processes leave as soon
118 # as they are terminated, but we don't want to break the
119 # indentation if all the remaining processes have short names.
120 return
121 i.prefix = new_prefix
122 w = max(w - max_name_len - 2, min(w, 10))
123 # python bug http://python.org/sf/1112949 on amd64
124 # from ajaxterm.py
125 bug = struct.unpack('i', struct.pack('I', termios.TIOCSWINSZ))[0]
126 packed_size = struct.pack('HHHH', h, w, 0, 0)
127 term_size = w, h
128 for i in all_instances():
129 if i.enabled and i.term_size != term_size:
130 i.term_size = term_size
131 fcntl.ioctl(i.fd, bug, packed_size)
133 def format_info(info_list):
134 """Turn a 2-dimension list of strings into a 1-dimension list of strings
135 with correct spacing"""
136 info_list.sort(key=lambda i:int(i[1][3:]))
137 max_lengths = []
138 if info_list:
139 nr_columns = len(info_list[0])
140 else:
141 nr_columns = 0
142 for i in xrange(nr_columns):
143 max_lengths.append(max([len(str(info[i])) for info in info_list]))
144 for info_id in xrange(len(info_list)):
145 info = info_list[info_id]
146 for str_id in xrange(len(info)):
147 orig_str = str(info[str_id])
148 indent = max_lengths[str_id] - len(orig_str)
149 info[str_id] = orig_str + indent * ' '
150 info_list[info_id] = ' '.join(info)
152 class remote_dispatcher(buffered_dispatcher):
153 """A remote_dispatcher is a ssh process we communicate with"""
155 def __init__(self, options, hostname):
156 self.pid, fd = pty.fork()
157 if self.pid == 0:
158 # Child
159 self.launch_ssh(options, hostname)
160 sys.exit(1)
161 # Parent
162 self.hostname = hostname
163 buffered_dispatcher.__init__(self, fd)
164 self.options = options
165 self.log_path = None
166 self.active = True # deactived shells are dead forever
167 self.enabled = True # shells can be enabled and disabled
168 self.state = STATE_NOT_STARTED
169 self.termination = None
170 self.term_size = (-1, -1)
171 self.prefix = ''
172 self.change_name(hostname)
173 self.set_prompt()
174 self.pending_rename = None
175 if options.command:
176 self.dispatch_write(options.command + '\n')
177 self.dispatch_termination()
178 self.options.interactive = False
179 else:
180 self.options.interactive = sys.stdin.isatty()
182 def launch_ssh(self, options, name):
183 """Launch the ssh command in the child process"""
184 evaluated = options.ssh % {'host': name}
185 shell = os.environ.get('SHELL', '/bin/sh')
186 if options.quick_sh:
187 evaluated = '%s -t %s sh' % (evaluated, name)
188 elif evaluated == options.ssh:
189 evaluated = '%s %s' % (evaluated, name)
190 os.execlp(shell, shell, '-c', evaluated)
192 def set_enabled(self, enabled):
193 self.enabled = enabled
194 update_terminal_size()
196 def change_state(self, state):
197 """Change the state of the remote process, logging the change"""
198 if state is not self.state:
199 if self.is_logging(debug=True):
200 self.log('state => %s\n' % (STATE_NAMES[state]), debug=True)
201 self.state = state
203 def disconnect(self):
204 """We are no more interested in this remote process"""
205 self.read_buffer = ''
206 self.write_buffer = ''
207 self.active = False
208 self.set_enabled(False)
209 if self.options.abort_error and self.state is STATE_NOT_STARTED:
210 raise asyncore.ExitNow
212 def reconnect(self):
213 """Relaunch and reconnect to this same remote process"""
214 try:
215 os.kill(self.pid, signal.SIGKILL)
216 except OSError:
217 # The process was already dead, no problem
218 pass
219 self.close()
220 remote_dispatcher(self.options, self.hostname)
222 def dispatch_termination(self):
223 """Start the termination procedure on this remote process, using the
224 same trick as the prompt to hide it"""
225 if not self.termination:
226 self.term1 = '[gsh termination ' + str(random.random())[2:]
227 self.term2 = str(random.random())[2:] + ']'
228 self.termination = self.term1 + self.term2
229 self.dispatch_write('echo "%s""%s"\n' % (self.term1, self.term2))
230 if self.state is not STATE_NOT_STARTED:
231 self.change_state(STATE_EXPECTING_NEXT_LINE)
233 def set_prompt(self):
234 """The prompt is important because we detect the readyness of a process
235 by waiting for its prompt. The prompt is built in two parts for it not
236 to appear in its building"""
237 # No right prompt
238 self.dispatch_write('RPS1=\n')
239 self.dispatch_write('RPROMPT=\n')
240 self.dispatch_write('TERM=ansi\n')
241 prompt1 = '[gsh prompt ' + str(random.random())[2:]
242 prompt2 = str(random.random())[2:] + ']'
243 self.prompt = prompt1 + prompt2
244 self.dispatch_write('PS1="%s""%s\n"\n' % (prompt1, prompt2))
246 def readable(self):
247 """We are always interested in reading from active remote processes if
248 the buffer is OK"""
249 return self.active and buffered_dispatcher.readable(self)
251 def handle_error(self):
252 """An exception may or may not lead to a disconnection"""
253 if buffered_dispatcher.handle_error(self):
254 console_output('Error talking to %s\n ' % (self.display_name),
255 sys.stderr)
256 self.disconnect()
258 def handle_read_fast_case(self, data):
259 """If we are in a fast case we'll avoid the long processing of each
260 line"""
261 if self.prompt in data or self.state is not STATE_RUNNING or \
262 self.termination and (self.term1 in data or self.term2 in data) or \
263 self.pending_rename and self.pending_rename in data:
264 # Slow case :-(
265 return False
267 last_nl = data.rfind('\n')
268 if last_nl == -1:
269 # No '\n' in data => slow case
270 return False
271 self.read_buffer = data[last_nl + 1:]
272 data = data[:last_nl].strip('\n').replace('\r', '\n')
273 while True:
274 no_empty_lines = data.replace('\n\n', '\n')
275 if len(no_empty_lines) == len(data):
276 break
277 data = no_empty_lines
278 if not data:
279 return True
280 if self.is_logging():
281 self.log(data + '\n')
282 console_output(self.prefix + \
283 data.replace('\n', '\n' + self.prefix) + '\n')
284 return True
286 def handle_read(self):
287 """We got some output from a remote shell, this is one of the state
288 machine"""
289 if not self.active:
290 return
291 new_data = buffered_dispatcher.handle_read(self)
292 if self.is_logging(debug=True):
293 self.log('==> ' + new_data, debug=True)
294 if self.handle_read_fast_case(self.read_buffer):
295 return
296 lf_pos = new_data.find('\n')
297 if lf_pos >= 0:
298 # Optimization: we knew there were no '\n' in the previous read
299 # buffer, so we searched only in the new_data and we offset the
300 # found index by the length of the previous buffer
301 lf_pos += len(self.read_buffer) - len(new_data)
302 limit = buffered_dispatcher.MAX_BUFFER_SIZE / 10
303 if lf_pos < 0 and len(self.read_buffer) > limit:
304 # A large unfinished line is treated as a complete line
305 # Or maybe there is a '\r' to break the line
306 lf_pos = max(new_data.find('\r'), limit)
308 while lf_pos >= 0:
309 # For each line in the buffer
310 line = self.read_buffer[:lf_pos + 1]
311 if self.prompt in line:
312 if self.options.interactive:
313 self.change_state(STATE_IDLE)
314 else:
315 self.change_state(STATE_EXPECTING_NEXT_LINE)
316 elif self.termination and self.termination in line:
317 self.change_state(STATE_TERMINATED)
318 self.disconnect()
319 elif self.termination and self.term1 in line and self.term2 in line:
320 # Just ignore this line
321 pass
322 elif self.pending_rename and self.pending_rename in line:
323 self.received_rename(line)
324 elif self.state is STATE_EXPECTING_NEXT_LINE:
325 self.change_state(STATE_RUNNING)
326 elif self.state is STATE_RUNNING:
327 line = line.replace('\r', '\n')
328 if line[-1] != '\n':
329 line += '\n'
330 if self.is_logging():
331 self.log(line)
332 if line.strip():
333 console_output(self.prefix + line)
335 # Go to the next line in the buffer
336 self.read_buffer = self.read_buffer[lf_pos + 1:]
337 if self.handle_read_fast_case(self.read_buffer):
338 return
339 lf_pos = self.read_buffer.find('\n')
341 def print_unfinished_line(self):
342 """The unfinished line stayed long enough in the buffer to be printed"""
343 if self.state is STATE_RUNNING:
344 line = self.read_buffer + '\n'
345 self.read_buffer = ''
346 if self.is_logging():
347 self.log(line)
348 console_output(self.prefix + line)
350 def writable(self):
351 """Do we want to write something?"""
352 return self.active and buffered_dispatcher.writable(self)
354 def is_logging(self, debug=False):
355 if debug:
356 return self.options.debug
357 return self.log_path is not None
359 def log(self, msg, debug=False):
360 """Log some information, either to a file or on the console"""
361 if self.log_path is None:
362 if debug and self.options.debug:
363 state = STATE_NAMES[self.state]
364 console_output('[dbg] %s[%s]: %s' %
365 (self.display_name, state, msg))
366 else:
367 # None != False, that's why we use 'not'
368 if (not debug) == (not self.options.debug):
369 log = os.open(self.log_path,
370 os.O_WRONLY|os.O_APPEND|os.O_CREAT, 0664)
371 os.write(log, msg)
372 os.close(log)
374 def get_info(self):
375 """Return a list will all information available about this process"""
376 if self.active:
377 state = STATE_NAMES[self.state]
378 else:
379 state = ''
381 return [self.display_name, 'fd:%d' % (self.fd),
382 'r:%d' % (len(self.read_buffer)),
383 'w:%d' % (len(self.write_buffer)),
384 self.active and 'active' or 'dead',
385 self.enabled and 'enabled' or 'disabled',
386 state]
388 def dispatch_write(self, buf):
389 """There is new stuff to write when possible"""
390 if self.active and self.enabled:
391 if self.is_logging(debug=True):
392 self.log('<== ' + buf, debug=True)
393 buffered_dispatcher.dispatch_write(self, buf)
395 def change_name(self, name):
396 self.display_name = None
397 self.display_name = make_unique_name(name)
398 update_terminal_size()
399 if self.options.log_dir:
400 # The log file
401 filename = self.display_name.replace('/', '_')
402 log_path = os.path.join(self.options.log_dir, filename)
403 if self.log_path:
404 # Rename the previous log
405 os.rename(self.log_path, log_path)
406 self.log_path = log_path
408 def rename(self, string):
409 previous_name = self.display_name
410 if string:
411 pending_rename1 = str(random.random())[2:] + ','
412 pending_rename2 = str(random.random())[2:] + ':'
413 self.pending_rename = pending_rename1 + pending_rename2
414 self.dispatch_write('echo "%s""%s" %s\n' %
415 (pending_rename1, pending_rename2, string))
416 self.change_state(STATE_EXPECTING_NEXT_LINE)
417 else:
418 self.change_name(self.hostname)
420 def received_rename(self, line):
421 new_name = line[len(self.pending_rename) + 1:-1]
422 self.change_name(new_name)
423 self.pending_rename = None