Absolute Worktree.directory
[stgit.git] / stgit / utils.py
blobf6f17dace10b1270a023e9bf6ea8176d8797b123
1 """Common utility functions"""
3 import os
4 import re
5 import tempfile
6 from io import open
8 from stgit.compat import environ_get
9 from stgit.config import config
10 from stgit.exception import StgException
11 from stgit.out import out
12 from stgit.run import Run
14 __copyright__ = """
15 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
17 This program is free software; you can redistribute it and/or modify
18 it under the terms of the GNU General Public License version 2 as
19 published by the Free Software Foundation.
21 This program is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU General Public License for more details.
26 You should have received a copy of the GNU General Public License
27 along with this program; if not, see http://www.gnu.org/licenses/.
28 """
31 def strip_prefix(prefix, string):
32 """Return string, without the specified prefix.
34 The string must start with the prefix.
36 """
37 assert string.startswith(prefix)
38 return string[len(prefix) :]
41 class EditorException(StgException):
42 pass
45 def get_editor():
46 for editor in [
47 environ_get('GIT_EDITOR'),
48 config.get('stgit.editor'), # legacy
49 config.get('core.editor'),
50 environ_get('VISUAL'),
51 environ_get('EDITOR'),
52 'vi',
54 if editor:
55 return editor
58 def call_editor(filename):
59 """Run the editor on the specified filename."""
60 cmd = '%s %s' % (get_editor(), filename)
61 out.start('Invoking the editor: "%s"' % cmd)
62 err = os.system(cmd)
63 if err:
64 raise EditorException('editor failed, exit code: %d' % err)
65 out.done()
68 def get_hooks_path(repository):
69 hooks_path = config.get('core.hookspath')
70 if hooks_path is None:
71 return os.path.join(repository.directory, 'hooks')
72 elif os.path.isabs(hooks_path):
73 return hooks_path
74 else:
75 return os.path.join(repository.default_worktree.directory, hooks_path)
78 def get_hook(repository, hook_name, extra_env={}):
79 hook_path = os.path.join(get_hooks_path(repository), hook_name)
80 if not (os.path.isfile(hook_path) and os.access(hook_path, os.X_OK)):
81 return None
83 prefix_dir = os.path.relpath(os.getcwd(), repository.default_worktree.directory)
84 if prefix_dir == os.curdir:
85 prefix = ''
86 else:
87 prefix = os.path.join(prefix_dir, '')
88 extra_env = add_dict(extra_env, {'GIT_PREFIX': prefix})
90 def hook(*parameters):
91 argv = [hook_path]
92 argv.extend(parameters)
94 # On Windows, run the hook using "bash" explicitly
95 if os.name != 'posix':
96 argv.insert(0, 'bash')
98 repository.default_iw.run(argv, extra_env).run()
100 hook.__name__ = str(hook_name)
101 return hook
104 def run_hook_on_bytes(hook, byte_data, *args):
105 temp = tempfile.NamedTemporaryFile('wb', prefix='stgit-hook', delete=False)
106 try:
107 with temp:
108 temp.write(byte_data)
109 hook(temp.name)
110 with open(temp.name, 'rb') as data_file:
111 return data_file.read()
112 finally:
113 os.unlink(temp.name)
116 def edit_string(s, filename, encoding='utf-8'):
117 with open(filename, 'w', encoding=encoding) as f:
118 f.write(s)
119 call_editor(filename)
120 with open(filename, encoding=encoding) as f:
121 s = f.read()
122 os.remove(filename)
123 return s
126 def edit_bytes(s, filename):
127 with open(filename, 'wb') as f:
128 f.write(s)
129 call_editor(filename)
130 with open(filename, 'rb') as f:
131 s = f.read()
132 os.remove(filename)
133 return s
136 def add_trailer(message, trailer, name, email):
137 trailer_line = '%s: %s <%s>' % (trailer, name, email)
138 return (
139 Run('git', 'interpret-trailers', '--trailer', trailer_line)
140 .raw_input(message)
141 .raw_output()
145 def parse_name_email(address):
146 """Parse an email address string.
148 Returns a tuple consisting of the name and email parsed from a
149 standard 'name <email>' or 'email (name)' string.
152 address = re.sub(r'[\\"]', r'\\\g<0>', address)
153 str_list = re.findall(r'^(.*)\s*<(.*)>\s*$', address)
154 if not str_list:
155 str_list = re.findall(r'^(.*)\s*\((.*)\)\s*$', address)
156 if not str_list:
157 return None
158 return (str_list[0][1], str_list[0][0])
159 return str_list[0]
162 # Exit codes.
163 STGIT_SUCCESS = 0 # everything's OK
164 STGIT_GENERAL_ERROR = 1 # seems to be non-command-specific error
165 STGIT_COMMAND_ERROR = 2 # seems to be a command that failed
166 STGIT_CONFLICT = 3 # merge conflict, otherwise OK
167 STGIT_BUG_ERROR = 4 # a bug in StGit
170 def add_dict(d1, d2):
171 """Return a new dict with the contents of both d1 and d2.
173 In case of conflicting mappings, d2 takes precedence.
176 d = dict(d1)
177 d.update(d2)
178 return d