Use super() where applicable
[stgit.git] / stgit / lib / transaction.py
blob21bd1a44521bc9d0b91da387727c1610a34e0337
1 """The L{StackTransaction} class makes it possible to make complex
2 updates to an StGit stack in a safe and convenient way."""
5 import atexit
6 from itertools import takewhile
8 from stgit import exception, utils
9 from stgit.config import config
10 from stgit.lib.git import CheckoutException, MergeConflictException, MergeException
11 from stgit.lib.log import log_external_mods, log_stack_state
12 from stgit.out import out
15 class TransactionException(exception.StgException):
16 """Exception raised when something goes wrong with a
17 L{StackTransaction}."""
20 class TransactionHalted(TransactionException):
21 """Exception raised when a L{StackTransaction} stops part-way through.
22 Used to make a non-local jump from the transaction setup to the
23 part of the transaction code where the transaction is run."""
26 def _print_current_patch(old_applied, new_applied):
27 def now_at(pn):
28 out.info('Now at patch "%s"' % pn)
30 if not old_applied and not new_applied:
31 pass
32 elif not old_applied:
33 now_at(new_applied[-1])
34 elif not new_applied:
35 out.info('No patch applied')
36 elif old_applied[-1] == new_applied[-1]:
37 pass
38 else:
39 now_at(new_applied[-1])
42 class _TransPatchMap(dict):
43 """Maps patch names to Commit objects."""
45 def __init__(self, stack):
46 super().__init__()
47 self._stack = stack
49 def __getitem__(self, pn):
50 try:
51 return super().__getitem__(pn)
52 except KeyError:
53 return self._stack.patches.get(pn).commit
56 class StackTransaction:
57 """A stack transaction, used for making complex updates to an StGit
58 stack in one single operation that will either succeed or fail
59 cleanly.
61 The basic theory of operation is the following:
63 1. Create a transaction object.
65 2. Inside a::
67 try
68 ...
69 except TransactionHalted:
70 pass
72 block, update the transaction with e.g. methods like
73 L{pop_patches} and L{push_patch}. This may create new git
74 objects such as commits, but will not write any refs; this means
75 that in case of a fatal error we can just walk away, no clean-up
76 required.
78 (Some operations may need to touch your index and working tree,
79 though. But they are cleaned up when needed.)
81 3. After the C{try} block -- wheher or not the setup ran to
82 completion or halted part-way through by raising a
83 L{TransactionHalted} exception -- call the transaction's L{run}
84 method. This will either succeed in writing the updated state to
85 your refs and index+worktree, or fail without having done
86 anything."""
88 def __init__(
89 self,
90 stack,
91 msg,
92 discard_changes=False,
93 allow_conflicts=False,
94 allow_bad_head=False,
95 check_clean_iw=None,
97 """Create a new L{StackTransaction}.
99 @param discard_changes: Discard any changes in index+worktree
100 @type discard_changes: bool
101 @param allow_conflicts: Whether to allow pre-existing conflicts
102 @type allow_conflicts: bool or function of L{StackTransaction}"""
103 self.stack = stack
104 self._msg = msg
105 self.patches = _TransPatchMap(stack)
106 self._applied = list(self.stack.patchorder.applied)
107 self._unapplied = list(self.stack.patchorder.unapplied)
108 self._hidden = list(self.stack.patchorder.hidden)
109 self._error = None
110 self._current_tree = self.stack.head.data.tree
111 self._base = self.stack.base
112 self._discard_changes = discard_changes
113 self._bad_head = None
114 self._conflicts = None
115 if isinstance(allow_conflicts, bool):
116 self._allow_conflicts = lambda trans: allow_conflicts
117 else:
118 self._allow_conflicts = allow_conflicts
119 self._temp_index = self.temp_index_tree = None
120 if not allow_bad_head:
121 self._assert_head_top_equal()
122 if check_clean_iw:
123 self._assert_index_worktree_clean(check_clean_iw)
125 @property
126 def applied(self):
127 return self._applied
129 @applied.setter
130 def applied(self, value):
131 self._applied = list(value)
133 @property
134 def unapplied(self):
135 return self._unapplied
137 @unapplied.setter
138 def unapplied(self, value):
139 self._unapplied = list(value)
141 @property
142 def hidden(self):
143 return self._hidden
145 @hidden.setter
146 def hidden(self, value):
147 self._hidden = list(value)
149 @property
150 def all_patches(self):
151 return self._applied + self._unapplied + self._hidden
153 @property
154 def base(self):
155 return self._base
157 @base.setter
158 def base(self, value):
159 assert not self._applied or self.patches[self.applied[0]].data.parent == value
160 self._base = value
162 @property
163 def temp_index(self):
164 if not self._temp_index:
165 self._temp_index = self.stack.repository.temp_index()
166 atexit.register(self._temp_index.delete)
167 return self._temp_index
169 @property
170 def top(self):
171 if self._applied:
172 return self.patches[self._applied[-1]]
173 else:
174 return self._base
176 @property
177 def head(self):
178 if self._bad_head:
179 return self._bad_head
180 else:
181 return self.top
183 @head.setter
184 def head(self, value):
185 self._bad_head = value
187 def _assert_head_top_equal(self):
188 if not self.stack.head_top_equal():
189 out.error(
190 'HEAD and top are not the same.',
191 'This can happen if you modify a branch with git.',
192 '"stg repair --help" explains more about what to do next.',
194 self._abort()
196 def _assert_index_worktree_clean(self, iw):
197 if not iw.worktree_clean():
198 self._halt('Worktree not clean. Use "refresh" or "reset --hard"')
199 if not iw.index.is_clean(self.stack.head):
200 self._halt('Index not clean. Use "refresh" or "reset --hard"')
202 def _checkout(self, tree, iw, allow_bad_head):
203 if not allow_bad_head:
204 self._assert_head_top_equal()
205 if self._current_tree == tree and not self._discard_changes:
206 # No tree change, but we still want to make sure that
207 # there are no unresolved conflicts. Conflicts
208 # conceptually "belong" to the topmost patch, and just
209 # carrying them along to another patch is confusing.
210 if self._allow_conflicts(self) or iw is None or not iw.index.conflicts():
211 return
212 out.error('Need to resolve conflicts first')
213 self._abort()
214 assert iw is not None
215 if self._discard_changes:
216 iw.checkout_hard(tree)
217 else:
218 iw.checkout(self._current_tree, tree)
219 self._current_tree = tree
221 @staticmethod
222 def _abort():
223 raise TransactionException('Command aborted (all changes rolled back)')
225 def _check_consistency(self):
226 remaining = set(self.all_patches)
227 for pn, commit in self.patches.items():
228 if commit is None:
229 assert self.stack.patches.exists(pn)
230 else:
231 assert pn in remaining
233 def abort(self, iw=None):
234 # The only state we need to restore is index+worktree.
235 if iw:
236 self._checkout(self.stack.head.data.tree, iw, allow_bad_head=True)
238 def run(
239 self, iw=None, set_head=True, allow_bad_head=False, print_current_patch=True
241 """Execute the transaction. Will either succeed, or fail (with an
242 exception) and do nothing."""
243 self._check_consistency()
244 log_external_mods(self.stack)
245 new_head = self.head
247 # Set branch head.
248 if set_head:
249 if iw:
250 try:
251 self._checkout(new_head.data.tree, iw, allow_bad_head)
252 except CheckoutException:
253 # We have to abort the transaction.
254 self.abort(iw)
255 self._abort()
256 self.stack.set_head(new_head, self._msg)
258 if self._error:
259 if self._conflicts:
260 out.error(*([self._error] + self._conflicts))
261 else:
262 out.error(self._error)
264 old_applied = self.stack.patchorder.applied
265 msg = self._msg + (' (CONFLICT)' if self._conflicts else '')
267 # Write patches.
268 for pn, commit in self.patches.items():
269 if self.stack.patches.exists(pn):
270 p = self.stack.patches.get(pn)
271 if commit is None:
272 p.delete()
273 else:
274 p.set_commit(commit, msg)
275 else:
276 self.stack.patches.new(pn, commit, msg)
277 self.stack.patchorder.set_order(self._applied, self._unapplied, self._hidden)
278 log_stack_state(self.stack, msg)
280 if print_current_patch:
281 _print_current_patch(old_applied, self._applied)
283 if self._error:
284 return utils.STGIT_CONFLICT
285 else:
286 return utils.STGIT_SUCCESS
288 def _halt(self, msg):
289 self._error = msg
290 raise TransactionHalted(msg)
292 @staticmethod
293 def _print_popped(popped):
294 if len(popped) == 0:
295 pass
296 elif len(popped) == 1:
297 out.info('Popped %s' % popped[0])
298 else:
299 out.info('Popped %s -- %s' % (popped[-1], popped[0]))
301 def pop_patches(self, p):
302 """Pop all patches pn for which p(pn) is true. Return the list of
303 other patches that had to be popped to accomplish this. Always
304 succeeds."""
305 popped = []
306 for i in range(len(self.applied)):
307 if p(self.applied[i]):
308 popped = self.applied[i:]
309 del self.applied[i:]
310 break
311 popped1 = [pn for pn in popped if not p(pn)]
312 popped2 = [pn for pn in popped if p(pn)]
313 self.unapplied = popped1 + popped2 + self.unapplied
314 self._print_popped(popped)
315 return popped1
317 def delete_patches(self, p, quiet=False):
318 """Delete all patches pn for which p(pn) is true. Return the list of
319 other patches that had to be popped to accomplish this. Always
320 succeeds."""
321 popped = []
322 all_patches = self.applied + self.unapplied + self.hidden
323 for i in range(len(self.applied)):
324 if p(self.applied[i]):
325 popped = self.applied[i:]
326 del self.applied[i:]
327 break
328 popped = [pn for pn in popped if not p(pn)]
329 self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
330 self.hidden = [pn for pn in self.hidden if not p(pn)]
331 self._print_popped(popped)
332 for pn in all_patches:
333 if p(pn):
334 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
335 self.patches[pn] = None
336 if not quiet:
337 out.info('Deleted %s%s' % (pn, s))
338 return popped
340 def push_patch(self, pn, iw=None, allow_interactive=False, already_merged=False):
341 """Attempt to push the named patch. If this results in conflicts,
342 halts the transaction. If index+worktree are given, spill any
343 conflicts to them."""
344 out.start('Pushing patch "%s"' % pn)
345 orig_cd = self.patches[pn].data
346 cd = orig_cd.set_committer(None)
347 oldparent = cd.parent
348 cd = cd.set_parent(self.top)
349 if already_merged:
350 # the resulting patch is empty
351 tree = cd.parent.data.tree
352 else:
353 base = oldparent.data.tree
354 ours = cd.parent.data.tree
355 theirs = cd.tree
356 tree, self.temp_index_tree = self.temp_index.merge(
357 base, ours, theirs, self.temp_index_tree
359 s = ''
360 merge_conflict = False
361 if not tree:
362 if iw is None:
363 self._halt('%s does not apply cleanly' % pn)
364 try:
365 self._checkout(ours, iw, allow_bad_head=False)
366 except CheckoutException:
367 self._halt('Index/worktree dirty')
368 try:
369 interactive = allow_interactive and config.getbool('stgit.autoimerge')
370 iw.merge(base, ours, theirs, interactive=interactive)
371 tree = iw.index.write_tree()
372 self._current_tree = tree
373 s = 'modified'
374 except MergeConflictException as e:
375 tree = ours
376 merge_conflict = True
377 self._conflicts = e.conflicts
378 s = 'conflict'
379 except MergeException as e:
380 self._halt(str(e))
381 cd = cd.set_tree(tree)
382 if any(
383 getattr(cd, a) != getattr(orig_cd, a)
384 for a in ['parent', 'tree', 'author', 'message']
386 comm = self.stack.repository.commit(cd)
387 if merge_conflict:
388 # When we produce a conflict, we'll run the update()
389 # function defined below _after_ having done the
390 # checkout in run(). To make sure that we check out
391 # the real stack top (as it will look after update()
392 # has been run), set it hard here.
393 self.head = comm
394 else:
395 comm = None
396 s = 'unmodified'
397 if already_merged:
398 s = 'merged'
399 elif not merge_conflict and cd.is_nochange():
400 s = 'empty'
401 out.done(s)
403 if merge_conflict:
404 # We've just caused conflicts, so we must allow them in
405 # the final checkout.
406 self._allow_conflicts = lambda trans: True
408 # Update the stack state
409 if comm:
410 self.patches[pn] = comm
411 if pn in self.hidden:
412 x = self.hidden
413 else:
414 x = self.unapplied
415 del x[x.index(pn)]
416 self.applied.append(pn)
418 if merge_conflict:
419 self._halt("%d merge conflict(s)" % len(self._conflicts))
421 def push_tree(self, pn):
422 """Push the named patch without updating its tree."""
423 orig_cd = self.patches[pn].data
424 cd = orig_cd.set_committer(None).set_parent(self.top)
426 s = ''
427 if any(
428 getattr(cd, a) != getattr(orig_cd, a)
429 for a in ['parent', 'tree', 'author', 'message']
431 self.patches[pn] = self.stack.repository.commit(cd)
432 else:
433 s = ' (unmodified)'
434 if cd.is_nochange():
435 s = ' (empty)'
436 out.info('Pushed %s%s' % (pn, s))
438 if pn in self.hidden:
439 x = self.hidden
440 else:
441 x = self.unapplied
442 del x[x.index(pn)]
443 self.applied.append(pn)
445 def reorder_patches(
446 self, applied, unapplied, hidden=None, iw=None, allow_interactive=False
448 """Push and pop patches to attain the given ordering."""
449 if hidden is None:
450 hidden = self.hidden
451 common = len(
452 list(takewhile(lambda a: a[0] == a[1], zip(self.applied, applied)))
454 to_pop = set(self.applied[common:])
455 self.pop_patches(lambda pn: pn in to_pop)
456 for pn in applied[common:]:
457 self.push_patch(pn, iw, allow_interactive=allow_interactive)
459 # We only get here if all the pushes succeeded.
460 assert self.applied == applied
461 assert set(self.unapplied + self.hidden) == set(unapplied + hidden)
462 self.unapplied = unapplied
463 self.hidden = hidden
465 def check_merged(self, patches, tree=None, quiet=False):
466 """Return a subset of patches already merged."""
467 if not quiet:
468 out.start('Checking for patches merged upstream')
469 merged = []
470 if tree:
471 self.temp_index.read_tree(tree)
472 self.temp_index_tree = tree
473 elif self.temp_index_tree != self.stack.head.data.tree:
474 self.temp_index.read_tree(self.stack.head.data.tree)
475 self.temp_index_tree = self.stack.head.data.tree
476 for pn in reversed(patches):
477 # check whether patch changes can be reversed in the current index
478 cd = self.patches[pn].data
479 if cd.is_nochange():
480 continue
481 try:
482 self.temp_index.apply_treediff(
483 cd.tree,
484 cd.parent.data.tree,
485 quiet=True,
487 merged.append(pn)
488 # The self.temp_index was modified by apply_treediff() so
489 # force read_tree() the next time merge() is used.
490 self.temp_index_tree = None
491 except MergeException:
492 pass
493 if not quiet:
494 out.done('%d found' % len(merged))
495 return merged