1 """The L{StackTransaction} class makes it possible to make complex
2 updates to an StGit stack in a safe and convenient way."""
6 from itertools
import takewhile
8 from stgit
import exception
, utils
9 from stgit
.config
import config
10 from stgit
.lib
.git
import CheckoutException
, MergeConflictException
, MergeException
11 from stgit
.lib
.log
import log_external_mods
, log_stack_state
12 from stgit
.out
import out
15 class TransactionException(exception
.StgException
):
16 """Exception raised when something goes wrong with a
17 L{StackTransaction}."""
20 class TransactionHalted(TransactionException
):
21 """Exception raised when a L{StackTransaction} stops part-way through.
22 Used to make a non-local jump from the transaction setup to the
23 part of the transaction code where the transaction is run."""
26 def _print_current_patch(old_applied
, new_applied
):
28 out
.info('Now at patch "%s"' % pn
)
30 if not old_applied
and not new_applied
:
33 now_at(new_applied
[-1])
35 out
.info('No patch applied')
36 elif old_applied
[-1] == new_applied
[-1]:
39 now_at(new_applied
[-1])
42 class _TransPatchMap(dict):
43 """Maps patch names to Commit objects."""
45 def __init__(self
, stack
):
49 def __getitem__(self
, pn
):
51 return super().__getitem
__(pn
)
53 return self
._stack
.patches
.get(pn
).commit
56 class StackTransaction
:
57 """A stack transaction, used for making complex updates to an StGit
58 stack in one single operation that will either succeed or fail
61 The basic theory of operation is the following:
63 1. Create a transaction object.
69 except TransactionHalted:
72 block, update the transaction with e.g. methods like
73 L{pop_patches} and L{push_patch}. This may create new git
74 objects such as commits, but will not write any refs; this means
75 that in case of a fatal error we can just walk away, no clean-up
78 (Some operations may need to touch your index and working tree,
79 though. But they are cleaned up when needed.)
81 3. After the C{try} block -- wheher or not the setup ran to
82 completion or halted part-way through by raising a
83 L{TransactionHalted} exception -- call the transaction's L{run}
84 method. This will either succeed in writing the updated state to
85 your refs and index+worktree, or fail without having done
92 discard_changes
=False,
93 allow_conflicts
=False,
97 """Create a new L{StackTransaction}.
99 @param discard_changes: Discard any changes in index+worktree
100 @type discard_changes: bool
101 @param allow_conflicts: Whether to allow pre-existing conflicts
102 @type allow_conflicts: bool or function of L{StackTransaction}"""
105 self
.patches
= _TransPatchMap(stack
)
106 self
._applied
= list(self
.stack
.patchorder
.applied
)
107 self
._unapplied
= list(self
.stack
.patchorder
.unapplied
)
108 self
._hidden
= list(self
.stack
.patchorder
.hidden
)
110 self
._current
_tree
= self
.stack
.head
.data
.tree
111 self
._base
= self
.stack
.base
112 self
._discard
_changes
= discard_changes
113 self
._bad
_head
= None
114 self
._conflicts
= None
115 if isinstance(allow_conflicts
, bool):
116 self
._allow
_conflicts
= lambda trans
: allow_conflicts
118 self
._allow
_conflicts
= allow_conflicts
119 self
._temp
_index
= self
.temp_index_tree
= None
120 if not allow_bad_head
:
121 self
._assert
_head
_top
_equal
()
123 self
._assert
_index
_worktree
_clean
(check_clean_iw
)
130 def applied(self
, value
):
131 self
._applied
= list(value
)
135 return self
._unapplied
138 def unapplied(self
, value
):
139 self
._unapplied
= list(value
)
146 def hidden(self
, value
):
147 self
._hidden
= list(value
)
150 def all_patches(self
):
151 return self
._applied
+ self
._unapplied
+ self
._hidden
158 def base(self
, value
):
159 assert not self
._applied
or self
.patches
[self
.applied
[0]].data
.parent
== value
163 def temp_index(self
):
164 if not self
._temp
_index
:
165 self
._temp
_index
= self
.stack
.repository
.temp_index()
166 atexit
.register(self
._temp
_index
.delete
)
167 return self
._temp
_index
172 return self
.patches
[self
._applied
[-1]]
179 return self
._bad
_head
184 def head(self
, value
):
185 self
._bad
_head
= value
187 def _assert_head_top_equal(self
):
188 if not self
.stack
.head_top_equal():
190 'HEAD and top are not the same.',
191 'This can happen if you modify a branch with git.',
192 '"stg repair --help" explains more about what to do next.',
196 def _assert_index_worktree_clean(self
, iw
):
197 if not iw
.worktree_clean():
198 self
._halt
('Worktree not clean. Use "refresh" or "reset --hard"')
199 if not iw
.index
.is_clean(self
.stack
.head
):
200 self
._halt
('Index not clean. Use "refresh" or "reset --hard"')
202 def _checkout(self
, tree
, iw
, allow_bad_head
):
203 if not allow_bad_head
:
204 self
._assert
_head
_top
_equal
()
205 if self
._current
_tree
== tree
and not self
._discard
_changes
:
206 # No tree change, but we still want to make sure that
207 # there are no unresolved conflicts. Conflicts
208 # conceptually "belong" to the topmost patch, and just
209 # carrying them along to another patch is confusing.
210 if self
._allow
_conflicts
(self
) or iw
is None or not iw
.index
.conflicts():
212 out
.error('Need to resolve conflicts first')
214 assert iw
is not None
215 if self
._discard
_changes
:
216 iw
.checkout_hard(tree
)
218 iw
.checkout(self
._current
_tree
, tree
)
219 self
._current
_tree
= tree
223 raise TransactionException('Command aborted (all changes rolled back)')
225 def _check_consistency(self
):
226 remaining
= set(self
.all_patches
)
227 for pn
, commit
in self
.patches
.items():
229 assert self
.stack
.patches
.exists(pn
)
231 assert pn
in remaining
233 def abort(self
, iw
=None):
234 # The only state we need to restore is index+worktree.
236 self
._checkout
(self
.stack
.head
.data
.tree
, iw
, allow_bad_head
=True)
239 self
, iw
=None, set_head
=True, allow_bad_head
=False, print_current_patch
=True
241 """Execute the transaction. Will either succeed, or fail (with an
242 exception) and do nothing."""
243 self
._check
_consistency
()
244 log_external_mods(self
.stack
)
251 self
._checkout
(new_head
.data
.tree
, iw
, allow_bad_head
)
252 except CheckoutException
:
253 # We have to abort the transaction.
256 self
.stack
.set_head(new_head
, self
._msg
)
260 out
.error(*([self
._error
] + self
._conflicts
))
262 out
.error(self
._error
)
264 old_applied
= self
.stack
.patchorder
.applied
265 msg
= self
._msg
+ (' (CONFLICT)' if self
._conflicts
else '')
268 for pn
, commit
in self
.patches
.items():
269 if self
.stack
.patches
.exists(pn
):
270 p
= self
.stack
.patches
.get(pn
)
274 p
.set_commit(commit
, msg
)
276 self
.stack
.patches
.new(pn
, commit
, msg
)
277 self
.stack
.patchorder
.set_order(self
._applied
, self
._unapplied
, self
._hidden
)
278 log_stack_state(self
.stack
, msg
)
280 if print_current_patch
:
281 _print_current_patch(old_applied
, self
._applied
)
284 return utils
.STGIT_CONFLICT
286 return utils
.STGIT_SUCCESS
288 def _halt(self
, msg
):
290 raise TransactionHalted(msg
)
293 def _print_popped(popped
):
296 elif len(popped
) == 1:
297 out
.info('Popped %s' % popped
[0])
299 out
.info('Popped %s -- %s' % (popped
[-1], popped
[0]))
301 def pop_patches(self
, p
):
302 """Pop all patches pn for which p(pn) is true. Return the list of
303 other patches that had to be popped to accomplish this. Always
306 for i
in range(len(self
.applied
)):
307 if p(self
.applied
[i
]):
308 popped
= self
.applied
[i
:]
311 popped1
= [pn
for pn
in popped
if not p(pn
)]
312 popped2
= [pn
for pn
in popped
if p(pn
)]
313 self
.unapplied
= popped1
+ popped2
+ self
.unapplied
314 self
._print
_popped
(popped
)
317 def delete_patches(self
, p
, quiet
=False):
318 """Delete all patches pn for which p(pn) is true. Return the list of
319 other patches that had to be popped to accomplish this. Always
322 all_patches
= self
.applied
+ self
.unapplied
+ self
.hidden
323 for i
in range(len(self
.applied
)):
324 if p(self
.applied
[i
]):
325 popped
= self
.applied
[i
:]
328 popped
= [pn
for pn
in popped
if not p(pn
)]
329 self
.unapplied
= popped
+ [pn
for pn
in self
.unapplied
if not p(pn
)]
330 self
.hidden
= [pn
for pn
in self
.hidden
if not p(pn
)]
331 self
._print
_popped
(popped
)
332 for pn
in all_patches
:
334 s
= ['', ' (empty)'][self
.patches
[pn
].data
.is_nochange()]
335 self
.patches
[pn
] = None
337 out
.info('Deleted %s%s' % (pn
, s
))
340 def push_patch(self
, pn
, iw
=None, allow_interactive
=False, already_merged
=False):
341 """Attempt to push the named patch. If this results in conflicts,
342 halts the transaction. If index+worktree are given, spill any
343 conflicts to them."""
344 out
.start('Pushing patch "%s"' % pn
)
345 orig_cd
= self
.patches
[pn
].data
346 cd
= orig_cd
.set_committer(None)
347 oldparent
= cd
.parent
348 cd
= cd
.set_parent(self
.top
)
350 # the resulting patch is empty
351 tree
= cd
.parent
.data
.tree
353 base
= oldparent
.data
.tree
354 ours
= cd
.parent
.data
.tree
356 tree
, self
.temp_index_tree
= self
.temp_index
.merge(
357 base
, ours
, theirs
, self
.temp_index_tree
360 merge_conflict
= False
363 self
._halt
('%s does not apply cleanly' % pn
)
365 self
._checkout
(ours
, iw
, allow_bad_head
=False)
366 except CheckoutException
:
367 self
._halt
('Index/worktree dirty')
369 interactive
= allow_interactive
and config
.getbool('stgit.autoimerge')
370 iw
.merge(base
, ours
, theirs
, interactive
=interactive
)
371 tree
= iw
.index
.write_tree()
372 self
._current
_tree
= tree
374 except MergeConflictException
as e
:
376 merge_conflict
= True
377 self
._conflicts
= e
.conflicts
379 except MergeException
as e
:
381 cd
= cd
.set_tree(tree
)
383 getattr(cd
, a
) != getattr(orig_cd
, a
)
384 for a
in ['parent', 'tree', 'author', 'message']
386 comm
= self
.stack
.repository
.commit(cd
)
388 # When we produce a conflict, we'll run the update()
389 # function defined below _after_ having done the
390 # checkout in run(). To make sure that we check out
391 # the real stack top (as it will look after update()
392 # has been run), set it hard here.
399 elif not merge_conflict
and cd
.is_nochange():
404 # We've just caused conflicts, so we must allow them in
405 # the final checkout.
406 self
._allow
_conflicts
= lambda trans
: True
408 # Update the stack state
410 self
.patches
[pn
] = comm
411 if pn
in self
.hidden
:
416 self
.applied
.append(pn
)
419 self
._halt
("%d merge conflict(s)" % len(self
._conflicts
))
421 def push_tree(self
, pn
):
422 """Push the named patch without updating its tree."""
423 orig_cd
= self
.patches
[pn
].data
424 cd
= orig_cd
.set_committer(None).set_parent(self
.top
)
428 getattr(cd
, a
) != getattr(orig_cd
, a
)
429 for a
in ['parent', 'tree', 'author', 'message']
431 self
.patches
[pn
] = self
.stack
.repository
.commit(cd
)
436 out
.info('Pushed %s%s' % (pn
, s
))
438 if pn
in self
.hidden
:
443 self
.applied
.append(pn
)
446 self
, applied
, unapplied
, hidden
=None, iw
=None, allow_interactive
=False
448 """Push and pop patches to attain the given ordering."""
452 list(takewhile(lambda a
: a
[0] == a
[1], zip(self
.applied
, applied
)))
454 to_pop
= set(self
.applied
[common
:])
455 self
.pop_patches(lambda pn
: pn
in to_pop
)
456 for pn
in applied
[common
:]:
457 self
.push_patch(pn
, iw
, allow_interactive
=allow_interactive
)
459 # We only get here if all the pushes succeeded.
460 assert self
.applied
== applied
461 assert set(self
.unapplied
+ self
.hidden
) == set(unapplied
+ hidden
)
462 self
.unapplied
= unapplied
465 def check_merged(self
, patches
, tree
=None, quiet
=False):
466 """Return a subset of patches already merged."""
468 out
.start('Checking for patches merged upstream')
471 self
.temp_index
.read_tree(tree
)
472 self
.temp_index_tree
= tree
473 elif self
.temp_index_tree
!= self
.stack
.head
.data
.tree
:
474 self
.temp_index
.read_tree(self
.stack
.head
.data
.tree
)
475 self
.temp_index_tree
= self
.stack
.head
.data
.tree
476 for pn
in reversed(patches
):
477 # check whether patch changes can be reversed in the current index
478 cd
= self
.patches
[pn
].data
482 self
.temp_index
.apply_treediff(
488 # The self.temp_index was modified by apply_treediff() so
489 # force read_tree() the next time merge() is used.
490 self
.temp_index_tree
= None
491 except MergeException
:
494 out
.done('%d found' % len(merged
))