base_job.py: Make TAPReport() work on python 2.4
[autotest-zwu.git] / server / hosts / paramiko_host.py
blobd68b15e0d5c9322e590d27f38aa811cd824dc162
1 import os, sys, time, signal, socket, re, fnmatch, logging, threading
2 import paramiko
4 from autotest_lib.client.common_lib import utils, error, global_config
5 from autotest_lib.server import subcommand
6 from autotest_lib.server.hosts import abstract_ssh
9 class ParamikoHost(abstract_ssh.AbstractSSHHost):
10 KEEPALIVE_TIMEOUT_SECONDS = 30
11 CONNECT_TIMEOUT_SECONDS = 30
12 CONNECT_TIMEOUT_RETRIES = 3
13 BUFFSIZE = 2**16
15 def _initialize(self, hostname, *args, **dargs):
16 super(ParamikoHost, self)._initialize(hostname=hostname, *args, **dargs)
18 # paramiko is very noisy, tone down the logging
19 paramiko.util.log_to_file("/dev/null", paramiko.util.ERROR)
21 self.keys = self.get_user_keys(hostname)
22 self.pid = None
25 @staticmethod
26 def _load_key(path):
27 """Given a path to a private key file, load the appropriate keyfile.
29 Tries to load the file as both an RSAKey and a DSAKey. If the file
30 cannot be loaded as either type, returns None."""
31 try:
32 return paramiko.DSSKey.from_private_key_file(path)
33 except paramiko.SSHException:
34 try:
35 return paramiko.RSAKey.from_private_key_file(path)
36 except paramiko.SSHException:
37 return None
40 @staticmethod
41 def _parse_config_line(line):
42 """Given an ssh config line, return a (key, value) tuple for the
43 config value listed in the line, or (None, None)"""
44 match = re.match(r"\s*(\w+)\s*=?(.*)\n", line)
45 if match:
46 return match.groups()
47 else:
48 return None, None
51 @staticmethod
52 def get_user_keys(hostname):
53 """Returns a mapping of path -> paramiko.PKey entries available for
54 this user. Keys are found in the default locations (~/.ssh/id_[d|r]sa)
55 as well as any IdentityFile entries in the standard ssh config files.
56 """
57 raw_identity_files = ["~/.ssh/id_dsa", "~/.ssh/id_rsa"]
58 for config_path in ("/etc/ssh/ssh_config", "~/.ssh/config"):
59 config_path = os.path.expanduser(config_path)
60 if not os.path.exists(config_path):
61 continue
62 host_pattern = "*"
63 config_lines = open(config_path).readlines()
64 for line in config_lines:
65 key, value = ParamikoHost._parse_config_line(line)
66 if key == "Host":
67 host_pattern = value
68 elif (key == "IdentityFile"
69 and fnmatch.fnmatch(hostname, host_pattern)):
70 raw_identity_files.append(value)
72 # drop any files that use percent-escapes; we don't support them
73 identity_files = []
74 UNSUPPORTED_ESCAPES = ["%d", "%u", "%l", "%h", "%r"]
75 for path in raw_identity_files:
76 # skip this path if it uses % escapes
77 if sum((escape in path) for escape in UNSUPPORTED_ESCAPES):
78 continue
79 path = os.path.expanduser(path)
80 if os.path.exists(path):
81 identity_files.append(path)
83 # load up all the keys that we can and return them
84 user_keys = {}
85 for path in identity_files:
86 key = ParamikoHost._load_key(path)
87 if key:
88 user_keys[path] = key
90 # load up all the ssh agent keys
91 use_sshagent = global_config.global_config.get_config_value(
92 'AUTOSERV', 'use_sshagent_with_paramiko', type=bool)
93 if use_sshagent:
94 ssh_agent = paramiko.Agent()
95 for i, key in enumerate(ssh_agent.get_keys()):
96 user_keys['agent-key-%d' % i] = key
98 return user_keys
101 def _check_transport_error(self, transport):
102 error = transport.get_exception()
103 if error:
104 transport.close()
105 raise error
108 def _connect_socket(self):
109 """Return a socket for use in instantiating a paramiko transport. Does
110 not have to be a literal socket, it can be anything that the
111 paramiko.Transport constructor accepts."""
112 return self.hostname, self.port
115 def _connect_transport(self, pkey):
116 for _ in xrange(self.CONNECT_TIMEOUT_RETRIES):
117 transport = paramiko.Transport(self._connect_socket())
118 completed = threading.Event()
119 transport.start_client(completed)
120 completed.wait(self.CONNECT_TIMEOUT_SECONDS)
121 if completed.isSet():
122 self._check_transport_error(transport)
123 completed.clear()
124 transport.auth_publickey(self.user, pkey, completed)
125 completed.wait(self.CONNECT_TIMEOUT_SECONDS)
126 if completed.isSet():
127 self._check_transport_error(transport)
128 if not transport.is_authenticated():
129 transport.close()
130 raise paramiko.AuthenticationException()
131 return transport
132 logging.warn("SSH negotiation (%s:%d) timed out, retrying",
133 self.hostname, self.port)
134 # HACK: we can't count on transport.join not hanging now, either
135 transport.join = lambda: None
136 transport.close()
137 logging.error("SSH negotation (%s:%d) has timed out %s times, "
138 "giving up", self.hostname, self.port,
139 self.CONNECT_TIMEOUT_RETRIES)
140 raise error.AutoservSSHTimeout("SSH negotiation timed out")
143 def _init_transport(self):
144 for path, key in self.keys.iteritems():
145 try:
146 logging.debug("Connecting with %s", path)
147 transport = self._connect_transport(key)
148 transport.set_keepalive(self.KEEPALIVE_TIMEOUT_SECONDS)
149 self.transport = transport
150 self.pid = os.getpid()
151 return
152 except paramiko.AuthenticationException:
153 logging.debug("Authentication failure")
154 else:
155 raise error.AutoservSshPermissionDeniedError(
156 "Permission denied using all keys available to ParamikoHost",
157 utils.CmdResult())
160 def _open_channel(self, timeout):
161 start_time = time.time()
162 if os.getpid() != self.pid:
163 if self.pid is not None:
164 # HACK: paramiko tries to join() on its worker thread
165 # and this just hangs on linux after a fork()
166 self.transport.join = lambda: None
167 self.transport.atfork()
168 join_hook = lambda cmd: self._close_transport()
169 subcommand.subcommand.register_join_hook(join_hook)
170 logging.debug("Reopening SSH connection after a process fork")
171 self._init_transport()
173 channel = None
174 try:
175 channel = self.transport.open_session()
176 except (socket.error, paramiko.SSHException, EOFError), e:
177 logging.warn("Exception occured while opening session: %s", e)
178 if time.time() - start_time >= timeout:
179 raise error.AutoservSSHTimeout("ssh failed: %s" % e)
181 if not channel:
182 # we couldn't get a channel; re-initing transport should fix that
183 try:
184 self.transport.close()
185 except Exception, e:
186 logging.debug("paramiko.Transport.close failed with %s", e)
187 self._init_transport()
188 return self.transport.open_session()
189 else:
190 return channel
193 def _close_transport(self):
194 if os.getpid() == self.pid:
195 self.transport.close()
198 def close(self):
199 super(ParamikoHost, self).close()
200 self._close_transport()
203 @classmethod
204 def _exhaust_stream(cls, tee, output_list, recvfunc):
205 while True:
206 try:
207 output_list.append(recvfunc(cls.BUFFSIZE))
208 except socket.timeout:
209 return
210 tee.write(output_list[-1])
211 if not output_list[-1]:
212 return
215 @classmethod
216 def __send_stdin(cls, channel, stdin):
217 if not stdin or not channel.send_ready():
218 # nothing more to send or just no space to send now
219 return
221 sent = channel.send(stdin[:cls.BUFFSIZE])
222 if not sent:
223 logging.warn('Could not send a single stdin byte.')
224 else:
225 stdin = stdin[sent:]
226 if not stdin:
227 # no more stdin input, close output direction
228 channel.shutdown_write()
229 return stdin
232 def run(self, command, timeout=3600, ignore_status=False,
233 stdout_tee=utils.TEE_TO_LOGS, stderr_tee=utils.TEE_TO_LOGS,
234 connect_timeout=30, stdin=None, verbose=True, args=()):
236 Run a command on the remote host.
237 @see common_lib.hosts.host.run()
239 @param connect_timeout: connection timeout (in seconds)
240 @param options: string with additional ssh command options
241 @param verbose: log the commands
243 @raises AutoservRunError: if the command failed
244 @raises AutoservSSHTimeout: ssh connection has timed out
247 stdout = utils.get_stream_tee_file(
248 stdout_tee, utils.DEFAULT_STDOUT_LEVEL,
249 prefix=utils.STDOUT_PREFIX)
250 stderr = utils.get_stream_tee_file(
251 stderr_tee, utils.get_stderr_level(ignore_status),
252 prefix=utils.STDERR_PREFIX)
254 for arg in args:
255 command += ' "%s"' % utils.sh_escape(arg)
257 if verbose:
258 logging.debug("Running (ssh-paramiko) '%s'" % command)
260 # start up the command
261 start_time = time.time()
262 try:
263 channel = self._open_channel(timeout)
264 channel.exec_command(command)
265 except (socket.error, paramiko.SSHException, EOFError), e:
266 # This has to match the string from paramiko *exactly*.
267 if str(e) != 'Channel closed.':
268 raise error.AutoservSSHTimeout("ssh failed: %s" % e)
270 # pull in all the stdout, stderr until the command terminates
271 raw_stdout, raw_stderr = [], []
272 timed_out = False
273 while not channel.exit_status_ready():
274 if channel.recv_ready():
275 raw_stdout.append(channel.recv(self.BUFFSIZE))
276 stdout.write(raw_stdout[-1])
277 if channel.recv_stderr_ready():
278 raw_stderr.append(channel.recv_stderr(self.BUFFSIZE))
279 stderr.write(raw_stderr[-1])
280 if timeout and time.time() - start_time > timeout:
281 timed_out = True
282 break
283 stdin = self.__send_stdin(channel, stdin)
284 time.sleep(1)
286 if timed_out:
287 exit_status = -signal.SIGTERM
288 else:
289 exit_status = channel.recv_exit_status()
290 channel.settimeout(10)
291 self._exhaust_stream(stdout, raw_stdout, channel.recv)
292 self._exhaust_stream(stderr, raw_stderr, channel.recv_stderr)
293 channel.close()
294 duration = time.time() - start_time
296 # create the appropriate results
297 stdout = "".join(raw_stdout)
298 stderr = "".join(raw_stderr)
299 result = utils.CmdResult(command, stdout, stderr, exit_status,
300 duration)
301 if exit_status == -signal.SIGHUP:
302 msg = "ssh connection unexpectedly terminated"
303 raise error.AutoservRunError(msg, result)
304 if timed_out:
305 logging.warn('Paramiko command timed out after %s sec: %s', timeout,
306 command)
307 raise error.AutoservRunError("command timed out", result)
308 if not ignore_status and exit_status:
309 raise error.AutoservRunError(command, result)
310 return result