models: do not list a file in both the unmerged and modified sections
[git-cola.git] / cola / models.py
blob909dc819d71cad8904740f1cac12a6e9625c987d
1 # Copyright (c) 2008 David Aguilar
2 import os
3 import sys
4 import re
5 import time
6 import subprocess
7 from cStringIO import StringIO
9 from cola import git
10 from cola import utils
11 from cola import model
13 #+-------------------------------------------------------------------------
14 #+ A regex for matching the output of git(log|rev-list) --pretty=oneline
15 REV_LIST_REGEX = re.compile('([0-9a-f]+)\W(.*)')
17 class GitCola(git.Git):
18 """GitPython throws exceptions by default.
19 We suppress exceptions in favor of return values.
20 """
21 def __init__(self):
22 git.Git.__init__(self)
23 self.load_worktree(os.getcwd())
25 def load_worktree(self, path):
26 self._git_dir = path
27 self._work_tree = None
28 self.get_work_tree()
30 def get_work_tree(self):
31 if self._work_tree:
32 return self._work_tree
33 self.get_git_dir()
34 if self._git_dir:
35 curdir = self._git_dir
36 else:
37 curdir = os.getcwd()
39 if self._is_git_dir(os.path.join(curdir, '.git')):
40 return curdir
42 # Handle bare repositories
43 if (len(os.path.basename(curdir)) > 4
44 and curdir.endswith('.git')):
45 return curdir
46 if 'GIT_WORK_TREE' in os.environ:
47 self._work_tree = os.getenv('GIT_WORK_TREE')
48 if not self._work_tree or not os.path.isdir(self._work_tree):
49 if self._git_dir:
50 gitparent = os.path.join(os.path.abspath(self._git_dir), '..')
51 self._work_tree = os.path.abspath(gitparent)
52 self.set_cwd(self._work_tree)
53 return self._work_tree
55 def is_valid(self):
56 return self._git_dir and self._is_git_dir(self._git_dir)
58 def get_git_dir(self):
59 if self.is_valid():
60 return self._git_dir
61 if 'GIT_DIR' in os.environ:
62 self._git_dir = os.getenv('GIT_DIR')
63 if self._git_dir:
64 curpath = os.path.abspath(self._git_dir)
65 else:
66 curpath = os.path.abspath(os.getcwd())
67 # Search for a .git directory
68 while curpath:
69 if self._is_git_dir(curpath):
70 self._git_dir = curpath
71 break
72 gitpath = os.path.join(curpath, '.git')
73 if self._is_git_dir(gitpath):
74 self._git_dir = gitpath
75 break
76 curpath, dummy = os.path.split(curpath)
77 if not dummy:
78 break
79 return self._git_dir
81 def _is_git_dir(self, d):
82 """ This is taken from the git setup.c:is_git_directory
83 function."""
84 if (os.path.isdir(d)
85 and os.path.isdir(os.path.join(d, 'objects'))
86 and os.path.isdir(os.path.join(d, 'refs'))):
87 headref = os.path.join(d, 'HEAD')
88 return (os.path.isfile(headref)
89 or (os.path.islink(headref)
90 and os.readlink(headref).startswith('refs')))
91 return False
93 def eval_path(path):
94 """handles quoted paths."""
95 if path.startswith('"') and path.endswith('"'):
96 return eval(path).decode('utf-8')
97 else:
98 return path
100 class Model(model.Model):
101 """Provides a friendly wrapper for doing commit git operations."""
103 def clone(self):
104 worktree = self.git.get_work_tree()
105 clone = model.Model.clone(self)
106 clone.use_worktree(worktree)
107 return clone
109 def use_worktree(self, worktree):
110 self.git.load_worktree(worktree)
111 is_valid = self.git.is_valid()
112 if is_valid:
113 self.__init_config_data()
114 return is_valid
116 def init(self):
117 """Reads git repository settings and sets several methods
118 so that they refer to the git module. This object
119 encapsulates cola's interaction with git."""
121 # Initialize the git command object
122 self.git = GitCola()
123 self.partially_staged = set()
125 self.fetch_helper = self.gen_remote_helper(self.git.fetch)
126 self.push_helper = self.gen_remote_helper(self.git.push)
127 self.pull_helper = self.gen_remote_helper(self.git.pull)
129 self.create(
130 #####################################################
131 # Used in various places
132 currentbranch = '',
133 remotes = [],
134 remotename = '',
135 local_branch = '',
136 remote_branch = '',
137 search_text = '',
139 #####################################################
140 # Used primarily by the main UI
141 commitmsg = '',
142 modified = [],
143 staged = [],
144 unstaged = [],
145 untracked = [],
146 unmerged = [],
148 #####################################################
149 # Used by the create branch dialog
150 revision = '',
151 local_branches = [],
152 remote_branches = [],
153 tags = [],
155 #####################################################
156 # Used by the commit/repo browser
157 directory = '',
158 revisions = [],
159 summaries = [],
161 # These are parallel lists
162 types = [],
163 sha1s = [],
164 names = [],
166 # All items below here are re-calculated in
167 # init_browser_data()
168 directories = [],
169 directory_entries = {},
171 # These are also parallel lists
172 subtree_types = [],
173 subtree_sha1s = [],
174 subtree_names = [],
177 def __init_config_data(self):
178 """Reads git config --list and creates parameters
179 for each setting."""
180 # These parameters are saved in .gitconfig,
181 # so ideally these should be as short as possible.
183 # config items that are controllable globally
184 # and per-repository
185 self.__local_and_global_defaults = {
186 'user_name': '',
187 'user_email': '',
188 'merge_summary': False,
189 'merge_diffstat': True,
190 'merge_verbosity': 2,
191 'gui_diffcontext': 3,
192 'gui_pruneduringfetch': False,
194 # config items that are purely git config --global settings
195 self.__global_defaults = {
196 'cola_geometry':'',
197 'cola_fontui': '',
198 'cola_fontuisize': 12,
199 'cola_fontdiff': '',
200 'cola_fontdiffsize': 12,
201 'cola_savewindowsettings': False,
202 'merge_keepbackup': True,
203 'merge_tool': os.getenv('MERGETOOL', 'xxdiff'),
204 'gui_editor': os.getenv('EDITOR', 'gvim'),
205 'gui_historybrowser': 'gitk',
208 local_dict = self.config_dict(local=True)
209 global_dict = self.config_dict(local=False)
211 for k,v in local_dict.iteritems():
212 self.set_param('local_'+k, v)
213 for k,v in global_dict.iteritems():
214 self.set_param('global_'+k, v)
215 if k not in local_dict:
216 local_dict[k]=v
217 self.set_param('local_'+k, v)
219 # Bootstrap the internal font*size variables
220 for param in ('global_cola_fontui', 'global_cola_fontdiff'):
221 setdefault = True
222 if hasattr(self, param):
223 font = self.get_param(param)
224 if font:
225 setdefault = False
226 size = int(font.split(',')[1])
227 self.set_param(param+'size', size)
228 param = param[len('global_'):]
229 global_dict[param] = font
230 global_dict[param+'size'] = size
232 # Load defaults for all undefined items
233 local_and_global_defaults = self.__local_and_global_defaults
234 for k,v in local_and_global_defaults.iteritems():
235 if k not in local_dict:
236 self.set_param('local_'+k, v)
237 if k not in global_dict:
238 self.set_param('global_'+k, v)
240 global_defaults = self.__global_defaults
241 for k,v in global_defaults.iteritems():
242 if k not in global_dict:
243 self.set_param('global_'+k, v)
245 # Load the diff context
246 self.diff_context = self.local_gui_diffcontext
248 def get_global_config(self, key):
249 return getattr(self, 'global_'+key.replace('.', '_'))
251 def get_cola_config(self, key):
252 return getattr(self, 'global_cola_'+key)
254 def get_gui_config(self, key):
255 return getattr(self, 'global_gui_'+key)
257 def get_default_remote(self):
258 branch = self.get_currentbranch()
259 branchconfig = 'local_branch_%s_remote' % branch
260 if branchconfig in self.get_param_names():
261 remote = self.get_param(branchconfig)
262 else:
263 remote = 'origin'
264 return remote
266 def get_corresponding_remote_ref(self):
267 remote = self.get_default_remote()
268 branch = self.get_currentbranch()
269 best_match = '%s/%s' % (remote, branch)
270 remote_branches = self.get_remote_branches()
271 if not remote_branches:
272 return remote
273 for rb in remote_branches:
274 if rb == best_match:
275 return rb
276 return remote_branches[0]
278 def get_diff_filenames(self, arg):
279 diff_zstr = self.git.diff(arg, name_only=True, z=True).rstrip('\0')
280 return [ f.decode('utf-8') for f in diff_zstr.split('\0') if f ]
282 def branch_list(self, remote=False):
283 branches = map(lambda x: x.lstrip('* '),
284 self.git.branch(r=remote).splitlines())
285 if remote:
286 remotes = []
287 for branch in branches:
288 if branch.endswith('/HEAD'):
289 continue
290 remotes.append(branch)
291 return remotes
292 return branches
294 def get_config_params(self):
295 params = []
296 params.extend(map(lambda x: 'local_' + x,
297 self.__local_and_global_defaults.keys()))
298 params.extend(map(lambda x: 'global_' + x,
299 self.__local_and_global_defaults.keys()))
300 params.extend(map(lambda x: 'global_' + x,
301 self.__global_defaults.keys()))
302 return [ p for p in params if not p.endswith('size') ]
304 def save_config_param(self, param):
305 if param not in self.get_config_params():
306 return
307 value = self.get_param(param)
308 if param == 'local_gui_diffcontext':
309 self.diff_context = value
310 if param.startswith('local_'):
311 param = param[len('local_'):]
312 is_local = True
313 elif param.startswith('global_'):
314 param = param[len('global_'):]
315 is_local = False
316 else:
317 raise Exception("Invalid param '%s' passed to " % param
318 +'save_config_param()')
319 param = param.replace('_', '.') # model -> git
320 return self.config_set(param, value, local=is_local)
322 def init_browser_data(self):
323 """This scans over self.(names, sha1s, types) to generate
324 directories, directory_entries, and subtree_*"""
326 # Collect data for the model
327 if not self.get_currentbranch(): return
329 self.subtree_types = []
330 self.subtree_sha1s = []
331 self.subtree_names = []
332 self.directories = []
333 self.directory_entries = {}
335 # Lookup the tree info
336 tree_info = self.parse_ls_tree(self.get_currentbranch())
338 self.set_types(map( lambda(x): x[1], tree_info ))
339 self.set_sha1s(map( lambda(x): x[2], tree_info ))
340 self.set_names(map( lambda(x): x[3], tree_info ))
342 if self.directory: self.directories.append('..')
344 dir_entries = self.directory_entries
345 dir_regex = re.compile('([^/]+)/')
346 dirs_seen = {}
347 subdirs_seen = {}
349 for idx, name in enumerate(self.names):
350 if not name.startswith(self.directory):
351 continue
352 name = name[ len(self.directory): ]
353 if name.count('/'):
354 # This is a directory...
355 match = dir_regex.match(name)
356 if not match:
357 continue
358 dirent = match.group(1) + '/'
359 if dirent not in self.directory_entries:
360 self.directory_entries[dirent] = []
362 if dirent not in dirs_seen:
363 dirs_seen[dirent] = True
364 self.directories.append(dirent)
366 entry = name.replace(dirent, '')
367 entry_match = dir_regex.match(entry)
368 if entry_match:
369 subdir = entry_match.group(1) + '/'
370 if subdir in subdirs_seen:
371 continue
372 subdirs_seen[subdir] = True
373 dir_entries[dirent].append(subdir)
374 else:
375 dir_entries[dirent].append(entry)
376 else:
377 self.subtree_types.append(self.types[idx])
378 self.subtree_sha1s.append(self.sha1s[idx])
379 self.subtree_names.append(name)
381 def add_or_remove(self, to_process):
382 """Invokes 'git add' to index the filenames in to_process that exist
383 and 'git rm' for those that do not exist."""
385 if not to_process:
386 return 'No files to add or remove.'
388 to_add = []
389 to_remove = []
391 for filename in to_process:
392 encfilename = filename.encode('utf-8')
393 if os.path.exists(encfilename):
394 to_add.append(filename)
396 if to_add:
397 output = self.git.add(v=True, *to_add)
398 else:
399 output = ''
401 if len(to_add) == len(to_process):
402 # to_process only contained unremoved files --
403 # short-circuit the removal checks
404 return output
406 # Process files to remote
407 for filename in to_process:
408 if not os.path.exists(filename):
409 to_remove.append(filename)
410 output + '\n\n' + self.git.rm(*to_remove)
412 def get_editor(self):
413 return self.get_gui_config('editor')
415 def get_mergetool(self):
416 return self.get_global_config('merge.tool')
418 def get_history_browser(self):
419 return self.get_gui_config('historybrowser')
421 def remember_gui_settings(self):
422 return self.get_cola_config('savewindowsettings')
424 def get_tree_node(self, idx):
425 return (self.get_types()[idx],
426 self.get_sha1s()[idx],
427 self.get_names()[idx] )
429 def get_subtree_node(self, idx):
430 return (self.get_subtree_types()[idx],
431 self.get_subtree_sha1s()[idx],
432 self.get_subtree_names()[idx] )
434 def get_all_branches(self):
435 return (self.get_local_branches() + self.get_remote_branches())
437 def set_remote(self, remote):
438 if not remote:
439 return
440 self.set_param('remote', remote)
441 branches = utils.grep('%s/\S+$' % remote,
442 self.branch_list(remote=True),
443 squash=False)
444 self.set_remote_branches(branches)
446 def add_signoff(self,*rest):
447 """Adds a standard Signed-off by: tag to the end
448 of the current commit message."""
449 msg = self.get_commitmsg()
450 signoff =('\n\nSigned-off-by: %s <%s>\n'
451 % (self.get_local_user_name(), self.get_local_user_email()))
452 if signoff not in msg:
453 self.set_commitmsg(msg + signoff)
455 def apply_diff(self, filename):
456 return self.git.apply(filename, index=True, cached=True)
458 def apply_diff_to_worktree(self, filename):
459 return self.git.apply(filename)
461 def load_commitmsg(self, path):
462 file = open(path, 'r')
463 contents = file.read().decode('utf-8')
464 file.close()
465 self.set_commitmsg(contents)
467 def get_prev_commitmsg(self,*rest):
468 """Queries git for the latest commit message and sets it in
469 self.commitmsg."""
470 commit_msg = []
471 commit_lines = self.git.show('HEAD').decode('utf-8').split('\n')
472 for idx, msg in enumerate(commit_lines):
473 if idx < 4:
474 continue
475 msg = msg.lstrip()
476 if msg.startswith('diff --git'):
477 commit_msg.pop()
478 break
479 commit_msg.append(msg)
480 self.set_commitmsg('\n'.join(commit_msg).rstrip())
482 def load_commitmsg_template(self):
483 try:
484 template = self.get_global_config('commit.template')
485 except AttributeError:
486 return
487 self.load_commitmsg(template)
489 def update_status(self, amend=False):
490 # This allows us to defer notification until the
491 # we finish processing data
492 notify_enabled = self.get_notify()
493 self.set_notify(False)
495 (self.staged,
496 self.modified,
497 self.unmerged,
498 self.untracked) = self.get_workdir_state(amend=amend)
499 # NOTE: the model's unstaged list holds an aggregate of the
500 # the modified, unmerged, and untracked file lists.
501 self.set_unstaged(self.modified + self.unmerged + self.untracked)
502 self.set_currentbranch(self.current_branch())
503 self.set_remotes(self.git.remote().splitlines())
504 self.set_remote_branches(self.branch_list(remote=True))
505 self.set_local_branches(self.branch_list(remote=False))
506 self.set_tags(self.git.tag().splitlines())
507 self.set_revision('')
508 self.set_local_branch('')
509 self.set_remote_branch('')
510 # Re-enable notifications and emit changes
511 self.set_notify(notify_enabled)
512 self.notify_observers('staged','unstaged')
514 def delete_branch(self, branch):
515 return self.git.branch(branch, D=True)
517 def get_revision_sha1(self, idx):
518 return self.get_revisions()[idx]
520 def apply_font_size(self, param, default):
521 old_font = self.get_param(param)
522 if not old_font:
523 old_font = default
524 size = self.get_param(param+'size')
525 props = old_font.split(',')
526 props[1] = str(size)
527 new_font = ','.join(props)
529 self.set_param(param, new_font)
531 def get_commit_diff(self, sha1):
532 commit = self.git.show(sha1)
533 first_newline = commit.index('\n')
534 if commit[first_newline+1:].startswith('Merge:'):
535 return (commit + '\n\n'
536 + self.diff_helper(commit=sha1,
537 cached=False,
538 suppress_header=False))
539 else:
540 return commit
542 def get_filename(self, idx, staged=True):
543 try:
544 if staged:
545 return self.get_staged()[idx]
546 else:
547 return self.get_unstaged()[idx]
548 except IndexError:
549 return None
551 def get_diff_details(self, idx, ref, staged=True):
552 filename = self.get_filename(idx, staged=staged)
553 if not filename:
554 return (None, None, None)
555 encfilename = filename.encode('utf-8')
556 if staged:
557 if os.path.exists(encfilename):
558 status = 'Staged for commit'
559 else:
560 status = 'Staged for removal'
561 diff = self.diff_helper(filename=filename,
562 ref=ref,
563 cached=True)
564 else:
565 if os.path.isdir(encfilename):
566 status = 'Untracked directory'
567 diff = '\n'.join(os.listdir(filename))
569 elif filename in self.get_unmerged():
570 status = 'Unmerged'
571 diff = ('@@@+-+-+-+-+-+-+-+-+-+-+ UNMERGED +-+-+-+-+-+-+-+-+-+-+@@@\n\n'
572 '>>> %s is unmerged.\n' % filename +
573 'Right-click on the filename '
574 'to launch "git mergetool".\n\n\n')
575 diff += self.diff_helper(filename=filename,
576 cached=False,
577 patch_with_raw=False)
578 elif filename in self.get_modified():
579 status = 'Modified, not staged'
580 diff = self.diff_helper(filename=filename,
581 cached=False)
582 else:
583 status = 'Untracked, not staged'
584 diff = 'SHA1: ' + self.git.hash_object(filename)
585 return diff, status, filename
587 def stage_modified(self):
588 output = self.git.add(v=True, *self.get_modified())
589 self.update_status()
590 return output
592 def stage_untracked(self):
593 output = self.git.add(*self.get_untracked())
594 self.update_status()
595 return output
597 def reset(self, *items):
598 output = self.git.reset('--', *items)
599 self.update_status()
600 return output
602 def unstage_all(self):
603 output = self.git.reset()
604 self.update_status()
605 return output
607 def stage_all(self):
608 output = self.git.add(v=True,u=True)
609 self.update_status()
610 return output
612 def save_gui_settings(self):
613 self.config_set('cola.geometry', utils.get_geom(), local=False)
615 def config_set(self, key=None, value=None, local=True):
616 if key and value is not None:
617 # git config category.key value
618 strval = unicode(value)
619 if type(value) is bool:
620 # git uses "true" and "false"
621 strval = strval.lower()
622 if local:
623 argv = [ key, strval ]
624 else:
625 argv = [ '--global', key, strval ]
626 return self.git.config(*argv)
627 else:
628 msg = "oops in config_set(key=%s,value=%s,local=%s"
629 raise Exception(msg % (key, value, local))
631 def config_dict(self, local=True):
632 """parses the lines from git config --list into a dictionary"""
634 kwargs = {
635 'list': True,
636 'global': not local, # global is a python keyword
638 config_lines = self.git.config(**kwargs).splitlines()
639 newdict = {}
640 for line in config_lines:
641 k, v = line.split('=', 1)
642 v = v.decode('utf-8')
643 k = k.replace('.','_') # git -> model
644 if v == 'true' or v == 'false':
645 v = bool(eval(v.title()))
646 try:
647 v = int(eval(v))
648 except:
649 pass
650 newdict[k]=v
651 return newdict
653 def commit_with_msg(self, msg, amend=False):
654 """Creates a git commit."""
656 if not msg.endswith('\n'):
657 msg += '\n'
658 # Sure, this is a potential "security risk," but if someone
659 # is trying to intercept/re-write commit messages on your system,
660 # then you probably have bigger problems to worry about.
661 tmpfile = self.get_tmp_filename()
663 # Create the commit message file
664 fh = open(tmpfile, 'w')
665 fh.write(msg)
666 fh.close()
668 # Run 'git commit'
669 (status, stdout, stderr) = self.git.commit(F=tmpfile,
670 v=True,
671 amend=amend,
672 with_extended_output=True)
673 os.unlink(tmpfile)
675 return (status, stdout+stderr)
678 def diffindex(self):
679 return self.git.diff(unified=self.diff_context,
680 stat=True,
681 cached=True)
683 def get_tmp_dir(self):
684 # Allow TMPDIR/TMP with a fallback to /tmp
685 return os.environ.get('TMP', os.environ.get('TMPDIR', '/tmp'))
687 def get_tmp_file_pattern(self):
688 return os.path.join(self.get_tmp_dir(), '*.git-cola.%s.*' % os.getpid())
690 def get_tmp_filename(self, prefix=''):
691 basename = ((prefix+'.git-cola.%s.%s'
692 % (os.getpid(), time.time())))
693 basename = basename.replace('/', '-')
694 basename = basename.replace('\\', '-')
695 tmpdir = self.get_tmp_dir()
696 return os.path.join(tmpdir, basename)
698 def log_helper(self, all=False):
700 Returns a pair of parallel arrays listing the revision sha1's
701 and commit summaries.
703 revs = []
704 summaries = []
705 regex = REV_LIST_REGEX
706 output = self.git.log(pretty='oneline', all=all)
707 for line in output.splitlines():
708 match = regex.match(line)
709 if match:
710 revs.append(match.group(1))
711 summaries.append(match.group(2))
712 return (revs, summaries)
714 def parse_rev_list(self, raw_revs):
715 revs = []
716 for line in raw_revs.splitlines():
717 match = REV_LIST_REGEX.match(line)
718 if match:
719 rev_id = match.group(1)
720 summary = match.group(2)
721 revs.append((rev_id, summary,))
722 return revs
724 def rev_list_range(self, start, end):
725 range = '%s..%s' % (start, end)
726 raw_revs = self.git.rev_list(range, pretty='oneline')
727 return self.parse_rev_list(raw_revs)
729 def diff_helper(self,
730 commit=None,
731 branch=None,
732 ref = None,
733 endref = None,
734 filename=None,
735 cached=True,
736 with_diff_header=False,
737 suppress_header=True,
738 reverse=False,
739 patch_with_raw=True):
740 "Invokes git diff on a filepath."
741 if commit:
742 ref, endref = commit+'^', commit
743 argv = []
744 if ref and endref:
745 argv.append('%s..%s' % (ref, endref))
746 elif ref:
747 argv.append(ref)
748 elif branch:
749 argv.append(branch)
751 if filename:
752 argv.append('--')
753 if type(filename) is list:
754 argv.extend(filename)
755 else:
756 argv.append(filename)
758 output = StringIO()
759 start = False
760 del_tag = 'deleted file mode '
762 headers = []
763 deleted = cached and not os.path.exists(filename.encode('utf-8'))
765 diffoutput = self.git.diff(R=reverse,
766 cached=cached,
767 patch_with_raw=patch_with_raw,
768 unified=self.diff_context,
769 with_raw_output=True,
770 *argv)
771 diff = diffoutput.splitlines()
772 for line in diff:
773 line = unicode(line.decode('utf-8'))
774 if not start and '@@' == line[:2] and '@@' in line[2:]:
775 start = True
776 if start or(deleted and del_tag in line):
777 output.write(line.encode('utf-8') + '\n')
778 else:
779 if with_diff_header:
780 headers.append(line)
781 elif not suppress_header:
782 output.write(line.encode('utf-8') + '\n')
784 result = output.getvalue().decode('utf-8')
785 output.close()
787 if with_diff_header:
788 return('\n'.join(headers), result)
789 else:
790 return result
792 def git_repo_path(self, *subpaths):
793 paths = [ self.git.get_git_dir() ]
794 paths.extend(subpaths)
795 return os.path.realpath(os.path.join(*paths))
797 def get_merge_message_path(self):
798 for file in ('MERGE_MSG', 'SQUASH_MSG'):
799 path = self.git_repo_path(file)
800 if os.path.exists(path):
801 return path
802 return None
804 def get_merge_message(self):
805 return self.git.fmt_merge_msg('--file',
806 self.git_repo_path('FETCH_HEAD'))
808 def abort_merge(self):
809 # Reset the worktree
810 output = self.git.read_tree('HEAD', reset=True, u=True, v=True)
811 # remove MERGE_HEAD
812 merge_head = self.git_repo_path('MERGE_HEAD')
813 if os.path.exists(merge_head):
814 os.unlink(merge_head)
815 # remove MERGE_MESSAGE, etc.
816 merge_msg_path = self.get_merge_message_path()
817 while merge_msg_path:
818 os.unlink(merge_msg_path)
819 merge_msg_path = self.get_merge_message_path()
821 def get_workdir_state(self, amend=False):
822 """RETURNS: A tuple of staged, unstaged untracked, and unmerged
823 file lists.
825 self.partially_staged = set()
826 head = 'HEAD'
827 if amend:
828 head = 'HEAD^'
829 (staged, modified, unmerged, untracked) = ([], [], [], [])
830 try:
831 for name in self.git.diff_index(head).splitlines():
832 rest, name = name.split('\t')
833 status = rest[-1]
834 name = eval_path(name)
835 if status == 'M' or status == 'D':
836 modified.append(name)
837 except:
838 # handle git init
839 for name in (self.git.ls_files(modified=True, z=True)
840 .split('\0')):
841 if name:
842 modified.append(name.decode('utf-8'))
844 try:
845 for name in (self.git.diff_index(head, cached=True)
846 .splitlines()):
847 rest, name = name.split('\t')
848 status = rest[-1]
849 name = eval_path(name)
850 if status == 'M':
851 staged.append(name)
852 # is this file partially staged?
853 diff = self.git.diff('--', name, name_only=True, z=True)
854 if not diff.strip():
855 modified.remove(name)
856 else:
857 self.partially_staged.add(name)
858 elif status == 'A':
859 staged.append(name)
860 elif status == 'D':
861 staged.append(name)
862 modified.remove(name)
863 elif status == 'U':
864 unmerged.append(name)
865 except:
866 # handle git init
867 for name in self.git.ls_files(z=True).strip('\0').split('\0'):
868 if name:
869 staged.append(name.decode('utf-8'))
871 for name in self.git.ls_files(others=True, exclude_standard=True,
872 z=True).split('\0'):
873 if name:
874 untracked.append(name.decode('utf-8'))
876 # remove duplicate merged and modified entries
877 for u in unmerged:
878 if u in modified:
879 modified.remove(u)
881 return (staged, modified, unmerged, untracked)
883 def reset_helper(self, args):
884 """Removes files from the index.
885 This handles the git init case, which is why it's not
886 just git.reset(name).
887 For the git init case this fall back to git rm --cached.
889 output = self.git.reset('--', *args)
890 # handle git init -- we have to rm --cached them
891 state = self.get_workdir_state()
892 staged = state[0]
893 newargs = []
894 for arg in args:
895 if arg in staged:
896 newargs.append(arg)
897 if newargs:
898 output = self.git.rm('--', cached=True, *newargs)
899 return output
901 def remote_url(self, name):
902 return self.git.config('remote.%s.url' % name, get=True)
904 def get_remote_args(self, remote,
905 local_branch='', remote_branch='',
906 ffwd=True, tags=False):
907 if ffwd:
908 branch_arg = '%s:%s' % ( remote_branch, local_branch )
909 else:
910 branch_arg = '+%s:%s' % ( remote_branch, local_branch )
911 args = [remote]
912 if local_branch and remote_branch:
913 args.append(branch_arg)
914 kwargs = {
915 'verbose': True,
916 'tags': tags,
918 return (args, kwargs)
920 def gen_remote_helper(self, gitaction):
921 """Generates a closure that calls git fetch, push or pull
923 def remote_helper(remote, **kwargs):
924 args, kwargs = self.get_remote_args(remote, **kwargs)
925 return gitaction(*args, **kwargs)
926 return remote_helper
928 def parse_ls_tree(self, rev):
929 """Returns a list of(mode, type, sha1, path) tuples."""
930 lines = self.git.ls_tree(rev, r=True).splitlines()
931 output = []
932 regex = re.compile('^(\d+)\W(\w+)\W(\w+)[ \t]+(.*)$')
933 for line in lines:
934 match = regex.match(line)
935 if match:
936 mode = match.group(1)
937 objtype = match.group(2)
938 sha1 = match.group(3)
939 filename = match.group(4)
940 output.append((mode, objtype, sha1, filename,) )
941 return output
943 def format_patch_helper(self, to_export, revs, output='patches'):
944 """writes patches named by to_export to the output directory."""
946 outlines = []
948 cur_rev = to_export[0]
949 cur_master_idx = revs.index(cur_rev)
951 patches_to_export = [ [cur_rev] ]
952 patchset_idx = 0
954 # Group the patches into continuous sets
955 for idx, rev in enumerate(to_export[1:]):
956 # Limit the search to the current neighborhood for efficiency
957 master_idx = revs[ cur_master_idx: ].index(rev)
958 master_idx += cur_master_idx
959 if master_idx == cur_master_idx + 1:
960 patches_to_export[ patchset_idx ].append(rev)
961 cur_master_idx += 1
962 continue
963 else:
964 patches_to_export.append([ rev ])
965 cur_master_idx = master_idx
966 patchset_idx += 1
968 # Export each patchsets
969 for patchset in patches_to_export:
970 cmdoutput = self.export_patchset(patchset[0],
971 patchset[-1],
972 output="patches",
973 n=len(patchset) > 1,
974 thread=True,
975 patch_with_stat=True)
976 outlines.append(cmdoutput)
977 return '\n'.join(outlines)
979 def export_patchset(self, start, end, output="patches", **kwargs):
980 revarg = '%s^..%s' % (start, end)
981 return self.git.format_patch("-o", output, revarg, **kwargs)
983 def current_branch(self):
984 """Parses 'git symbolic-ref' to find the current branch."""
985 headref = self.git.symbolic_ref('HEAD')
986 if headref.startswith('refs/heads/'):
987 return headref[11:]
988 elif headref.startswith('fatal: '):
989 return 'Not currently on any branch'
990 return headref
992 def create_branch(self, name, base, track=False):
993 """Creates a branch starting from base. Pass track=True
994 to create a remote tracking branch."""
995 return self.git.branch(name, base, track=track)
997 def cherry_pick_list(self, revs, **kwargs):
998 """Cherry-picks each revision into the current branch.
999 Returns a list of command output strings (1 per cherry pick)"""
1000 if not revs:
1001 return []
1002 cherries = []
1003 for rev in revs:
1004 cherries.append(self.git.cherry_pick(rev, **kwargs))
1005 return '\n'.join(cherries)
1007 def parse_stash_list(self, revids=False):
1008 """Parses "git stash list" and returns a list of stashes."""
1009 stashes = self.git.stash("list").splitlines()
1010 if revids:
1011 return [ s[:s.index(':')] for s in stashes ]
1012 else:
1013 return [ s[s.index(':')+1:] for s in stashes ]
1015 def diffstat(self):
1016 return self.git.diff(
1017 'HEAD^',
1018 unified=self.diff_context,
1019 stat=True)
1021 def pad(self, pstr, num=22):
1022 topad = num-len(pstr)
1023 if topad > 0:
1024 return pstr + ' '*topad
1025 else:
1026 return pstr
1028 def describe(self, revid, descr):
1029 version = self.git.describe(revid, tags=True, always=True,
1030 abbrev=4)
1031 return version + ' - ' + descr
1033 def update_revision_lists(self, filename=None, show_versions=False):
1034 num_results = self.get_num_results()
1035 if filename:
1036 rev_list = self.git.log('--', filename,
1037 max_count=num_results,
1038 pretty='oneline')
1039 else:
1040 rev_list = self.git.log(max_count=num_results,
1041 pretty='oneline', all=True)
1043 commit_list = self.parse_rev_list(rev_list)
1044 commit_list.reverse()
1045 commits = map(lambda x: x[0], commit_list)
1046 descriptions = map(lambda x: x[1].decode('utf-8'), commit_list)
1047 if show_versions:
1048 fancy_descr_list = map(lambda x: self.describe(*x), commit_list)
1049 self.set_descriptions_start(fancy_descr_list)
1050 self.set_descriptions_end(fancy_descr_list)
1051 else:
1052 self.set_descriptions_start(descriptions)
1053 self.set_descriptions_end(descriptions)
1055 self.set_revisions_start(commits)
1056 self.set_revisions_end(commits)
1058 return commits
1060 def get_changed_files(self, start, end):
1061 zfiles_str = self.git.diff('%s..%s' % (start, end),
1062 name_only=True, z=True).strip('\0')
1063 return [ enc.decode('utf-8')
1064 for enc in zfiles_str.split('\0') if enc ]
1066 def get_renamed_files(self, start, end):
1067 files = []
1068 difflines = self.git.diff('%s..%s' % (start, end), M=True).splitlines()
1069 return [ eval_path(r[12:].rstrip())
1070 for r in difflines if r.startswith('rename from ') ]