Use raw strings for regexps with unescaped \'s
[stgit.git] / stgit / commands / common.py
blob1a17f2ea034147176a6032c92f40caca7cb0d974
1 # -*- coding: utf-8 -*-
2 """Function/variables common to all the commands"""
3 from __future__ import absolute_import, division, print_function
4 import email.utils
5 import os
6 import re
7 import sys
9 from stgit import stack, git
10 from stgit.config import config
11 from stgit.exception import StgException
12 from stgit.lib import git as libgit
13 from stgit.lib import log
14 from stgit.lib import stack as libstack
15 from stgit.out import out
16 from stgit.run import Run, RunException
17 from stgit.utils import (EditorException,
18 add_sign_line,
19 edit_string,
20 get_hook,
21 parse_name_email_date,
22 run_hook_on_string,
23 strip_prefix)
25 __copyright__ = """
26 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
28 This program is free software; you can redistribute it and/or modify
29 it under the terms of the GNU General Public License version 2 as
30 published by the Free Software Foundation.
32 This program is distributed in the hope that it will be useful,
33 but WITHOUT ANY WARRANTY; without even the implied warranty of
34 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
35 GNU General Public License for more details.
37 You should have received a copy of the GNU General Public License
38 along with this program; if not, see http://www.gnu.org/licenses/.
39 """
42 # Command exception class
43 class CmdException(StgException):
44 pass
46 # Utility functions
47 def parse_rev(rev):
48 """Parse a revision specification into its branch:patch parts.
49 """
50 try:
51 branch, patch = rev.split(':', 1)
52 except ValueError:
53 branch = None
54 patch = rev
56 return (branch, patch)
58 def git_id(crt_series, rev):
59 """Return the GIT id
60 """
61 # TODO: remove this function once all the occurrences were converted
62 # to git_commit()
63 repository = libstack.Repository.default()
64 return git_commit(rev, repository, crt_series.get_name()).sha1
66 def get_public_ref(branch_name):
67 """Return the public ref of the branch."""
68 public_ref = config.get('branch.%s.public' % branch_name)
69 if not public_ref:
70 public_ref = 'refs/heads/%s.public' % branch_name
71 return public_ref
73 def git_commit(name, repository, branch_name = None):
74 """Return the a Commit object if 'name' is a patch name or Git commit.
75 The patch names allowed are in the form '<branch>:<patch>' and can
76 be followed by standard symbols used by git rev-parse. If <patch>
77 is '{base}', it represents the bottom of the stack. If <patch> is
78 {public}, it represents the public branch corresponding to the stack as
79 described in the 'publish' command.
80 """
81 # Try a [branch:]patch name first
82 branch, patch = parse_rev(name)
83 if not branch:
84 branch = branch_name or repository.current_branch_name
86 # The stack base
87 if patch.startswith('{base}'):
88 base_id = repository.get_stack(branch).base.sha1
89 return repository.rev_parse(base_id +
90 strip_prefix('{base}', patch))
91 elif patch.startswith('{public}'):
92 public_ref = get_public_ref(branch)
93 return repository.rev_parse(public_ref +
94 strip_prefix('{public}', patch),
95 discard_stderr = True)
97 # Other combination of branch and patch
98 try:
99 return repository.rev_parse('patches/%s/%s' % (branch, patch),
100 discard_stderr = True)
101 except libgit.RepositoryException:
102 pass
104 # Try a Git commit
105 try:
106 return repository.rev_parse(name, discard_stderr = True)
107 except libgit.RepositoryException:
108 raise CmdException('%s: Unknown patch or revision name' % name)
110 def color_diff_flags():
111 """Return the git flags for coloured diff output if the configuration and
112 stdout allows."""
113 stdout_is_tty = (sys.stdout.isatty() and 'true') or 'false'
114 if config.get_colorbool('color.diff', stdout_is_tty) == 'true':
115 return ['--color']
116 else:
117 return []
119 def check_local_changes():
120 if git.local_changes():
121 raise CmdException('local changes in the tree. Use "refresh" or'
122 ' "reset --hard"')
124 def check_head_top_equal(crt_series):
125 if not crt_series.head_top_equal():
126 raise CmdException('HEAD and top are not the same. This can happen'
127 ' if you modify a branch with git. "stg repair'
128 ' --help" explains more about what to do next.')
130 def check_conflicts():
131 if git.get_conflicts():
132 raise CmdException('Unsolved conflicts. Please fix the conflicts'
133 ' then use "git add --update <files>" or revert the'
134 ' changes with "reset --hard".')
136 def print_crt_patch(crt_series, branch = None):
137 if not branch:
138 patch = crt_series.get_current()
139 else:
140 patch = stack.Series(branch).get_current()
142 if patch:
143 out.info('Now at patch "%s"' % patch)
144 else:
145 out.info('No patches applied')
147 def resolved_all(reset = None):
148 conflicts = git.get_conflicts()
149 git.resolved(conflicts, reset)
151 def push_patches(crt_series, patches, check_merged = False):
152 """Push multiple patches onto the stack. This function is shared
153 between the push and pull commands
155 forwarded = crt_series.forward_patches(patches)
156 if forwarded > 1:
157 out.info('Fast-forwarded patches "%s" - "%s"'
158 % (patches[0], patches[forwarded - 1]))
159 elif forwarded == 1:
160 out.info('Fast-forwarded patch "%s"' % patches[0])
162 names = patches[forwarded:]
164 # check for patches merged upstream
165 if names and check_merged:
166 out.start('Checking for patches merged upstream')
168 merged = crt_series.merged_patches(names)
170 out.done('%d found' % len(merged))
171 else:
172 merged = []
174 for p in names:
175 out.start('Pushing patch "%s"' % p)
177 if p in merged:
178 crt_series.push_empty_patch(p)
179 out.done('merged upstream')
180 else:
181 modified = crt_series.push_patch(p)
183 if crt_series.empty_patch(p):
184 out.done('empty patch')
185 elif modified:
186 out.done('modified')
187 else:
188 out.done()
190 def pop_patches(crt_series, patches, keep = False):
191 """Pop the patches in the list from the stack. It is assumed that
192 the patches are listed in the stack reverse order.
194 if len(patches) == 0:
195 out.info('Nothing to push/pop')
196 else:
197 p = patches[-1]
198 if len(patches) == 1:
199 out.start('Popping patch "%s"' % p)
200 else:
201 out.start('Popping patches "%s" - "%s"' % (patches[0], p))
202 crt_series.pop_patch(p, keep)
203 out.done()
205 def get_patch_from_list(part_name, patch_list):
206 candidates = [full for full in patch_list if part_name in full]
207 if len(candidates) >= 2:
208 out.info('Possible patches:\n %s' % '\n '.join(candidates))
209 raise CmdException('Ambiguous patch name "%s"' % part_name)
210 elif len(candidates) == 1:
211 return candidates[0]
212 else:
213 return None
215 def parse_patches(patch_args, patch_list, boundary = 0, ordered = False):
216 """Parse patch_args list for patch names in patch_list and return
217 a list. The names can be individual patches and/or in the
218 patch1..patch2 format.
220 # in case it receives a tuple
221 patch_list = list(patch_list)
222 patches = []
224 for name in patch_args:
225 pair = name.split('..')
226 for p in pair:
227 if p and p not in patch_list:
228 raise CmdException('Unknown patch name: %s' % p)
230 if len(pair) == 1:
231 # single patch name
232 pl = pair
233 elif len(pair) == 2:
234 # patch range [p1]..[p2]
235 # inclusive boundary
236 if pair[0]:
237 first = patch_list.index(pair[0])
238 else:
239 first = -1
240 # exclusive boundary
241 if pair[1]:
242 last = patch_list.index(pair[1]) + 1
243 else:
244 last = -1
246 # only cross the boundary if explicitly asked
247 if not boundary:
248 boundary = len(patch_list)
249 if first < 0:
250 if last <= boundary:
251 first = 0
252 else:
253 first = boundary
254 if last < 0:
255 if first < boundary:
256 last = boundary
257 else:
258 last = len(patch_list)
260 if last > first:
261 pl = patch_list[first:last]
262 else:
263 pl = patch_list[(last - 1):(first + 1)]
264 pl.reverse()
265 else:
266 raise CmdException('Malformed patch name: %s' % name)
268 for p in pl:
269 if p in patches:
270 raise CmdException('Duplicate patch name: %s' % p)
272 patches += pl
274 if ordered:
275 patches = [p for p in patch_list if p in patches]
277 return patches
279 def name_email(address):
280 p = email.utils.parseaddr(address)
281 if p[1]:
282 return p
283 else:
284 raise CmdException('Incorrect "name <email>"/"email (name)" string: %s'
285 % address)
287 def name_email_date(address):
288 p = parse_name_email_date(address)
289 if p:
290 return p
291 else:
292 raise CmdException('Incorrect "name <email> date" string: %s' % address)
294 def address_or_alias(addr_pair):
295 """Return a name-email tuple the e-mail address is valid or look up
296 the aliases in the config files.
298 addr = addr_pair[1]
299 if '@' in addr:
300 # it's an e-mail address
301 return addr_pair
302 alias = config.get('mail.alias.' + addr)
303 if alias:
304 # it's an alias
305 return name_email(alias)
306 raise CmdException('unknown e-mail alias: %s' % addr)
308 def prepare_rebase(crt_series):
309 # pop all patches
310 applied = crt_series.get_applied()
311 if len(applied) > 0:
312 out.start('Popping all applied patches')
313 crt_series.pop_patch(applied[0])
314 out.done()
315 return applied
317 def rebase(crt_series, target):
318 try:
319 tree_id = git_id(crt_series, target)
320 except:
321 # it might be that we use a custom rebase command with its own
322 # target type
323 tree_id = target
324 if target:
325 out.start('Rebasing to "%s"' % target)
326 else:
327 out.start('Rebasing to the default target')
328 git.rebase(tree_id = tree_id)
329 out.done()
331 def post_rebase(crt_series, applied, nopush, merged):
332 # memorize that we rebased to here
333 crt_series._set_field('orig-base', git.get_head())
334 # push the patches back
335 if not nopush:
336 push_patches(crt_series, applied, merged)
339 # Patch description/e-mail/diff parsing
341 def __end_descr(line):
342 return re.match(r'---\s*$', line) or re.match('diff -', line) or \
343 re.match('Index: ', line) or re.match('--- \w', line)
345 def __split_descr_diff(string):
346 """Return the description and the diff from the given string
348 descr = diff = ''
349 top = True
351 for line in string.split('\n'):
352 if top:
353 if not __end_descr(line):
354 descr += line + '\n'
355 continue
356 else:
357 top = False
358 diff += line + '\n'
360 return (descr.rstrip(), diff)
362 def __parse_description(descr):
363 """Parse the patch description and return the new description and
364 author information (if any).
366 subject = body = ''
367 authname = authemail = authdate = None
369 descr_lines = [line.rstrip() for line in descr.split('\n')]
370 if not descr_lines:
371 raise CmdException("Empty patch description")
373 lasthdr = 0
374 end = len(descr_lines)
375 descr_strip = 0
377 # Parse the patch header
378 for pos in range(0, end):
379 if not descr_lines[pos]:
380 continue
381 # check for a "From|Author:" line
382 if re.match(r'\s*(?:from|author):\s+', descr_lines[pos], re.I):
383 auth = re.findall(r'^.*?:\s+(.*)$', descr_lines[pos])[0]
384 authname, authemail = name_email(auth)
385 lasthdr = pos + 1
386 continue
387 # check for a "Date:" line
388 if re.match(r'\s*date:\s+', descr_lines[pos], re.I):
389 authdate = re.findall(r'^.*?:\s+(.*)$', descr_lines[pos])[0]
390 lasthdr = pos + 1
391 continue
392 if subject:
393 break
394 # get the subject
395 subject = descr_lines[pos][descr_strip:]
396 if re.match(r'commit [\da-f]{40}$', subject):
397 # 'git show' output, look for the real subject
398 subject = ''
399 descr_strip = 4
400 lasthdr = pos + 1
402 # get the body
403 if lasthdr < end:
404 body = '\n' + '\n'.join(l[descr_strip:] for l in descr_lines[lasthdr:])
406 return (subject + body, authname, authemail, authdate)
408 def parse_mail(msg):
409 """Parse the message object and return (description, authname,
410 authemail, authdate, diff)
412 from email.header import decode_header, make_header
414 def __decode_header(header):
415 """Decode a qp-encoded e-mail header as per rfc2047"""
416 try:
417 words_enc = decode_header(header)
418 hobj = make_header(words_enc)
419 except Exception as ex:
420 raise CmdException('header decoding error: %s' % str(ex))
421 return unicode(hobj).encode('utf-8')
423 # parse the headers
424 if 'from' in msg:
425 authname, authemail = name_email(__decode_header(msg['from']))
426 else:
427 authname = authemail = None
429 # '\n\t' can be found on multi-line headers
430 descr = __decode_header(msg['subject'])
431 descr = re.sub('\n[ \t]*', ' ', descr)
432 authdate = msg['date']
434 # remove the '[*PATCH*]' expression in the subject
435 if descr:
436 descr = re.findall(r'^(\[.*?[Pp][Aa][Tt][Cc][Hh].*?\])?\s*(.*)$',
437 descr)[0][1]
438 else:
439 raise CmdException('Subject: line not found')
441 # the rest of the message
442 msg_text = ''
443 for part in msg.walk():
444 if part.get_content_type() in ['text/plain',
445 'application/octet-stream']:
446 msg_text += part.get_payload(decode = True)
448 rem_descr, diff = __split_descr_diff(msg_text)
449 if rem_descr:
450 descr += '\n\n' + rem_descr
452 # parse the description for author information
453 descr, descr_authname, descr_authemail, descr_authdate = \
454 __parse_description(descr)
455 if descr_authname:
456 authname = descr_authname
457 if descr_authemail:
458 authemail = descr_authemail
459 if descr_authdate:
460 authdate = descr_authdate
462 return (descr, authname, authemail, authdate, diff)
464 def parse_patch(text, contains_diff):
465 """Parse the input text and return (description, authname,
466 authemail, authdate, diff)
468 if contains_diff:
469 (text, diff) = __split_descr_diff(text)
470 else:
471 diff = None
472 (descr, authname, authemail, authdate) = __parse_description(text)
474 # we don't yet have an agreed place for the creation date.
475 # Just return None
476 return (descr, authname, authemail, authdate, diff)
478 def readonly_constant_property(f):
479 """Decorator that converts a function that computes a value to an
480 attribute that returns the value. The value is computed only once,
481 the first time it is accessed."""
482 def new_f(self):
483 n = '__' + f.__name__
484 if not hasattr(self, n):
485 setattr(self, n, f(self))
486 return getattr(self, n)
487 return property(new_f)
489 def run_commit_msg_hook(repo, cd, editor_is_used=True):
490 """Run the commit-msg hook (if any) on a commit.
492 @param cd: The L{CommitData<stgit.lib.git.CommitData>} to run the
493 hook on.
495 Return the new L{CommitData<stgit.lib.git.CommitData>}."""
496 env = dict(cd.env)
497 if not editor_is_used:
498 env['GIT_EDITOR'] = ':'
499 commit_msg_hook = get_hook(repo, 'commit-msg', env)
501 try:
502 new_msg = run_hook_on_string(commit_msg_hook, cd.message)
503 except RunException as exc:
504 raise EditorException(str(exc))
506 return cd.set_message(new_msg)
508 def update_commit_data(cd, options):
509 """Return a new CommitData object updated according to the command line
510 options."""
511 # Set the commit message from commandline.
512 if options.message is not None:
513 cd = cd.set_message(options.message)
515 # Modify author data.
516 cd = cd.set_author(options.author(cd.author))
518 # Add Signed-off-by: or similar.
519 if options.sign_str is not None:
520 sign_str = options.sign_str
521 else:
522 sign_str = config.get("stgit.autosign")
523 if sign_str is not None:
524 cd = cd.set_message(
525 add_sign_line(cd.message, sign_str,
526 cd.committer.name, cd.committer.email))
528 # Let user edit the commit message manually, unless
529 # --save-template or --message was specified.
530 if not getattr(options, 'save_template', None) and options.message is None:
531 cd = cd.set_message(edit_string(cd.message, '.stgit-new.txt'))
533 return cd
535 class DirectoryException(StgException):
536 pass
538 class _Directory(object):
539 def __init__(self, needs_current_series = True, log = True):
540 self.needs_current_series = needs_current_series
541 self.log = log
542 @readonly_constant_property
543 def git_dir(self):
544 try:
545 return Run('git', 'rev-parse', '--git-dir'
546 ).discard_stderr().output_one_line()
547 except RunException:
548 raise DirectoryException('No git repository found')
549 @readonly_constant_property
550 def __topdir_path(self):
551 try:
552 lines = Run('git', 'rev-parse', '--show-cdup'
553 ).discard_stderr().output_lines()
554 if len(lines) == 0:
555 return '.'
556 elif len(lines) == 1:
557 return lines[0]
558 else:
559 raise RunException('Too much output')
560 except RunException:
561 raise DirectoryException('No git repository found')
562 @readonly_constant_property
563 def is_inside_git_dir(self):
564 return { 'true': True, 'false': False
565 }[Run('git', 'rev-parse', '--is-inside-git-dir'
566 ).output_one_line()]
567 @readonly_constant_property
568 def is_inside_worktree(self):
569 return { 'true': True, 'false': False
570 }[Run('git', 'rev-parse', '--is-inside-work-tree'
571 ).output_one_line()]
572 def cd_to_topdir(self):
573 os.chdir(self.__topdir_path)
574 def write_log(self, msg):
575 if self.log:
576 log.compat_log_entry(msg)
578 class DirectoryAnywhere(_Directory):
579 def setup(self):
580 pass
582 class DirectoryHasRepository(_Directory):
583 def setup(self):
584 self.git_dir # might throw an exception
585 log.compat_log_external_mods()
587 class DirectoryInWorktree(DirectoryHasRepository):
588 def setup(self):
589 DirectoryHasRepository.setup(self)
590 if not self.is_inside_worktree:
591 raise DirectoryException('Not inside a git worktree')
593 class DirectoryGotoToplevel(DirectoryInWorktree):
594 def setup(self):
595 DirectoryInWorktree.setup(self)
596 self.cd_to_topdir()
598 class DirectoryHasRepositoryLib(_Directory):
599 """For commands that use the new infrastructure in stgit.lib.*."""
600 def __init__(self):
601 self.needs_current_series = False
602 self.log = False # stgit.lib.transaction handles logging
603 def setup(self):
604 # This will throw an exception if we don't have a repository.
605 self.repository = libstack.Repository.default()