Merge pull request #1000 from bensmrs/master
[git-cola.git] / cola / utils.py
blob9de125a26a5e7a04194d0b7fcb0b92e065a12b01
1 # Copyright (C) 2007-2018 David Aguilar and contributors
2 """This module provides miscellaneous utility functions."""
3 from __future__ import division, absolute_import, unicode_literals
4 import copy
5 import os
6 import random
7 import re
8 import shlex
9 import sys
10 import tempfile
11 import time
12 import traceback
14 from . import core
15 from . import compat
17 random.seed(hash(time.time()))
20 def asint(obj, default=0):
21 """Make any value into an int, even if the cast fails"""
22 try:
23 value = int(obj)
24 except (TypeError, ValueError):
25 value = default
26 return value
29 def clamp(value, lo, hi):
30 """Clamp a value to the specified range"""
31 return min(hi, max(lo, value))
34 def epoch_millis():
35 return int(time.time() * 1000)
38 def add_parents(paths):
39 """Iterate over each item in the set and add its parent directories."""
40 all_paths = set()
41 for path in paths:
42 while '//' in path:
43 path = path.replace('//', '/')
44 all_paths.add(path)
45 if '/' in path:
46 parent_dir = dirname(path)
47 while parent_dir:
48 all_paths.add(parent_dir)
49 parent_dir = dirname(parent_dir)
50 return all_paths
53 def format_exception(e):
54 exc_type, exc_value, exc_tb = sys.exc_info()
55 details = traceback.format_exception(exc_type, exc_value, exc_tb)
56 details = '\n'.join(map(core.decode, details))
57 if hasattr(e, 'msg'):
58 msg = e.msg
59 else:
60 msg = core.decode(repr(e))
61 return (msg, details)
64 def sublist(a, b):
65 """Subtracts list b from list a and returns the resulting list."""
66 # conceptually, c = a - b
67 c = []
68 for item in a:
69 if item not in b:
70 c.append(item)
71 return c
74 __grep_cache = {}
77 def grep(pattern, items, squash=True):
78 """Greps a list for items that match a pattern
80 :param squash: If only one item matches, return just that item
81 :returns: List of matching items
83 """
84 isdict = isinstance(items, dict)
85 if pattern in __grep_cache:
86 regex = __grep_cache[pattern]
87 else:
88 regex = __grep_cache[pattern] = re.compile(pattern)
89 matched = []
90 matchdict = {}
91 for item in items:
92 match = regex.match(item)
93 if not match:
94 continue
95 groups = match.groups()
96 if not groups:
97 subitems = match.group(0)
98 else:
99 if len(groups) == 1:
100 subitems = groups[0]
101 else:
102 subitems = list(groups)
103 if isdict:
104 matchdict[item] = items[item]
105 else:
106 matched.append(subitems)
108 if isdict:
109 result = matchdict
110 elif squash and len(matched) == 1:
111 result = matched[0]
112 else:
113 result = matched
115 return result
118 def basename(path):
120 An os.path.basename() implementation that always uses '/'
122 Avoid os.path.basename because git's output always
123 uses '/' regardless of platform.
126 return path.rsplit('/', 1)[-1]
129 def strip_one(path):
130 """Strip one level of directory"""
131 return path.strip('/').split('/', 1)[-1]
134 def dirname(path, current_dir=''):
136 An os.path.dirname() implementation that always uses '/'
138 Avoid os.path.dirname because git's output always
139 uses '/' regardless of platform.
142 while '//' in path:
143 path = path.replace('//', '/')
144 path_dirname = path.rsplit('/', 1)[0]
145 if path_dirname == path:
146 return current_dir
147 return path.rsplit('/', 1)[0]
150 def splitpath(path):
151 """Split paths using '/' regardless of platform
154 return path.split('/')
157 def join(*paths):
158 """Join paths using '/' regardless of platform
161 return '/'.join(paths)
164 def pathset(path):
165 """Return all of the path components for the specified path
167 >>> pathset('foo/bar/baz') == ['foo', 'foo/bar', 'foo/bar/baz']
168 True
171 result = []
172 parts = splitpath(path)
173 prefix = ''
174 for part in parts:
175 result.append(prefix + part)
176 prefix += part + '/'
178 return result
181 def select_directory(paths):
182 """Return the first directory in a list of paths"""
183 if not paths:
184 return core.getcwd()
186 for path in paths:
187 if core.isdir(path):
188 return path
190 return os.path.dirname(paths[0])
193 def strip_prefix(prefix, string):
194 """Return string, without the prefix. Blow up if string doesn't
195 start with prefix."""
196 assert string.startswith(prefix)
197 return string[len(prefix):]
200 def sanitize(s):
201 """Removes shell metacharacters from a string."""
202 for c in """ \t!@#$%^&*()\\;,<>"'[]{}~|""":
203 s = s.replace(c, '_')
204 return s
207 def tablength(word, tabwidth):
208 """Return length of a word taking tabs into account
210 >>> tablength("\\t\\t\\t\\tX", 8)
214 return len(word.replace('\t', '')) + word.count('\t') * tabwidth
217 def _shell_split_py2(s):
218 """Python2 requires bytes inputs to shlex.split(). Returns [unicode]"""
219 try:
220 result = shlex.split(core.encode(s))
221 except ValueError:
222 result = core.encode(s).strip().split()
223 # Decode to unicode strings
224 return [core.decode(arg) for arg in result]
227 def _shell_split_py3(s):
228 """Python3 requires unicode inputs to shlex.split(). Converts to unicode"""
229 try:
230 result = shlex.split(s)
231 except ValueError:
232 result = core.decode(s).strip().split()
233 # Already unicode
234 return result
237 def shell_split(s):
238 if compat.PY2:
239 # Encode before calling split()
240 values = _shell_split_py2(s)
241 else:
242 # Python3 does not need the encode/decode dance
243 values = _shell_split_py3(s)
244 return values
247 def tmp_filename(label, suffix=''):
248 label = 'git-cola-' + label.replace('/', '-').replace('\\', '-')
249 fd = tempfile.NamedTemporaryFile(prefix=label+'-', suffix=suffix)
250 fd.close()
251 return fd.name
254 def is_linux():
255 """Is this a linux machine?"""
256 return sys.platform.startswith('linux')
259 def is_debian():
260 """Is it debian?"""
261 return os.path.exists('/usr/bin/apt-get')
264 def is_darwin():
265 """Return True on OSX."""
266 return sys.platform == 'darwin'
269 def is_win32():
270 """Return True on win32"""
271 return sys.platform == 'win32' or sys.platform == 'cygwin'
274 def expandpath(path):
275 """Expand ~user/ and environment $variables"""
276 path = os.path.expandvars(path)
277 if path.startswith('~'):
278 path = os.path.expanduser(path)
279 return path
282 class Group(object):
283 """Operate on a collection of objects as a single unit"""
285 def __init__(self, *members):
286 self._members = members
288 def __getattr__(self, name):
289 """Return a function that relays calls to the group"""
290 def relay(*args, **kwargs):
291 for member in self._members:
292 method = getattr(member, name)
293 method(*args, **kwargs)
294 setattr(self, name, relay)
295 return relay
298 class Proxy(object):
299 """Wrap an object and override attributes"""
301 def __init__(self, obj, **overrides):
302 self._obj = obj
303 for k, v in overrides.items():
304 setattr(self, k, v)
306 def __getattr__(self, name):
307 return getattr(self._obj, name)
310 def slice_fn(input_items, map_fn):
311 """Slice input_items and call map_fn over every slice
313 This exists because of "errno: Argument list too long"
316 # This comment appeared near the top of include/linux/binfmts.h
317 # in the Linux source tree:
319 # /*
320 # * MAX_ARG_PAGES defines the number of pages allocated for arguments
321 # * and envelope for the new program. 32 should suffice, this gives
322 # * a maximum env+arg of 128kB w/4KB pages!
323 # */
324 # #define MAX_ARG_PAGES 32
326 # 'size' is a heuristic to keep things highly performant by minimizing
327 # the number of slices. If we wanted it to run as few commands as
328 # possible we could call "getconf ARG_MAX" and make a better guess,
329 # but it's probably not worth the complexity (and the extra call to
330 # getconf that we can't do on Windows anyways).
332 # In my testing, getconf ARG_MAX on Mac OS X Mountain Lion reported
333 # 262144 and Debian/Linux-x86_64 reported 2097152.
335 # The hard-coded max_arg_len value is safely below both of these
336 # real-world values.
338 # 4K pages x 32 MAX_ARG_PAGES
339 max_arg_len = (32 * 4096) // 4 # allow plenty of space for the environment
340 max_filename_len = 256
341 size = max_arg_len // max_filename_len
343 status = 0
344 outs = []
345 errs = []
347 items = copy.copy(input_items)
348 while items:
349 stat, out, err = map_fn(items[:size])
350 if stat < 0:
351 status = min(stat, status)
352 else:
353 status = max(stat, status)
354 outs.append(out)
355 errs.append(err)
356 items = items[size:]
358 return (status, '\n'.join(outs), '\n'.join(errs))
361 class seq(object):
363 def __init__(self, sequence):
364 self.seq = sequence
366 def index(self, item, default=-1):
367 try:
368 idx = self.seq.index(item)
369 except ValueError:
370 idx = default
371 return idx
373 def __getitem__(self, idx):
374 return self.seq[idx]