1 import os
, sys
, time
, signal
, socket
, re
, fnmatch
, logging
, threading
4 from autotest_lib
.client
.common_lib
import utils
, error
, global_config
5 from autotest_lib
.server
import subcommand
6 from autotest_lib
.server
.hosts
import abstract_ssh
9 class ParamikoHost(abstract_ssh
.AbstractSSHHost
):
10 KEEPALIVE_TIMEOUT_SECONDS
= 30
11 CONNECT_TIMEOUT_SECONDS
= 30
12 CONNECT_TIMEOUT_RETRIES
= 3
15 def _initialize(self
, hostname
, *args
, **dargs
):
16 super(ParamikoHost
, self
)._initialize
(hostname
=hostname
, *args
, **dargs
)
18 # paramiko is very noisy, tone down the logging
19 paramiko
.util
.log_to_file("/dev/null", paramiko
.util
.ERROR
)
21 self
.keys
= self
.get_user_keys(hostname
)
27 """Given a path to a private key file, load the appropriate keyfile.
29 Tries to load the file as both an RSAKey and a DSAKey. If the file
30 cannot be loaded as either type, returns None."""
32 return paramiko
.DSSKey
.from_private_key_file(path
)
33 except paramiko
.SSHException
:
35 return paramiko
.RSAKey
.from_private_key_file(path
)
36 except paramiko
.SSHException
:
41 def _parse_config_line(line
):
42 """Given an ssh config line, return a (key, value) tuple for the
43 config value listed in the line, or (None, None)"""
44 match
= re
.match(r
"\s*(\w+)\s*=?(.*)\n", line
)
52 def get_user_keys(hostname
):
53 """Returns a mapping of path -> paramiko.PKey entries available for
54 this user. Keys are found in the default locations (~/.ssh/id_[d|r]sa)
55 as well as any IdentityFile entries in the standard ssh config files.
57 raw_identity_files
= ["~/.ssh/id_dsa", "~/.ssh/id_rsa"]
58 for config_path
in ("/etc/ssh/ssh_config", "~/.ssh/config"):
59 config_path
= os
.path
.expanduser(config_path
)
60 if not os
.path
.exists(config_path
):
63 config_lines
= open(config_path
).readlines()
64 for line
in config_lines
:
65 key
, value
= ParamikoHost
._parse
_config
_line
(line
)
68 elif (key
== "IdentityFile"
69 and fnmatch
.fnmatch(hostname
, host_pattern
)):
70 raw_identity_files
.append(value
)
72 # drop any files that use percent-escapes; we don't support them
74 UNSUPPORTED_ESCAPES
= ["%d", "%u", "%l", "%h", "%r"]
75 for path
in raw_identity_files
:
76 # skip this path if it uses % escapes
77 if sum((escape
in path
) for escape
in UNSUPPORTED_ESCAPES
):
79 path
= os
.path
.expanduser(path
)
80 if os
.path
.exists(path
):
81 identity_files
.append(path
)
83 # load up all the keys that we can and return them
85 for path
in identity_files
:
86 key
= ParamikoHost
._load
_key
(path
)
90 # load up all the ssh agent keys
91 use_sshagent
= global_config
.global_config
.get_config_value(
92 'AUTOSERV', 'use_sshagent_with_paramiko', type=bool)
94 ssh_agent
= paramiko
.Agent()
95 for i
, key
in enumerate(ssh_agent
.get_keys()):
96 user_keys
['agent-key-%d' % i
] = key
101 def _check_transport_error(self
, transport
):
102 error
= transport
.get_exception()
108 def _connect_socket(self
):
109 """Return a socket for use in instantiating a paramiko transport. Does
110 not have to be a literal socket, it can be anything that the
111 paramiko.Transport constructor accepts."""
112 return self
.hostname
, self
.port
115 def _connect_transport(self
, pkey
):
116 for _
in xrange(self
.CONNECT_TIMEOUT_RETRIES
):
117 transport
= paramiko
.Transport(self
._connect
_socket
())
118 completed
= threading
.Event()
119 transport
.start_client(completed
)
120 completed
.wait(self
.CONNECT_TIMEOUT_SECONDS
)
121 if completed
.isSet():
122 self
._check
_transport
_error
(transport
)
124 transport
.auth_publickey(self
.user
, pkey
, completed
)
125 completed
.wait(self
.CONNECT_TIMEOUT_SECONDS
)
126 if completed
.isSet():
127 self
._check
_transport
_error
(transport
)
128 if not transport
.is_authenticated():
130 raise paramiko
.AuthenticationException()
132 logging
.warn("SSH negotiation (%s:%d) timed out, retrying",
133 self
.hostname
, self
.port
)
134 # HACK: we can't count on transport.join not hanging now, either
135 transport
.join
= lambda: None
137 logging
.error("SSH negotation (%s:%d) has timed out %s times, "
138 "giving up", self
.hostname
, self
.port
,
139 self
.CONNECT_TIMEOUT_RETRIES
)
140 raise error
.AutoservSSHTimeout("SSH negotiation timed out")
143 def _init_transport(self
):
144 for path
, key
in self
.keys
.iteritems():
146 logging
.debug("Connecting with %s", path
)
147 transport
= self
._connect
_transport
(key
)
148 transport
.set_keepalive(self
.KEEPALIVE_TIMEOUT_SECONDS
)
149 self
.transport
= transport
150 self
.pid
= os
.getpid()
152 except paramiko
.AuthenticationException
:
153 logging
.debug("Authentication failure")
155 raise error
.AutoservSshPermissionDeniedError(
156 "Permission denied using all keys available to ParamikoHost",
160 def _open_channel(self
, timeout
):
161 start_time
= time
.time()
162 if os
.getpid() != self
.pid
:
163 if self
.pid
is not None:
164 # HACK: paramiko tries to join() on its worker thread
165 # and this just hangs on linux after a fork()
166 self
.transport
.join
= lambda: None
167 self
.transport
.atfork()
168 join_hook
= lambda cmd
: self
._close
_transport
()
169 subcommand
.subcommand
.register_join_hook(join_hook
)
170 logging
.debug("Reopening SSH connection after a process fork")
171 self
._init
_transport
()
175 channel
= self
.transport
.open_session()
176 except (socket
.error
, paramiko
.SSHException
, EOFError), e
:
177 logging
.warn("Exception occured while opening session: %s", e
)
178 if time
.time() - start_time
>= timeout
:
179 raise error
.AutoservSSHTimeout("ssh failed: %s" % e
)
182 # we couldn't get a channel; re-initing transport should fix that
184 self
.transport
.close()
186 logging
.debug("paramiko.Transport.close failed with %s", e
)
187 self
._init
_transport
()
188 return self
.transport
.open_session()
193 def _close_transport(self
):
194 if os
.getpid() == self
.pid
:
195 self
.transport
.close()
199 super(ParamikoHost
, self
).close()
200 self
._close
_transport
()
204 def _exhaust_stream(cls
, tee
, output_list
, recvfunc
):
207 output_list
.append(recvfunc(cls
.BUFFSIZE
))
208 except socket
.timeout
:
210 tee
.write(output_list
[-1])
211 if not output_list
[-1]:
216 def __send_stdin(cls
, channel
, stdin
):
217 if not stdin
or not channel
.send_ready():
218 # nothing more to send or just no space to send now
221 sent
= channel
.send(stdin
[:cls
.BUFFSIZE
])
223 logging
.warn('Could not send a single stdin byte.')
227 # no more stdin input, close output direction
228 channel
.shutdown_write()
232 def run(self
, command
, timeout
=3600, ignore_status
=False,
233 stdout_tee
=utils
.TEE_TO_LOGS
, stderr_tee
=utils
.TEE_TO_LOGS
,
234 connect_timeout
=30, stdin
=None, verbose
=True, args
=()):
236 Run a command on the remote host.
237 @see common_lib.hosts.host.run()
239 @param connect_timeout: connection timeout (in seconds)
240 @param options: string with additional ssh command options
241 @param verbose: log the commands
243 @raises AutoservRunError: if the command failed
244 @raises AutoservSSHTimeout: ssh connection has timed out
247 stdout
= utils
.get_stream_tee_file(
248 stdout_tee
, utils
.DEFAULT_STDOUT_LEVEL
,
249 prefix
=utils
.STDOUT_PREFIX
)
250 stderr
= utils
.get_stream_tee_file(
251 stderr_tee
, utils
.get_stderr_level(ignore_status
),
252 prefix
=utils
.STDERR_PREFIX
)
255 command
+= ' "%s"' % utils
.sh_escape(arg
)
258 logging
.debug("Running (ssh-paramiko) '%s'" % command
)
260 # start up the command
261 start_time
= time
.time()
263 channel
= self
._open
_channel
(timeout
)
264 channel
.exec_command(command
)
265 except (socket
.error
, paramiko
.SSHException
, EOFError), e
:
266 # This has to match the string from paramiko *exactly*.
267 if str(e
) != 'Channel closed.':
268 raise error
.AutoservSSHTimeout("ssh failed: %s" % e
)
270 # pull in all the stdout, stderr until the command terminates
271 raw_stdout
, raw_stderr
= [], []
273 while not channel
.exit_status_ready():
274 if channel
.recv_ready():
275 raw_stdout
.append(channel
.recv(self
.BUFFSIZE
))
276 stdout
.write(raw_stdout
[-1])
277 if channel
.recv_stderr_ready():
278 raw_stderr
.append(channel
.recv_stderr(self
.BUFFSIZE
))
279 stderr
.write(raw_stderr
[-1])
280 if timeout
and time
.time() - start_time
> timeout
:
283 stdin
= self
.__send
_stdin
(channel
, stdin
)
287 exit_status
= -signal
.SIGTERM
289 exit_status
= channel
.recv_exit_status()
290 channel
.settimeout(10)
291 self
._exhaust
_stream
(stdout
, raw_stdout
, channel
.recv
)
292 self
._exhaust
_stream
(stderr
, raw_stderr
, channel
.recv_stderr
)
294 duration
= time
.time() - start_time
296 # create the appropriate results
297 stdout
= "".join(raw_stdout
)
298 stderr
= "".join(raw_stderr
)
299 result
= utils
.CmdResult(command
, stdout
, stderr
, exit_status
,
301 if exit_status
== -signal
.SIGHUP
:
302 msg
= "ssh connection unexpectedly terminated"
303 raise error
.AutoservRunError(msg
, result
)
305 logging
.warn('Paramiko command timed out after %s sec: %s', timeout
,
307 raise error
.AutoservRunError("command timed out", result
)
308 if not ignore_status
and exit_status
:
309 raise error
.AutoservRunError(command
, result
)