models: use symbolic-ref instead of git branch in current_branch()
[git-cola.git] / cola / models.py
blobd55d2d408b98bebc78e58d2fc429bd7fb8084951
1 # Copyright (c) 2008 David Aguilar
2 import os
3 import sys
4 import re
5 import time
6 import subprocess
7 from cStringIO import StringIO
9 from cola import git
10 from cola import utils
11 from cola import model
13 #+-------------------------------------------------------------------------
14 #+ A regex for matching the output of git(log|rev-list) --pretty=oneline
15 REV_LIST_REGEX = re.compile('([0-9a-f]+)\W(.*)')
17 class GitCola(git.Git):
18 """GitPython throws exceptions by default.
19 We suppress exceptions in favor of return values.
20 """
21 def __init__(self):
22 git.Git.__init__(self)
23 self.load_worktree(os.getcwd())
25 def load_worktree(self, path):
26 self._git_dir = path
27 self._work_tree = None
28 self.get_work_tree()
30 def get_work_tree(self):
31 if self._work_tree:
32 return self._work_tree
33 self.get_git_dir()
34 if self._git_dir:
35 curdir = self._git_dir
36 else:
37 curdir = os.getcwd()
39 if self._is_git_dir(os.path.join(curdir, '.git')):
40 return curdir
42 # Handle bare repositories
43 if (len(os.path.basename(curdir)) > 4
44 and curdir.endswith('.git')):
45 return curdir
46 if 'GIT_WORK_TREE' in os.environ:
47 self._work_tree = os.getenv('GIT_WORK_TREE')
48 if not self._work_tree or not os.path.isdir(self._work_tree):
49 if self._git_dir:
50 gitparent = os.path.join(os.path.abspath(self._git_dir), '..')
51 self._work_tree = os.path.abspath(gitparent)
52 self.set_cwd(self._work_tree)
53 return self._work_tree
55 def is_valid(self):
56 return self._git_dir and self._is_git_dir(self._git_dir)
58 def get_git_dir(self):
59 if self.is_valid():
60 return self._git_dir
61 if 'GIT_DIR' in os.environ:
62 self._git_dir = os.getenv('GIT_DIR')
63 if self._git_dir:
64 curpath = os.path.abspath(self._git_dir)
65 else:
66 curpath = os.path.abspath(os.getcwd())
67 # Search for a .git directory
68 while curpath:
69 if self._is_git_dir(curpath):
70 self._git_dir = curpath
71 break
72 gitpath = os.path.join(curpath, '.git')
73 if self._is_git_dir(gitpath):
74 self._git_dir = gitpath
75 break
76 curpath, dummy = os.path.split(curpath)
77 if not dummy:
78 break
79 return self._git_dir
81 def _is_git_dir(self, d):
82 """ This is taken from the git setup.c:is_git_directory
83 function."""
84 if (os.path.isdir(d)
85 and os.path.isdir(os.path.join(d, 'objects'))
86 and os.path.isdir(os.path.join(d, 'refs'))):
87 headref = os.path.join(d, 'HEAD')
88 return (os.path.isfile(headref)
89 or (os.path.islink(headref)
90 and os.readlink(headref).startswith('refs')))
91 return False
93 def eval_path(path):
94 """handles quoted paths."""
95 if path.startswith('"') and path.endswith('"'):
96 return eval(path).decode('utf-8')
97 else:
98 return path
100 class Model(model.Model):
101 """Provides a friendly wrapper for doing commit git operations."""
103 def clone(self):
104 worktree = self.git.get_work_tree()
105 clone = model.Model.clone(self)
106 clone.use_worktree(worktree)
107 return clone
109 def use_worktree(self, worktree):
110 self.git.load_worktree(worktree)
111 is_valid = self.git.is_valid()
112 if is_valid:
113 self.__init_config_data()
114 return is_valid
116 def init(self):
117 """Reads git repository settings and sets several methods
118 so that they refer to the git module. This object
119 encapsulates cola's interaction with git."""
121 # Initialize the git command object
122 self.git = GitCola()
123 self.partially_staged = set()
125 self.fetch_helper = self.gen_remote_helper(self.git.fetch)
126 self.push_helper = self.gen_remote_helper(self.git.push)
127 self.pull_helper = self.gen_remote_helper(self.git.pull)
129 self.create(
130 #####################################################
131 # Used in various places
132 currentbranch = '',
133 remotes = [],
134 remotename = '',
135 local_branch = '',
136 remote_branch = '',
137 search_text = '',
139 #####################################################
140 # Used primarily by the main UI
141 commitmsg = '',
142 modified = [],
143 staged = [],
144 unstaged = [],
145 untracked = [],
146 unmerged = [],
147 show_untracked = True,
149 #####################################################
150 # Used by the create branch dialog
151 revision = '',
152 local_branches = [],
153 remote_branches = [],
154 tags = [],
156 #####################################################
157 # Used by the commit/repo browser
158 directory = '',
159 revisions = [],
160 summaries = [],
162 # These are parallel lists
163 types = [],
164 sha1s = [],
165 names = [],
167 # All items below here are re-calculated in
168 # init_browser_data()
169 directories = [],
170 directory_entries = {},
172 # These are also parallel lists
173 subtree_types = [],
174 subtree_sha1s = [],
175 subtree_names = [],
178 def __init_config_data(self):
179 """Reads git config --list and creates parameters
180 for each setting."""
181 # These parameters are saved in .gitconfig,
182 # so ideally these should be as short as possible.
184 # config items that are controllable globally
185 # and per-repository
186 self.__local_and_global_defaults = {
187 'user_name': '',
188 'user_email': '',
189 'merge_summary': False,
190 'merge_diffstat': True,
191 'merge_verbosity': 2,
192 'gui_diffcontext': 3,
193 'gui_pruneduringfetch': False,
195 # config items that are purely git config --global settings
196 self.__global_defaults = {
197 'cola_geometry':'',
198 'cola_fontui': '',
199 'cola_fontuisize': 12,
200 'cola_fontdiff': '',
201 'cola_fontdiffsize': 12,
202 'cola_savewindowsettings': False,
203 'merge_keepbackup': True,
204 'merge_tool': os.getenv('MERGETOOL', 'xxdiff'),
205 'gui_editor': os.getenv('EDITOR', 'gvim'),
206 'gui_historybrowser': 'gitk',
209 local_dict = self.config_dict(local=True)
210 global_dict = self.config_dict(local=False)
212 for k,v in local_dict.iteritems():
213 self.set_param('local_'+k, v)
214 for k,v in global_dict.iteritems():
215 self.set_param('global_'+k, v)
216 if k not in local_dict:
217 local_dict[k]=v
218 self.set_param('local_'+k, v)
220 # Bootstrap the internal font*size variables
221 for param in ('global_cola_fontui', 'global_cola_fontdiff'):
222 setdefault = True
223 if hasattr(self, param):
224 font = self.get_param(param)
225 if font:
226 setdefault = False
227 size = int(font.split(',')[1])
228 self.set_param(param+'size', size)
229 param = param[len('global_'):]
230 global_dict[param] = font
231 global_dict[param+'size'] = size
233 # Load defaults for all undefined items
234 local_and_global_defaults = self.__local_and_global_defaults
235 for k,v in local_and_global_defaults.iteritems():
236 if k not in local_dict:
237 self.set_param('local_'+k, v)
238 if k not in global_dict:
239 self.set_param('global_'+k, v)
241 global_defaults = self.__global_defaults
242 for k,v in global_defaults.iteritems():
243 if k not in global_dict:
244 self.set_param('global_'+k, v)
246 # Load the diff context
247 self.diff_context = self.local_gui_diffcontext
249 def get_global_config(self, key):
250 return getattr(self, 'global_'+key.replace('.', '_'))
252 def get_cola_config(self, key):
253 return getattr(self, 'global_cola_'+key)
255 def get_gui_config(self, key):
256 return getattr(self, 'global_gui_'+key)
258 def get_default_remote(self):
259 branch = self.get_currentbranch()
260 branchconfig = 'local_branch_%s_remote' % branch
261 if branchconfig in self.get_param_names():
262 remote = self.get_param(branchconfig)
263 else:
264 remote = 'origin'
265 return remote
267 def get_corresponding_remote_ref(self):
268 remote = self.get_default_remote()
269 branch = self.get_currentbranch()
270 best_match = '%s/%s' % (remote, branch)
271 remote_branches = self.get_remote_branches()
272 if not remote_branches:
273 return remote
274 for rb in remote_branches:
275 if rb == best_match:
276 return rb
277 return remote_branches[0]
279 def get_diff_filenames(self, arg):
280 diff_zstr = self.git.diff(arg, name_only=True, z=True).rstrip('\0')
281 return [ f.decode('utf-8') for f in diff_zstr.split('\0') if f ]
283 def branch_list(self, remote=False):
284 branches = map(lambda x: x.lstrip('* '),
285 self.git.branch(r=remote).splitlines())
286 if remote:
287 remotes = []
288 for branch in branches:
289 if branch.endswith('/HEAD'):
290 continue
291 remotes.append(branch)
292 return remotes
293 return branches
295 def get_config_params(self):
296 params = []
297 params.extend(map(lambda x: 'local_' + x,
298 self.__local_and_global_defaults.keys()))
299 params.extend(map(lambda x: 'global_' + x,
300 self.__local_and_global_defaults.keys()))
301 params.extend(map(lambda x: 'global_' + x,
302 self.__global_defaults.keys()))
303 return [ p for p in params if not p.endswith('size') ]
305 def save_config_param(self, param):
306 if param not in self.get_config_params():
307 return
308 value = self.get_param(param)
309 if param == 'local_gui_diffcontext':
310 self.diff_context = value
311 if param.startswith('local_'):
312 param = param[len('local_'):]
313 is_local = True
314 elif param.startswith('global_'):
315 param = param[len('global_'):]
316 is_local = False
317 else:
318 raise Exception("Invalid param '%s' passed to " % param
319 +'save_config_param()')
320 param = param.replace('_', '.') # model -> git
321 return self.config_set(param, value, local=is_local)
323 def init_browser_data(self):
324 """This scans over self.(names, sha1s, types) to generate
325 directories, directory_entries, and subtree_*"""
327 # Collect data for the model
328 if not self.get_currentbranch(): return
330 self.subtree_types = []
331 self.subtree_sha1s = []
332 self.subtree_names = []
333 self.directories = []
334 self.directory_entries = {}
336 # Lookup the tree info
337 tree_info = self.parse_ls_tree(self.get_currentbranch())
339 self.set_types(map( lambda(x): x[1], tree_info ))
340 self.set_sha1s(map( lambda(x): x[2], tree_info ))
341 self.set_names(map( lambda(x): x[3], tree_info ))
343 if self.directory: self.directories.append('..')
345 dir_entries = self.directory_entries
346 dir_regex = re.compile('([^/]+)/')
347 dirs_seen = {}
348 subdirs_seen = {}
350 for idx, name in enumerate(self.names):
351 if not name.startswith(self.directory):
352 continue
353 name = name[ len(self.directory): ]
354 if name.count('/'):
355 # This is a directory...
356 match = dir_regex.match(name)
357 if not match:
358 continue
359 dirent = match.group(1) + '/'
360 if dirent not in self.directory_entries:
361 self.directory_entries[dirent] = []
363 if dirent not in dirs_seen:
364 dirs_seen[dirent] = True
365 self.directories.append(dirent)
367 entry = name.replace(dirent, '')
368 entry_match = dir_regex.match(entry)
369 if entry_match:
370 subdir = entry_match.group(1) + '/'
371 if subdir in subdirs_seen:
372 continue
373 subdirs_seen[subdir] = True
374 dir_entries[dirent].append(subdir)
375 else:
376 dir_entries[dirent].append(entry)
377 else:
378 self.subtree_types.append(self.types[idx])
379 self.subtree_sha1s.append(self.sha1s[idx])
380 self.subtree_names.append(name)
382 def add_or_remove(self, *to_process):
383 """Invokes 'git add' to index the filenames in to_process that exist
384 and 'git rm' for those that do not exist."""
386 if not to_process:
387 return 'No files to add or remove.'
389 to_add = []
390 to_remove = []
392 for filename in to_process:
393 encfilename = filename.encode('utf-8')
394 if os.path.exists(encfilename):
395 to_add.append(filename)
397 if to_add:
398 output = self.git.add(v=True, *to_add)
399 else:
400 output = ''
402 if len(to_add) == len(to_process):
403 # to_process only contained unremoved files --
404 # short-circuit the removal checks
405 return output
407 # Process files to remote
408 for filename in to_process:
409 if not os.path.exists(filename):
410 to_remove.append(filename)
411 output + '\n\n' + self.git.rm(*to_remove)
413 def get_editor(self):
414 return self.get_gui_config('editor')
416 def get_mergetool(self):
417 return self.get_global_config('merge.tool')
419 def get_history_browser(self):
420 return self.get_gui_config('historybrowser')
422 def remember_gui_settings(self):
423 return self.get_cola_config('savewindowsettings')
425 def get_tree_node(self, idx):
426 return (self.get_types()[idx],
427 self.get_sha1s()[idx],
428 self.get_names()[idx] )
430 def get_subtree_node(self, idx):
431 return (self.get_subtree_types()[idx],
432 self.get_subtree_sha1s()[idx],
433 self.get_subtree_names()[idx] )
435 def get_all_branches(self):
436 return (self.get_local_branches() + self.get_remote_branches())
438 def set_remote(self, remote):
439 if not remote:
440 return
441 self.set_param('remote', remote)
442 branches = utils.grep('%s/\S+$' % remote,
443 self.branch_list(remote=True),
444 squash=False)
445 self.set_remote_branches(branches)
447 def add_signoff(self,*rest):
448 """Adds a standard Signed-off by: tag to the end
449 of the current commit message."""
450 msg = self.get_commitmsg()
451 signoff =('\n\nSigned-off-by: %s <%s>\n'
452 % (self.get_local_user_name(), self.get_local_user_email()))
453 if signoff not in msg:
454 self.set_commitmsg(msg + signoff)
456 def apply_diff(self, filename):
457 return self.git.apply(filename, index=True, cached=True)
459 def apply_diff_to_worktree(self, filename):
460 return self.git.apply(filename)
462 def load_commitmsg(self, path):
463 file = open(path, 'r')
464 contents = file.read().decode('utf-8')
465 file.close()
466 self.set_commitmsg(contents)
468 def get_prev_commitmsg(self,*rest):
469 """Queries git for the latest commit message and sets it in
470 self.commitmsg."""
471 commit_msg = []
472 commit_lines = self.git.show('HEAD').decode('utf-8').split('\n')
473 for idx, msg in enumerate(commit_lines):
474 if idx < 4:
475 continue
476 msg = msg.lstrip()
477 if msg.startswith('diff --git'):
478 commit_msg.pop()
479 break
480 commit_msg.append(msg)
481 self.set_commitmsg('\n'.join(commit_msg).rstrip())
483 def load_commitmsg_template(self):
484 try:
485 template = self.get_global_config('commit.template')
486 except AttributeError:
487 return
488 self.load_commitmsg(template)
490 def update_status(self, amend=False):
491 # This allows us to defer notification until the
492 # we finish processing data
493 notify_enabled = self.get_notify()
494 self.set_notify(False)
496 # Reset the staged and unstaged model lists
497 # NOTE: the model's unstaged list is used to
498 # hold both modified and untracked files.
499 self.staged = []
500 self.modified = []
501 self.untracked = []
503 # Read git status items
504 (staged_items,
505 modified_items,
506 untracked_items,
507 unmerged_items) = self.get_workdir_state(amend=amend)
509 # Gather items to be committed
510 for staged in staged_items:
511 if staged not in self.get_staged():
512 self.add_staged(staged)
514 # Gather unindexed items
515 for modified in modified_items:
516 if modified not in self.get_modified():
517 self.add_modified(modified)
519 # Gather untracked items
520 for untracked in untracked_items:
521 if untracked not in self.get_untracked():
522 self.add_untracked(untracked)
524 # Gather unmerged items
525 for unmerged in unmerged_items:
526 if unmerged not in self.get_unmerged():
527 self.add_unmerged(unmerged)
529 self.set_currentbranch(self.current_branch())
530 if self.get_show_untracked():
531 self.set_unstaged(self.get_modified() + self.get_unmerged() +
532 self.get_untracked())
533 else:
534 self.set_unstaged(self.get_modified() + self.get_unmerged())
535 self.set_remotes(self.git.remote().splitlines())
536 self.set_remote_branches(self.branch_list(remote=True))
537 self.set_local_branches(self.branch_list(remote=False))
538 self.set_tags(self.git.tag().splitlines())
539 self.set_revision('')
540 self.set_local_branch('')
541 self.set_remote_branch('')
542 # Re-enable notifications and emit changes
543 self.set_notify(notify_enabled)
544 self.notify_observers('staged','unstaged')
546 def delete_branch(self, branch):
547 return self.git.branch(branch, D=True)
549 def get_revision_sha1(self, idx):
550 return self.get_revisions()[idx]
552 def apply_font_size(self, param, default):
553 old_font = self.get_param(param)
554 if not old_font:
555 old_font = default
556 size = self.get_param(param+'size')
557 props = old_font.split(',')
558 props[1] = str(size)
559 new_font = ','.join(props)
561 self.set_param(param, new_font)
563 def get_commit_diff(self, sha1):
564 commit = self.git.show(sha1)
565 first_newline = commit.index('\n')
566 if commit[first_newline+1:].startswith('Merge:'):
567 return (commit + '\n\n'
568 + self.diff_helper(commit=sha1,
569 cached=False,
570 suppress_header=False))
571 else:
572 return commit
574 def get_filename(self, idx, staged=True):
575 try:
576 if staged:
577 return self.get_staged()[idx]
578 else:
579 return self.get_unstaged()[idx]
580 except IndexError:
581 return None
583 def get_diff_details(self, idx, ref, staged=True):
584 filename = self.get_filename(idx, staged=staged)
585 if not filename:
586 return (None, None, None)
587 encfilename = filename.encode('utf-8')
588 if staged:
589 if os.path.exists(encfilename):
590 status = 'Staged for commit'
591 else:
592 status = 'Staged for removal'
593 diff = self.diff_helper(filename=filename,
594 ref=ref,
595 cached=True)
596 else:
597 if os.path.isdir(encfilename):
598 status = 'Untracked directory'
599 diff = '\n'.join(os.listdir(filename))
601 elif filename in self.get_unmerged():
602 status = 'Unmerged'
603 diff = ('@@@+-+-+-+-+-+-+-+-+-+-+ UNMERGED +-+-+-+-+-+-+-+-+-+-+@@@\n\n'
604 '>>> %s is unmerged.\n' % filename +
605 'Right-click on the filename '
606 'to launch "git mergetool".\n\n\n')
607 diff += self.diff_helper(filename=filename,
608 cached=False,
609 patch_with_raw=False)
610 elif filename in self.get_modified():
611 status = 'Modified, not staged'
612 diff = self.diff_helper(filename=filename,
613 cached=False)
614 else:
615 status = 'Untracked, not staged'
616 diff = 'SHA1: ' + self.git.hash_object(filename)
617 return diff, status, filename
619 def stage_modified(self):
620 output = self.git.add(v=True, *self.get_modified())
621 self.update_status()
622 return output
624 def stage_untracked(self):
625 output = self.git.add(self.get_untracked())
626 self.update_status()
627 return output
629 def reset(self, *items):
630 output = self.git.reset('--', *items)
631 self.update_status()
632 return output
634 def unstage_all(self):
635 self.git.reset('--', *self.get_staged())
636 self.update_status()
638 def save_gui_settings(self):
639 self.config_set('cola.geometry', utils.get_geom(), local=False)
641 def config_set(self, key=None, value=None, local=True):
642 if key and value is not None:
643 # git config category.key value
644 strval = unicode(value)
645 if type(value) is bool:
646 # git uses "true" and "false"
647 strval = strval.lower()
648 if local:
649 argv = [ key, strval ]
650 else:
651 argv = [ '--global', key, strval ]
652 return self.git.config(*argv)
653 else:
654 msg = "oops in config_set(key=%s,value=%s,local=%s"
655 raise Exception(msg % (key, value, local))
657 def config_dict(self, local=True):
658 """parses the lines from git config --list into a dictionary"""
660 kwargs = {
661 'list': True,
662 'global': not local, # global is a python keyword
664 config_lines = self.git.config(**kwargs).splitlines()
665 newdict = {}
666 for line in config_lines:
667 k, v = line.split('=', 1)
668 v = v.decode('utf-8')
669 k = k.replace('.','_') # git -> model
670 if v == 'true' or v == 'false':
671 v = bool(eval(v.title()))
672 try:
673 v = int(eval(v))
674 except:
675 pass
676 newdict[k]=v
677 return newdict
679 def commit_with_msg(self, msg, amend=False):
680 """Creates a git commit."""
682 if not msg.endswith('\n'):
683 msg += '\n'
684 # Sure, this is a potential "security risk," but if someone
685 # is trying to intercept/re-write commit messages on your system,
686 # then you probably have bigger problems to worry about.
687 tmpfile = self.get_tmp_filename()
689 # Create the commit message file
690 fh = open(tmpfile, 'w')
691 fh.write(msg)
692 fh.close()
694 # Run 'git commit'
695 (status, stdout, stderr) = self.git.commit(F=tmpfile,
696 v=True,
697 amend=amend,
698 with_extended_output=True)
699 os.unlink(tmpfile)
701 return (status, stdout+stderr)
704 def diffindex(self):
705 return self.git.diff(unified=self.diff_context,
706 stat=True,
707 cached=True)
709 def get_tmp_dir(self):
710 # Allow TMPDIR/TMP with a fallback to /tmp
711 return os.environ.get('TMP', os.environ.get('TMPDIR', '/tmp'))
713 def get_tmp_file_pattern(self):
714 return os.path.join(self.get_tmp_dir(), '*.git-cola.%s.*' % os.getpid())
716 def get_tmp_filename(self, prefix=''):
717 basename = ((prefix+'.git-cola.%s.%s'
718 % (os.getpid(), time.time())))
719 basename = basename.replace('/', '-')
720 basename = basename.replace('\\', '-')
721 tmpdir = self.get_tmp_dir()
722 return os.path.join(tmpdir, basename)
724 def log_helper(self, all=False):
726 Returns a pair of parallel arrays listing the revision sha1's
727 and commit summaries.
729 revs = []
730 summaries = []
731 regex = REV_LIST_REGEX
732 output = self.git.log(pretty='oneline', all=all)
733 for line in output.splitlines():
734 match = regex.match(line)
735 if match:
736 revs.append(match.group(1))
737 summaries.append(match.group(2))
738 return (revs, summaries)
740 def parse_rev_list(self, raw_revs):
741 revs = []
742 for line in raw_revs.splitlines():
743 match = REV_LIST_REGEX.match(line)
744 if match:
745 rev_id = match.group(1)
746 summary = match.group(2)
747 revs.append((rev_id, summary,))
748 return revs
750 def rev_list_range(self, start, end):
751 range = '%s..%s' % (start, end)
752 raw_revs = self.git.rev_list(range, pretty='oneline')
753 return self.parse_rev_list(raw_revs)
755 def diff_helper(self,
756 commit=None,
757 branch=None,
758 ref = None,
759 endref = None,
760 filename=None,
761 cached=True,
762 with_diff_header=False,
763 suppress_header=True,
764 reverse=False,
765 patch_with_raw=True):
766 "Invokes git diff on a filepath."
767 if commit:
768 ref, endref = commit+'^', commit
769 argv = []
770 if ref and endref:
771 argv.append('%s..%s' % (ref, endref))
772 elif ref:
773 argv.append(ref)
774 elif branch:
775 argv.append(branch)
777 if filename:
778 argv.append('--')
779 if type(filename) is list:
780 argv.extend(filename)
781 else:
782 argv.append(filename)
784 output = StringIO()
785 start = False
786 del_tag = 'deleted file mode '
788 headers = []
789 deleted = cached and not os.path.exists(filename.encode('utf-8'))
791 diffoutput = self.git.diff(R=reverse,
792 cached=cached,
793 patch_with_raw=patch_with_raw,
794 unified=self.diff_context,
795 with_raw_output=True,
796 *argv)
797 diff = diffoutput.splitlines()
798 for line in diff:
799 line = unicode(line.decode('utf-8'))
800 if not start and '@@' == line[:2] and '@@' in line[2:]:
801 start = True
802 if start or(deleted and del_tag in line):
803 output.write(line.encode('utf-8') + '\n')
804 else:
805 if with_diff_header:
806 headers.append(line)
807 elif not suppress_header:
808 output.write(line.encode('utf-8') + '\n')
810 result = output.getvalue().decode('utf-8')
811 output.close()
813 if with_diff_header:
814 return('\n'.join(headers), result)
815 else:
816 return result
818 def git_repo_path(self, *subpaths):
819 paths = [ self.git.get_git_dir() ]
820 paths.extend(subpaths)
821 return os.path.realpath(os.path.join(*paths))
823 def get_merge_message_path(self):
824 for file in ('MERGE_MSG', 'SQUASH_MSG'):
825 path = self.git_repo_path(file)
826 if os.path.exists(path):
827 return path
828 return None
830 def get_merge_message(self):
831 return self.git.fmt_merge_msg('--file',
832 self.git_repo_path('FETCH_HEAD'))
834 def abort_merge(self):
835 # Reset the worktree
836 output = self.git.read_tree('HEAD', reset=True, u=True, v=True)
837 # remove MERGE_HEAD
838 merge_head = self.git_repo_path('MERGE_HEAD')
839 if os.path.exists(merge_head):
840 os.unlink(merge_head)
841 # remove MERGE_MESSAGE, etc.
842 merge_msg_path = self.get_merge_message_path()
843 while merge_msg_path:
844 os.unlink(merge_msg_path)
845 merge_msg_path = self.get_merge_message_path()
847 def get_workdir_state(self, amend=False):
848 """RETURNS: A tuple of staged, unstaged untracked, and unmerged
849 file lists.
851 self.partially_staged = set()
852 head = 'HEAD'
853 if amend:
854 head = 'HEAD^'
855 (staged, unstaged, unmerged, untracked) = ([], [], [], [])
857 for idx, line in enumerate(self.git.diff_index(head).splitlines()):
858 rest, name = line.split('\t')
859 status = rest[-1]
860 name = eval_path(name)
861 if status == 'M' or status == 'D':
862 unstaged.append(name)
864 for idx, line in enumerate(self.git.diff_index(head, cached=True)
865 .splitlines()):
866 rest, name = line.split('\t')
867 status = rest[-1]
868 name = eval_path(name)
869 if status == 'M':
870 staged.append(name)
871 # is this file partially staged?
872 diff = self.git.diff('--', name, name_only=True, z=True)
873 if not diff.strip():
874 unstaged.remove(name)
875 else:
876 self.partially_staged.add(name)
877 elif status == 'A':
878 staged.append(name)
879 elif status == 'D':
880 staged.append(name)
881 unstaged.remove(name)
882 elif status == 'U':
883 unmerged.append(name)
885 for line in self.git.ls_files(others=True, exclude_standard=True,
886 z=True).split('\0'):
887 if line:
888 untracked.append(line.decode('utf-8'))
890 return (staged, unstaged, untracked, unmerged)
892 def reset_helper(self, *args, **kwargs):
893 return self.git.reset('--', *args, **kwargs)
895 def remote_url(self, name):
896 return self.git.config('remote.%s.url' % name, get=True)
898 def get_remote_args(self, remote,
899 local_branch='', remote_branch='',
900 ffwd=True, tags=False):
901 if ffwd:
902 branch_arg = '%s:%s' % ( remote_branch, local_branch )
903 else:
904 branch_arg = '+%s:%s' % ( remote_branch, local_branch )
905 args = [remote]
906 if local_branch and remote_branch:
907 args.append(branch_arg)
908 kwargs = {
909 'verbose': True,
910 'tags': tags,
912 return (args, kwargs)
914 def gen_remote_helper(self, gitaction):
915 """Generates a closure that calls git fetch, push or pull
917 def remote_helper(remote, **kwargs):
918 args, kwargs = self.get_remote_args(remote, **kwargs)
919 return gitaction(*args, **kwargs)
920 return remote_helper
922 def parse_ls_tree(self, rev):
923 """Returns a list of(mode, type, sha1, path) tuples."""
924 lines = self.git.ls_tree(rev, r=True).splitlines()
925 output = []
926 regex = re.compile('^(\d+)\W(\w+)\W(\w+)[ \t]+(.*)$')
927 for line in lines:
928 match = regex.match(line)
929 if match:
930 mode = match.group(1)
931 objtype = match.group(2)
932 sha1 = match.group(3)
933 filename = match.group(4)
934 output.append((mode, objtype, sha1, filename,) )
935 return output
937 def format_patch_helper(self, to_export, revs, output='patches'):
938 """writes patches named by to_export to the output directory."""
940 outlines = []
942 cur_rev = to_export[0]
943 cur_master_idx = revs.index(cur_rev)
945 patches_to_export = [ [cur_rev] ]
946 patchset_idx = 0
948 # Group the patches into continuous sets
949 for idx, rev in enumerate(to_export[1:]):
950 # Limit the search to the current neighborhood for efficiency
951 master_idx = revs[ cur_master_idx: ].index(rev)
952 master_idx += cur_master_idx
953 if master_idx == cur_master_idx + 1:
954 patches_to_export[ patchset_idx ].append(rev)
955 cur_master_idx += 1
956 continue
957 else:
958 patches_to_export.append([ rev ])
959 cur_master_idx = master_idx
960 patchset_idx += 1
962 # Export each patchsets
963 for patchset in patches_to_export:
964 cmdoutput = self.export_patchset(patchset[0],
965 patchset[-1],
966 output="patches",
967 n=len(patchset) > 1,
968 thread=True,
969 patch_with_stat=True)
970 outlines.append(cmdoutput)
971 return '\n'.join(outlines)
973 def export_patchset(self, start, end, output="patches", **kwargs):
974 revarg = '%s^..%s' % (start, end)
975 return self.git.format_patch("-o", output, revarg, **kwargs)
977 def current_branch(self):
978 """Parses 'git symbolic-ref' to find the current branch."""
979 headref = self.git.symbolic_ref('HEAD')
980 if headref.startswith('refs/heads/'):
981 return headref[11:]
982 elif headref.startswith('fatal: '):
983 return 'Not currently on any branch'
984 return headref
986 def create_branch(self, name, base, track=False):
987 """Creates a branch starting from base. Pass track=True
988 to create a remote tracking branch."""
989 return self.git.branch(name, base, track=track)
991 def cherry_pick_list(self, revs, **kwargs):
992 """Cherry-picks each revision into the current branch.
993 Returns a list of command output strings (1 per cherry pick)"""
994 if not revs:
995 return []
996 cherries = []
997 for rev in revs:
998 cherries.append(self.git.cherry_pick(rev, **kwargs))
999 return '\n'.join(cherries)
1001 def parse_stash_list(self, revids=False):
1002 """Parses "git stash list" and returns a list of stashes."""
1003 stashes = self.git.stash("list").splitlines()
1004 if revids:
1005 return [ s[:s.index(':')] for s in stashes ]
1006 else:
1007 return [ s[s.index(':')+1:] for s in stashes ]
1009 def diffstat(self):
1010 return self.git.diff(
1011 'HEAD^',
1012 unified=self.diff_context,
1013 stat=True)
1015 def pad(self, pstr, num=22):
1016 topad = num-len(pstr)
1017 if topad > 0:
1018 return pstr + ' '*topad
1019 else:
1020 return pstr
1022 def describe(self, revid, descr):
1023 version = self.git.describe(revid, tags=True, always=True,
1024 abbrev=4)
1025 return version + ' - ' + descr
1027 def update_revision_lists(self, filename=None, show_versions=False):
1028 num_results = self.get_num_results()
1029 if filename:
1030 rev_list = self.git.log('--', filename,
1031 max_count=num_results,
1032 pretty='oneline')
1033 else:
1034 rev_list = self.git.log(max_count=num_results,
1035 pretty='oneline', all=True)
1037 commit_list = self.parse_rev_list(rev_list)
1038 commit_list.reverse()
1039 commits = map(lambda x: x[0], commit_list)
1040 descriptions = map(lambda x: x[1].decode('utf-8'), commit_list)
1041 if show_versions:
1042 fancy_descr_list = map(lambda x: self.describe(*x), commit_list)
1043 self.set_descriptions_start(fancy_descr_list)
1044 self.set_descriptions_end(fancy_descr_list)
1045 else:
1046 self.set_descriptions_start(descriptions)
1047 self.set_descriptions_end(descriptions)
1049 self.set_revisions_start(commits)
1050 self.set_revisions_end(commits)
1052 return commits
1054 def get_changed_files(self, start, end):
1055 zfiles_str = self.git.diff('%s..%s' % (start, end),
1056 name_only=True, z=True).strip('\0')
1057 return [ enc.decode('utf-8')
1058 for enc in zfiles_str.split('\0') if enc ]
1060 def get_renamed_files(self, start, end):
1061 files = []
1062 difflines = self.git.diff('%s..%s' % (start, end), M=True).splitlines()
1063 return [ eval_path(r[12:].rstrip())
1064 for r in difflines if r.startswith('rename from ') ]