maint: format code using black
[git-cola.git] / cola / utils.py
blob43838c7515c693192d6721c5be31188943d335b2
1 # Copyright (C) 2007-2018 David Aguilar and contributors
2 """This module provides miscellaneous utility functions."""
3 from __future__ import division, absolute_import, unicode_literals
4 import copy
5 import os
6 import random
7 import re
8 import shlex
9 import sys
10 import tempfile
11 import time
12 import traceback
14 from . import core
15 from . import compat
17 random.seed(hash(time.time()))
20 def asint(obj, default=0):
21 """Make any value into an int, even if the cast fails"""
22 try:
23 value = int(obj)
24 except (TypeError, ValueError):
25 value = default
26 return value
29 def clamp(value, lo, hi):
30 """Clamp a value to the specified range"""
31 return min(hi, max(lo, value))
34 def epoch_millis():
35 return int(time.time() * 1000)
38 def add_parents(paths):
39 """Iterate over each item in the set and add its parent directories."""
40 all_paths = set()
41 for path in paths:
42 while '//' in path:
43 path = path.replace('//', '/')
44 all_paths.add(path)
45 if '/' in path:
46 parent_dir = dirname(path)
47 while parent_dir:
48 all_paths.add(parent_dir)
49 parent_dir = dirname(parent_dir)
50 return all_paths
53 def format_exception(e):
54 exc_type, exc_value, exc_tb = sys.exc_info()
55 details = traceback.format_exception(exc_type, exc_value, exc_tb)
56 details = '\n'.join(map(core.decode, details))
57 if hasattr(e, 'msg'):
58 msg = e.msg
59 else:
60 msg = core.decode(repr(e))
61 return (msg, details)
64 def sublist(a, b):
65 """Subtracts list b from list a and returns the resulting list."""
66 # conceptually, c = a - b
67 c = []
68 for item in a:
69 if item not in b:
70 c.append(item)
71 return c
74 __grep_cache = {}
77 def grep(pattern, items, squash=True):
78 """Greps a list for items that match a pattern
80 :param squash: If only one item matches, return just that item
81 :returns: List of matching items
83 """
84 isdict = isinstance(items, dict)
85 if pattern in __grep_cache:
86 regex = __grep_cache[pattern]
87 else:
88 regex = __grep_cache[pattern] = re.compile(pattern)
89 matched = []
90 matchdict = {}
91 for item in items:
92 match = regex.match(item)
93 if not match:
94 continue
95 groups = match.groups()
96 if not groups:
97 subitems = match.group(0)
98 else:
99 if len(groups) == 1:
100 subitems = groups[0]
101 else:
102 subitems = list(groups)
103 if isdict:
104 matchdict[item] = items[item]
105 else:
106 matched.append(subitems)
108 if isdict:
109 result = matchdict
110 elif squash and len(matched) == 1:
111 result = matched[0]
112 else:
113 result = matched
115 return result
118 def basename(path):
120 An os.path.basename() implementation that always uses '/'
122 Avoid os.path.basename because git's output always
123 uses '/' regardless of platform.
126 return path.rsplit('/', 1)[-1]
129 def strip_one(path):
130 """Strip one level of directory"""
131 return path.strip('/').split('/', 1)[-1]
134 def dirname(path, current_dir=''):
136 An os.path.dirname() implementation that always uses '/'
138 Avoid os.path.dirname because git's output always
139 uses '/' regardless of platform.
142 while '//' in path:
143 path = path.replace('//', '/')
144 path_dirname = path.rsplit('/', 1)[0]
145 if path_dirname == path:
146 return current_dir
147 return path.rsplit('/', 1)[0]
150 def splitpath(path):
151 """Split paths using '/' regardless of platform"""
152 return path.split('/')
155 def join(*paths):
156 """Join paths using '/' regardless of platform"""
157 return '/'.join(paths)
160 def pathset(path):
161 """Return all of the path components for the specified path
163 >>> pathset('foo/bar/baz') == ['foo', 'foo/bar', 'foo/bar/baz']
164 True
167 result = []
168 parts = splitpath(path)
169 prefix = ''
170 for part in parts:
171 result.append(prefix + part)
172 prefix += part + '/'
174 return result
177 def select_directory(paths):
178 """Return the first directory in a list of paths"""
179 if not paths:
180 return core.getcwd()
182 for path in paths:
183 if core.isdir(path):
184 return path
186 return os.path.dirname(paths[0])
189 def strip_prefix(prefix, string):
190 """Return string, without the prefix. Blow up if string doesn't
191 start with prefix."""
192 assert string.startswith(prefix)
193 return string[len(prefix) :]
196 def sanitize(s):
197 """Removes shell metacharacters from a string."""
198 for c in """ \t!@#$%^&*()\\;,<>"'[]{}~|""":
199 s = s.replace(c, '_')
200 return s
203 def tablength(word, tabwidth):
204 """Return length of a word taking tabs into account
206 >>> tablength("\\t\\t\\t\\tX", 8)
210 return len(word.replace('\t', '')) + word.count('\t') * tabwidth
213 def _shell_split_py2(s):
214 """Python2 requires bytes inputs to shlex.split(). Returns [unicode]"""
215 try:
216 result = shlex.split(core.encode(s))
217 except ValueError:
218 result = core.encode(s).strip().split()
219 # Decode to unicode strings
220 return [core.decode(arg) for arg in result]
223 def _shell_split_py3(s):
224 """Python3 requires unicode inputs to shlex.split(). Converts to unicode"""
225 try:
226 result = shlex.split(s)
227 except ValueError:
228 result = core.decode(s).strip().split()
229 # Already unicode
230 return result
233 def shell_split(s):
234 if compat.PY2:
235 # Encode before calling split()
236 values = _shell_split_py2(s)
237 else:
238 # Python3 does not need the encode/decode dance
239 values = _shell_split_py3(s)
240 return values
243 def tmp_filename(label, suffix=''):
244 label = 'git-cola-' + label.replace('/', '-').replace('\\', '-')
245 fd = tempfile.NamedTemporaryFile(prefix=label + '-', suffix=suffix)
246 fd.close()
247 return fd.name
250 def is_linux():
251 """Is this a linux machine?"""
252 return sys.platform.startswith('linux')
255 def is_debian():
256 """Is it debian?"""
257 return os.path.exists('/usr/bin/apt-get')
260 def is_darwin():
261 """Return True on OSX."""
262 return sys.platform == 'darwin'
265 def is_win32():
266 """Return True on win32"""
267 return sys.platform == 'win32' or sys.platform == 'cygwin'
270 def expandpath(path):
271 """Expand ~user/ and environment $variables"""
272 path = os.path.expandvars(path)
273 if path.startswith('~'):
274 path = os.path.expanduser(path)
275 return path
278 class Group(object):
279 """Operate on a collection of objects as a single unit"""
281 def __init__(self, *members):
282 self._members = members
284 def __getattr__(self, name):
285 """Return a function that relays calls to the group"""
287 def relay(*args, **kwargs):
288 for member in self._members:
289 method = getattr(member, name)
290 method(*args, **kwargs)
292 setattr(self, name, relay)
293 return relay
296 class Proxy(object):
297 """Wrap an object and override attributes"""
299 def __init__(self, obj, **overrides):
300 self._obj = obj
301 for k, v in overrides.items():
302 setattr(self, k, v)
304 def __getattr__(self, name):
305 return getattr(self._obj, name)
308 def slice_fn(input_items, map_fn):
309 """Slice input_items and call map_fn over every slice
311 This exists because of "errno: Argument list too long"
314 # This comment appeared near the top of include/linux/binfmts.h
315 # in the Linux source tree:
317 # /*
318 # * MAX_ARG_PAGES defines the number of pages allocated for arguments
319 # * and envelope for the new program. 32 should suffice, this gives
320 # * a maximum env+arg of 128kB w/4KB pages!
321 # */
322 # #define MAX_ARG_PAGES 32
324 # 'size' is a heuristic to keep things highly performant by minimizing
325 # the number of slices. If we wanted it to run as few commands as
326 # possible we could call "getconf ARG_MAX" and make a better guess,
327 # but it's probably not worth the complexity (and the extra call to
328 # getconf that we can't do on Windows anyways).
330 # In my testing, getconf ARG_MAX on Mac OS X Mountain Lion reported
331 # 262144 and Debian/Linux-x86_64 reported 2097152.
333 # The hard-coded max_arg_len value is safely below both of these
334 # real-world values.
336 # 4K pages x 32 MAX_ARG_PAGES
337 max_arg_len = (32 * 4096) // 4 # allow plenty of space for the environment
338 max_filename_len = 256
339 size = max_arg_len // max_filename_len
341 status = 0
342 outs = []
343 errs = []
345 items = copy.copy(input_items)
346 while items:
347 stat, out, err = map_fn(items[:size])
348 if stat < 0:
349 status = min(stat, status)
350 else:
351 status = max(stat, status)
352 outs.append(out)
353 errs.append(err)
354 items = items[size:]
356 return (status, '\n'.join(outs), '\n'.join(errs))
359 class seq(object):
360 def __init__(self, sequence):
361 self.seq = sequence
363 def index(self, item, default=-1):
364 try:
365 idx = self.seq.index(item)
366 except ValueError:
367 idx = default
368 return idx
370 def __getitem__(self, idx):
371 return self.seq[idx]