Do not show binary diffs when editing patch
[stgit.git] / stgit / lib / git / repository.py
blobfff3441b0bb698818170ec85aca9eeb2fdae0c20
1 import atexit
2 import io
3 import re
5 from stgit import utils
6 from stgit.exception import StgException
7 from stgit.lib.objcache import ObjectCache
8 from stgit.run import Run, RunException
9 from stgit.utils import add_dict
11 from .iw import Index, IndexAndWorktree, MergeException, TemporaryIndex, Worktree
12 from .objects import Blob, Commit, Tree
15 class RepositoryException(StgException):
16 """Base class for all exceptions due to failed :class:`Repository` operations."""
19 class DetachedHeadException(RepositoryException):
20 """Exception raised when HEAD is detached (that is, there is no current branch)."""
22 def __init__(self):
23 super().__init__('Not on any branch')
26 class Refs:
27 """Accessor for the refs stored in a Git repository.
29 Will transparently cache the values of all refs.
31 """
33 empty_id = '0' * 40
35 def __init__(self, repository):
36 self._repository = repository
37 self._refs = None
39 def _ensure_refs_cache(self):
40 """(Re-)Build the cache of all refs in the repository."""
41 if self._refs is not None:
42 return
43 self._refs = {}
44 runner = self._repository.run(['git', 'show-ref'])
45 try:
46 lines = runner.output_lines()
47 except RunException:
48 # as this happens both in non-git trees and empty git
49 # trees, we silently ignore this error
50 return
51 for line in lines:
52 m = re.match(r'^([0-9a-f]{40})\s+(\S+)$', line, re.ASCII)
53 sha1, ref = m.groups()
54 self._refs[ref] = sha1
56 def __iter__(self):
57 self._ensure_refs_cache()
58 return iter(self._refs)
60 def reset_cache(self):
61 """Reset cached refs such that cache is rebuilt on next access.
63 Useful if refs are known to have changed due to an external command
64 such as `git pull`.
66 """
67 self._refs = None
69 def get(self, ref):
70 """Get the :class:`Commit` the given ref points to.
72 Throws :exc:`KeyError` if ref does not exist.
74 """
75 self._ensure_refs_cache()
76 return self._repository.get_commit(self._refs[ref])
78 def exists(self, ref):
79 """Check if the given ref exists."""
80 try:
81 self.get(ref)
82 except KeyError:
83 return False
84 else:
85 return True
87 def set(self, ref, commit, msg):
88 """Write the sha1 of the given :class:`Commit` to the ref.
90 The ref may or may not already exist.
92 """
93 self._ensure_refs_cache()
94 old_sha1 = self._refs.get(ref, self.empty_id)
95 new_sha1 = commit.sha1
96 if old_sha1 != new_sha1:
97 self._repository.run(
98 ['git', 'update-ref', '-m', msg, ref, new_sha1, old_sha1]
99 ).no_output()
100 self._refs[ref] = new_sha1
102 def delete(self, ref):
103 """Delete the given ref.
105 Throws :exc:`KeyError` if ref does not exist.
108 self._ensure_refs_cache()
109 self._repository.run(
110 ['git', 'update-ref', '-d', ref, self._refs[ref]]
111 ).no_output()
112 del self._refs[ref]
114 def rename(self, msg, *renames):
115 """Rename old, new ref pairs."""
116 ref_ops = []
117 for old_ref, new_ref in renames:
118 sha1 = self.get(old_ref).sha1
119 ref_ops.append('create %s %s\n' % (new_ref, sha1))
120 ref_ops.append('delete %s %s\n' % (old_ref, sha1))
122 self._repository.run(['git', 'update-ref', '-m', msg, '--stdin'])
123 .raw_input(''.join(ref_ops))
124 .discard_output()
126 self.reset_cache()
128 def batch_update(self, msg, create=(), update=(), delete=()):
129 """Batch update/create/delete refs."""
130 self._ensure_refs_cache()
131 ref_ops = []
132 for ref, commit in create:
133 ref_ops.append('create %s %s\n' % (ref, commit.sha1))
134 for ref, commit in update:
135 old_sha1 = self._refs[ref]
136 ref_ops.append('update %s %s %s\n' % (ref, commit.sha1, old_sha1))
137 for ref in delete:
138 old_sha1 = self._refs[ref]
139 ref_ops.append('delete %s %s\n' % (ref, old_sha1))
140 if ref_ops:
142 self._repository.run(['git', 'update-ref', '-m', msg, '--stdin'])
143 .raw_input(''.join(ref_ops))
144 .discard_output()
146 self.reset_cache()
149 class CatFileProcess:
150 def __init__(self, repo):
151 self._repository = repo
152 self._proc = None
153 atexit.register(self._shutdown)
155 def _get_process(self):
156 if self._proc is None:
157 self._proc = (
158 self._repository.run(['git', 'cat-file', '--batch'])
159 .encoding(None)
160 .decoding(None)
161 .run_background()
163 return self._proc
165 def _shutdown(self):
166 if self._proc is not None:
167 with self._proc:
168 self._proc.terminate()
170 def cat_file(self, sha1):
171 p = self._get_process()
172 p.stdin.write(b'%s\n' % sha1.encode('ascii'))
173 p.stdin.flush()
175 # Read until we have the entire header line.
176 parts = [p.stdout.read1(io.DEFAULT_BUFFER_SIZE)]
177 while b'\n' not in parts[-1]:
178 parts.append(p.stdout.read1(io.DEFAULT_BUFFER_SIZE))
179 out_bytes = b''.join(parts)
181 header_bytes, content_part = out_bytes.split(b'\n', 1)
182 header = header_bytes.decode('utf-8')
183 if header == '%s missing' % sha1:
184 raise RepositoryException('Cannot cat %s' % sha1)
185 name, content_type, size = header.split()
186 assert name == sha1
187 size = int(size)
189 # Read until we have the entire object plus the trailing newline.
190 content_len = len(content_part)
191 content_parts = [content_part]
192 while content_len < size + 1:
193 content_part = p.stdout.read1(io.DEFAULT_BUFFER_SIZE)
194 content_parts.append(content_part)
195 content_len += len(content_part)
196 content = b''.join(content_parts)[:size]
198 return content_type, content
201 class DiffTreeProcesses:
202 def __init__(self, repo):
203 self._repository = repo
204 self._procs = {}
205 atexit.register(self._shutdown)
207 def _get_process(self, args):
208 args = tuple(args)
209 if args not in self._procs:
210 self._procs[args] = (
211 self._repository.run(['git', 'diff-tree', '--stdin'] + list(args))
212 .encoding(None)
213 .decoding(None)
214 .run_background()
216 return self._procs[args]
218 def _shutdown(self):
219 for p in self._procs.values():
220 with p:
221 p.terminate()
223 def diff_trees(self, args, sha1a, sha1b):
224 p = self._get_process(args)
225 query = ('%s %s\n' % (sha1a, sha1b)).encode('ascii')
226 end = b'EOF\n' # arbitrary string that's not a 40-digit hex number
227 p.stdin.write(query + end)
228 p.stdin.flush()
230 def is_end(parts):
231 tail = parts[-1] if len(parts[-1]) > len(end) else b''.join(parts[-2:])
232 return tail.endswith(b'\n' + end) or tail.endswith(b'\0' + end)
234 parts = [p.stdout.read1(io.DEFAULT_BUFFER_SIZE)]
235 while not is_end(parts):
236 parts.append(p.stdout.read1(io.DEFAULT_BUFFER_SIZE))
238 data = b''.join(parts)
240 assert data.startswith(query)
241 assert data.endswith(end)
242 return data[len(query) : -len(end)]
245 class Repository:
246 """Represents a Git repository."""
248 def __init__(self, directory):
249 self._git_dir = directory
250 self.refs = Refs(self)
251 self._blobs = ObjectCache(lambda sha1: Blob(self, sha1))
252 self._trees = ObjectCache(lambda sha1: Tree(self, sha1))
253 self._commits = ObjectCache(lambda sha1: Commit(self, sha1))
254 self._default_index = None
255 self._default_worktree = None
256 self._default_iw = None
257 self._catfile = CatFileProcess(self)
258 self._difftree = DiffTreeProcesses(self)
260 @property
261 def env(self):
262 return {'GIT_DIR': self._git_dir}
264 @classmethod
265 def default(cls):
266 """Return the default repository."""
267 try:
268 return cls(Run('git', 'rev-parse', '--git-dir').output_one_line())
269 except RunException:
270 raise RepositoryException('Cannot find git repository')
272 @property
273 def current_branch_name(self):
274 """Return the name of the current branch."""
275 return utils.strip_prefix('refs/heads/', self.head_ref)
277 @property
278 def default_index(self):
279 """An :class:`Index` representing the default index file for the repository."""
280 if self._default_index is None:
281 self._default_index = Index.default(self)
282 return self._default_index
284 def temp_index(self):
285 """Return an :class:`Index` representing a new temporary index file."""
286 return TemporaryIndex(self)
288 @property
289 def default_worktree(self):
290 """A :class:`Worktree` representing the default work tree."""
291 if self._default_worktree is None:
292 self._default_worktree = Worktree.default()
293 return self._default_worktree
295 @property
296 def default_iw(self):
297 """:class:`IndexAndWorktree` for repository's default index and work tree."""
298 if self._default_iw is None:
299 self._default_iw = IndexAndWorktree(
300 self.default_index, self.default_worktree
302 return self._default_iw
304 @property
305 def directory(self):
306 return self._git_dir
308 def run(self, args, env=()):
309 return Run(*args).env(add_dict(self.env, env))
311 def cat_object(self, sha1):
312 return self._catfile.cat_file(sha1)
314 def rev_parse(self, rev, discard_stderr=False, object_type='commit'):
315 try:
316 sha1 = (
317 self.run(['git', 'rev-parse', '%s^{%s}' % (rev, object_type)])
318 .discard_stderr(discard_stderr)
319 .output_one_line()
321 except RunException:
322 raise RepositoryException('%s: No such %s' % (rev, object_type))
323 else:
324 return self.get_object(object_type, sha1)
326 def get_blob(self, sha1):
327 return self._blobs[sha1]
329 def get_tree(self, sha1):
330 return self._trees[sha1]
332 def get_commit(self, sha1):
333 return self._commits[sha1]
335 def get_object(self, object_type, sha1):
336 return {
337 Blob.typename: self.get_blob,
338 Tree.typename: self.get_tree,
339 Commit.typename: self.get_commit,
340 }[object_type](sha1)
342 def commit(self, objectdata):
343 return objectdata.commit(self)
345 @property
346 def head_ref(self):
347 try:
348 return self.run(['git', 'symbolic-ref', '-q', 'HEAD']).output_one_line()
349 except RunException:
350 raise DetachedHeadException()
352 def set_head_ref(self, ref, msg):
353 self.run(['git', 'symbolic-ref', '-m', msg, 'HEAD', ref]).no_output()
355 def get_merge_bases(self, commit1, commit2):
356 """Return a list of merge bases of two commits."""
357 sha1_list = self.run(
358 ['git', 'merge-base', '--all', commit1.sha1, commit2.sha1]
359 ).output_lines()
360 return [self.get_commit(sha1) for sha1 in sha1_list]
362 def describe(self, commit):
363 """Use git describe --all on the given commit."""
364 return (
365 self.run(['git', 'describe', '--all', commit.sha1])
366 .discard_stderr()
367 .discard_exitcode()
368 .raw_output()
371 def simple_merge(self, base, ours, theirs):
372 with self.temp_index() as index:
373 result, index_tree = index.merge(base, ours, theirs)
374 return result
376 def apply(self, tree, patch_bytes, quiet):
377 """Apply patch to given tree.
379 Given a :class:`Tree` and a patch, either returns the new :class:`Tree`
380 resulting from successful application of the patch, or None if the patch
381 could not be applied.
384 assert isinstance(tree, Tree)
385 if not patch_bytes:
386 return tree
387 with self.temp_index() as index:
388 index.read_tree(tree)
389 try:
390 index.apply(patch_bytes, quiet)
391 return index.write_tree()
392 except MergeException:
393 return None
395 def submodules(self, tree):
396 """Return list of submodule paths for the given :class:`Tree`."""
397 assert isinstance(tree, Tree)
398 # A simple regex to match submodule entries
399 regex = re.compile(r'160000 commit [0-9a-f]{40}\t(.*)$')
400 # First, use ls-tree to get all the trees and links
401 files = self.run(['git', 'ls-tree', '-d', '-r', '-z', tree.sha1]).output_lines(
402 '\0'
404 # Then extract the paths of any submodules
405 return set(m.group(1) for m in map(regex.match, files) if m)
407 def diff_tree(
408 self,
411 diff_opts=(),
412 pathlimits=(),
413 binary=True,
414 stat=False,
415 full_index=False,
417 """Produce patch (diff) between two trees.
419 Given two :class:`Tree`s ``t1`` and ``t2``, return the patch that takes
420 ``t1`` to ``t2``.
423 assert isinstance(t1, Tree)
424 assert isinstance(t2, Tree)
425 if stat:
426 args = ['--stat', '--summary']
427 args.extend(o for o in diff_opts if o != '--binary')
428 else:
429 args = ['--patch']
430 if binary and '--binary' not in diff_opts:
431 args.append('--binary')
432 if full_index:
433 args.append('--full-index')
434 args.extend(diff_opts)
435 if pathlimits:
436 args.append('--')
437 args.extend(pathlimits)
438 return self._difftree.diff_trees(args, t1.sha1, t2.sha1)
440 def diff_tree_files(self, t1, t2):
441 """Iterate files that differ between two trees.
443 Given two :class:`Tree`s ``t1`` and ``t2``, iterate over all files that differ
444 between the two trees.
446 For each differing file, yield a tuple with the old file mode, the new file
447 mode, the old blob, the new blob, the status, the old filename, and the new
448 filename.
450 Except in case of a copy or a rename, the old and new filenames are identical.
453 assert isinstance(t1, Tree)
454 assert isinstance(t2, Tree)
455 dt = self._difftree.diff_trees(['-r', '-z'], t1.sha1, t2.sha1)
456 i = iter(dt.decode('utf-8').split('\0'))
457 try:
458 while True:
459 x = next(i)
460 if not x:
461 continue
462 omode, nmode, osha1, nsha1, status = x[1:].split(' ')
463 fn1 = next(i)
464 if status[0] in ['C', 'R']:
465 fn2 = next(i)
466 else:
467 fn2 = fn1
468 yield (
469 omode,
470 nmode,
471 self.get_blob(osha1),
472 self.get_blob(nsha1),
473 status,
474 fn1,
475 fn2,
477 except StopIteration:
478 pass
480 def repack(self):
481 """Repack all objects into a single pack."""
482 self.run(['git', 'repack', '-a', '-d', '-f']).run()
484 def copy_notes(self, old_sha1, new_sha1):
485 """Copy Git notes from the old object to the new one."""
486 p = self.run(['git', 'notes', 'copy', old_sha1, new_sha1])
487 p.discard_exitcode().discard_stderr().discard_output()