1 # Copyright 2008 Google Inc, Martin J. Bligh <mbligh@google.com>,
2 # Benjamin Poirier, Ryan Stutsman
3 # Released under the GPL v2
5 Miscellaneous small functions.
7 DO NOT import this file directly - it is mixed in by server/utils.py,
11 import atexit
, os
, re
, shutil
, textwrap
, sys
, tempfile
, types
13 from autotest_lib
.client
.common_lib
import barrier
, utils
14 from autotest_lib
.server
import subcommand
17 # A dictionary of pid and a list of tmpdirs for that pid
21 def scp_remote_escape(filename
):
23 Escape special characters from a filename so that it can be passed
24 to scp (within double quotes) as a remote file.
26 Bis-quoting has to be used with scp for remote files, "bis-quoting"
28 scp does not support a newline in the filename
31 filename: the filename string to escape.
34 The escaped filename string. The required englobing double
35 quotes are NOT added and so should be added at some point by
38 escape_chars
= r
' !"$&' "'" r
'()*,:;<=>?[\]^`{|}'
42 if char
in escape_chars
:
43 new_name
.append("\\%s" % (char
,))
47 return utils
.sh_escape("".join(new_name
))
50 def get(location
, local_copy
= False):
51 """Get a file or directory to a local temporary directory.
54 location: the source of the material to get. This source may
56 * a local file or directory
58 * a python file-like object
61 The location of the file or directory where the requested
62 content was saved. This will be contained in a temporary
63 directory on the local host. If the material to get was a
64 directory, the location will contain a trailing '/'
66 tmpdir
= get_tmp_dir()
68 # location is a file-like object
69 if hasattr(location
, "read"):
70 tmpfile
= os
.path
.join(tmpdir
, "file")
71 tmpfileobj
= file(tmpfile
, 'w')
72 shutil
.copyfileobj(location
, tmpfileobj
)
76 if isinstance(location
, types
.StringTypes
):
78 if location
.startswith('http') or location
.startswith('ftp'):
79 tmpfile
= os
.path
.join(tmpdir
, os
.path
.basename(location
))
80 utils
.urlretrieve(location
, tmpfile
)
82 # location is a local path
83 elif os
.path
.exists(os
.path
.abspath(location
)):
85 if os
.path
.isdir(location
):
86 return location
.rstrip('/') + '/'
89 tmpfile
= os
.path
.join(tmpdir
, os
.path
.basename(location
))
90 if os
.path
.isdir(location
):
92 shutil
.copytree(location
, tmpfile
, symlinks
=True)
94 shutil
.copyfile(location
, tmpfile
)
96 # location is just a string, dump it to a file
98 tmpfd
, tmpfile
= tempfile
.mkstemp(dir=tmpdir
)
99 tmpfileobj
= os
.fdopen(tmpfd
, 'w')
100 tmpfileobj
.write(location
)
106 """Return the pathname of a directory on the host suitable
107 for temporary file storage.
109 The directory and its content will be deleted automatically
110 at the end of the program execution if they are still present.
112 dir_name
= tempfile
.mkdtemp(prefix
="autoserv-")
114 if not pid
in __tmp_dirs
:
116 __tmp_dirs
[pid
].append(dir_name
)
120 def __clean_tmp_dirs():
121 """Erase temporary directories that were created by the get_tmp_dir()
122 function and that are still present.
125 if pid
not in __tmp_dirs
:
127 for dir in __tmp_dirs
[pid
]:
134 atexit
.register(__clean_tmp_dirs
)
135 subcommand
.subcommand
.register_join_hook(lambda _
: __clean_tmp_dirs())
138 def unarchive(host
, source_material
):
139 """Uncompress and untar an archive on a host.
141 If the "source_material" is compresses (according to the file
142 extension) it will be uncompressed. Supported compression formats
143 are gzip and bzip2. Afterwards, if the source_material is a tar
144 archive, it will be untarred.
147 host: the host object on which the archive is located
148 source_material: the path of the archive on the host
151 The file or directory name of the unarchived source material.
152 If the material is a tar archive, it will be extracted in the
153 directory where it is and the path returned will be the first
154 entry in the archive, assuming it is the topmost directory.
155 If the material is not an archive, nothing will be done so this
156 function is "harmless" when it is "useless".
159 if (source_material
.endswith(".gz") or
160 source_material
.endswith(".gzip")):
161 host
.run('gunzip "%s"' % (utils
.sh_escape(source_material
)))
162 source_material
= ".".join(source_material
.split(".")[:-1])
163 elif source_material
.endswith("bz2"):
164 host
.run('bunzip2 "%s"' % (utils
.sh_escape(source_material
)))
165 source_material
= ".".join(source_material
.split(".")[:-1])
168 if source_material
.endswith(".tar"):
169 retval
= host
.run('tar -C "%s" -xvf "%s"' % (
170 utils
.sh_escape(os
.path
.dirname(source_material
)),
171 utils
.sh_escape(source_material
),))
172 source_material
= os
.path
.join(os
.path
.dirname(source_material
),
173 retval
.stdout
.split()[0])
175 return source_material
178 def get_server_dir():
179 path
= os
.path
.dirname(sys
.modules
['autotest_lib.server.utils'].__file
__)
180 return os
.path
.abspath(path
)
183 def find_pid(command
):
184 for line
in utils
.system_output('ps -eo pid,cmd').rstrip().split('\n'):
185 (pid
, cmd
) = line
.split(None, 1)
186 if re
.search(command
, cmd
):
191 def nohup(command
, stdout
='/dev/null', stderr
='/dev/null', background
=True,
193 cmd
= ' '.join(key
+'='+val
for key
, val
in env
.iteritems())
194 cmd
+= ' nohup ' + command
195 cmd
+= ' > %s' % stdout
199 cmd
+= ' 2> %s' % stderr
205 def default_mappings(machines
):
207 Returns a simple mapping in which all machines are assigned to the
208 same key. Provides the default behavior for
209 form_ntuples_from_machines. """
214 mappings
['ident'] = [mach
]
215 if len(machines
) > 1:
216 machines
= machines
[1:]
217 for machine
in machines
:
218 mappings
['ident'].append(machine
)
220 return (mappings
, failures
)
223 def form_ntuples_from_machines(machines
, n
=2, mapping_func
=default_mappings
):
224 """Returns a set of ntuples from machines where the machines in an
225 ntuple are in the same mapping, and a set of failures which are
226 (machine name, reason) tuples."""
228 (mappings
, failures
) = mapping_func(machines
)
230 # now run through the mappings and create n-tuples.
231 # throw out the odd guys out
233 key_machines
= mappings
[key
]
234 total_machines
= len(key_machines
)
237 while len(key_machines
) >= n
:
238 ntuples
.append(key_machines
[0:n
])
239 key_machines
= key_machines
[n
:]
241 for mach
in key_machines
:
242 failures
.append((mach
, "machine can not be tupled"))
244 return (ntuples
, failures
)
247 def parse_machine(machine
, user
='root', password
='', port
=22):
249 Parse the machine string user:pass@host:port and return it separately,
250 if the machine string is not complete, use the default parameters
255 user
, machine
= machine
.split('@', 1)
258 user
, password
= user
.split(':', 1)
261 machine
, port
= machine
.split(':', 1)
264 if not machine
or not user
:
267 return machine
, user
, password
, port
270 def get_public_key():
272 Return a valid string ssh public key for the user executing autoserv or
273 autotest. If there's no DSA or RSA public key, create a DSA keypair with
274 ssh-keygen and return it.
277 ssh_conf_path
= os
.path
.expanduser('~/.ssh')
279 dsa_public_key_path
= os
.path
.join(ssh_conf_path
, 'id_dsa.pub')
280 dsa_private_key_path
= os
.path
.join(ssh_conf_path
, 'id_dsa')
282 rsa_public_key_path
= os
.path
.join(ssh_conf_path
, 'id_rsa.pub')
283 rsa_private_key_path
= os
.path
.join(ssh_conf_path
, 'id_rsa')
285 has_dsa_keypair
= os
.path
.isfile(dsa_public_key_path
) and \
286 os
.path
.isfile(dsa_private_key_path
)
287 has_rsa_keypair
= os
.path
.isfile(rsa_public_key_path
) and \
288 os
.path
.isfile(rsa_private_key_path
)
291 print 'DSA keypair found, using it'
292 public_key_path
= dsa_public_key_path
294 elif has_rsa_keypair
:
295 print 'RSA keypair found, using it'
296 public_key_path
= rsa_public_key_path
299 print 'Neither RSA nor DSA keypair found, creating DSA ssh key pair'
300 utils
.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path
)
301 public_key_path
= dsa_public_key_path
303 public_key
= open(public_key_path
, 'r')
304 public_key_str
= public_key
.read()
307 return public_key_str
310 def get_sync_control_file(control
, host_name
, host_num
,
311 instance
, num_jobs
, port_base
=63100):
313 This function is used when there is a need to run more than one
314 job simultaneously starting exactly at the same time. It basically returns
315 a modified control file (containing the synchronization code prepended)
316 whenever it is ready to run the control file. The synchronization
317 is done using barriers to make sure that the jobs start at the same time.
319 Here is how the synchronization is done to make sure that the tests
320 start at exactly the same time on the client.
321 sc_bar is a server barrier and s_bar, c_bar are the normal barriers
323 Job1 Job2 ...... JobN
325 Server: | s_bar ...... s_bar
326 Server: | at.run() at.run() ...... at.run()
327 ----------|------------------------------------------------------
329 Client | c_bar c_bar ...... c_bar
330 Client | <run test> <run test> ...... <run test>
332 @param control: The control file which to which the above synchronization
333 code will be prepended.
334 @param host_name: The host name on which the job is going to run.
335 @param host_num: (non negative) A number to identify the machine so that
336 we have different sets of s_bar_ports for each of the machines.
337 @param instance: The number of the job
338 @param num_jobs: Total number of jobs that are going to run in parallel
339 with this job starting at the same time.
340 @param port_base: Port number that is used to derive the actual barrier
343 @returns The modified control file.
345 sc_bar_port
= port_base
346 c_bar_port
= port_base
348 print "Please provide a non negative number for the host"
350 s_bar_port
= port_base
+ 1 + host_num
# The set of s_bar_ports are
351 # the same for a given machine
354 s_bar_timeout
= c_bar_timeout
= 120
356 # The barrier code snippet is prepended into the conrol file
357 # dynamically before at.run() is called finally.
360 # jobid is the unique name used to identify the processes
361 # trying to reach the barriers
362 jobid
= "%s#%d" % (host_name
, instance
)
365 # rendvstr is a temp holder for the rendezvous list of the processes
366 for n
in range(num_jobs
):
367 rendv
.append("'%s#%d'" % (host_name
, n
))
368 rendvstr
= ",".join(rendv
)
371 # Do the setup and wait at the server barrier
372 # Clean up the tmp and the control dirs for the first instance
373 control_new
.append('if os.path.exists(job.tmpdir):')
374 control_new
.append("\t system('umount -f %s > /dev/null"
375 "2> /dev/null' % job.tmpdir,"
376 "ignore_status=True)")
377 control_new
.append("\t system('rm -rf ' + job.tmpdir)")
379 'b0 = job.barrier("%s", "sc_bar", %d, port=%d)'
380 % (jobid
, sc_bar_timeout
, sc_bar_port
))
382 'b0.rendezvous_servers("PARALLEL_MASTER", "%s")'
386 # Wait at the server barrier to wait for instance=0
387 # process to complete setup
388 b0
= barrier
.barrier("PARALLEL_MASTER", "sc_bar", sc_bar_timeout
,
390 b0
.rendezvous_servers("PARALLEL_MASTER", jobid
)
393 b1
= barrier
.barrier(jobid
, "s_bar", s_bar_timeout
,
395 b1
.rendezvous(rendvstr
)
398 # For the rest of the clients
399 b2
= barrier
.barrier(jobid
, "s_bar", s_bar_timeout
, port
=s_bar_port
)
400 b2
.rendezvous(rendvstr
)
402 # Client side barrier for all the tests to start at the same time
403 control_new
.append('b1 = job.barrier("%s", "c_bar", %d, port=%d)'
404 % (jobid
, c_bar_timeout
, c_bar_port
))
405 control_new
.append("b1.rendezvous(%s)" % rendvstr
)
407 # Stick in the rest of the control file
408 control_new
.append(control
)
410 return "\n".join(control_new
)