Avoid importing invalid and duplicate patch names
[stgit.git] / stgit / lib / stack.py
blob1ca97b03e96c235099c6396004d1b3db389b42d4
1 """A Python class hierarchy wrapping the StGit on-disk metadata."""
2 import re
4 from stgit.config import config
5 from stgit.exception import StackException
6 from stgit.lib import log, stackupgrade
7 from stgit.lib.git import Repository
8 from stgit.lib.git.branch import Branch, BranchException
9 from stgit.lib.objcache import ObjectCache
12 def _stack_state_ref(stack_name):
13 """Reference to stack state metadata. A.k.a. the stack's "log"."""
14 return 'refs/heads/%s.stgit' % (stack_name,)
17 def _patch_ref(stack_name, patch_name):
18 """Reference to a named patch's commit."""
19 return 'refs/patches/%s/%s' % (stack_name, patch_name)
22 def _patches_ref_prefix(stack_name):
23 return _patch_ref(stack_name, '')
26 class Patch:
27 """Represents an StGit patch."""
29 def __init__(self, stack, name):
30 self._stack = stack
31 self.name = name
33 @property
34 def _ref(self):
35 return _patch_ref(self._stack.name, self.name)
37 @property
38 def commit(self):
39 return self._stack.repository.refs.get(self._ref)
41 def set_commit(self, commit, msg):
42 try:
43 old_sha1 = self.commit.sha1
44 except KeyError:
45 old_sha1 = None
46 self._stack.repository.refs.set(self._ref, commit, msg)
47 if old_sha1 and old_sha1 != commit.sha1:
48 self._stack.repository.copy_notes(old_sha1, commit.sha1)
50 def set_name(self, name, msg):
51 commit = self.commit
52 self.delete()
53 self.name = name
54 self._stack.repository.refs.set(self._ref, commit, msg)
56 def delete(self):
57 self._stack.repository.refs.delete(self._ref)
59 def is_empty(self):
60 return self.commit.data.is_nochange()
62 def files(self):
63 """Return the set of files this patch touches."""
64 fs = set()
65 for dt in self._stack.repository.diff_tree_files(
66 self.commit.data.parent.data.tree,
67 self.commit.data.tree,
69 _, _, _, _, _, oldname, newname = dt
70 fs.add(oldname)
71 fs.add(newname)
72 return fs
75 class PatchOrder:
76 """Keeps track of patch order, and which patches are applied.
78 Works with patch names, not actual patches.
80 """
82 def __init__(self, state):
83 self._applied = tuple(state.applied)
84 self._unapplied = tuple(state.unapplied)
85 self._hidden = tuple(state.hidden)
87 @property
88 def applied(self):
89 return self._applied
91 @property
92 def unapplied(self):
93 return self._unapplied
95 @property
96 def hidden(self):
97 return self._hidden
99 @property
100 def all(self):
101 return self.applied + self.unapplied + self.hidden
103 @property
104 def all_visible(self):
105 return self.applied + self.unapplied
107 def set_order(self, applied, unapplied, hidden):
108 self._applied = tuple(applied)
109 self._unapplied = tuple(unapplied)
110 self._hidden = tuple(hidden)
112 def rename_patch(self, old_name, new_name):
113 for attr in ['_applied', '_unapplied', '_hidden']:
114 patch_list = list(getattr(self, attr))
115 try:
116 index = patch_list.index(old_name)
117 except ValueError:
118 continue
119 else:
120 patch_list[index] = new_name
121 setattr(self, attr, tuple(patch_list))
122 break
123 else:
124 raise AssertionError('"%s" not found in patchorder' % old_name)
127 class Patches:
128 """Manage the set of :class:`Patch` objects.
130 Ensures a single :class:`Patch` instance per patch.
134 def __init__(self, stack, state):
135 self._stack = stack
137 def create_patch(name):
138 p = Patch(self._stack, name)
139 p.commit # raise exception if the patch doesn't exist
140 return p
142 self._patches = ObjectCache(create_patch) # name -> Patch
144 def exists(self, name):
145 try:
146 self.get(name)
147 return True
148 except KeyError:
149 return False
151 def get(self, name):
152 return self._patches[name]
154 def is_name_valid(self, name):
155 if '/' in name:
156 # TODO slashes in patch names could be made to be okay
157 return False
158 ref = _patch_ref(self._stack.name, name)
159 p = self._stack.repository.run(['git', 'check-ref-format', ref])
160 p.returns([0, 1]).discard_stderr().discard_output()
161 return p.exitcode == 0
163 def new(self, name, commit, msg):
164 assert name not in self._patches
165 assert self.is_name_valid(name)
166 p = Patch(self._stack, name)
167 p.set_commit(commit, msg)
168 self._patches[name] = p
169 return p
171 def make_name(self, raw, unique=True, lower=True):
172 """Make a unique and valid patch name from provided raw name.
174 The raw name may come from a filename, commit message, or email subject line.
176 The generated patch name will meet the rules of `git check-ref-format` along
177 with some additional StGit patch name rules.
180 default_name = 'patch'
182 for line in raw.split('\n'):
183 if line:
184 break
186 if not line:
187 line = default_name
189 if lower:
190 line = line.lower()
192 parts = []
193 for part in line.split('/'):
194 # fmt: off
195 part = re.sub(r'\.lock$', '', part) # Disallowed in Git refs
196 part = re.sub(r'^\.+|\.+$', '', part) # Cannot start or end with '.'
197 part = re.sub(r'\.+', '.', part) # No consecutive '.'
198 part = re.sub(r'[^\w.]+', '-', part) # Non-word and whitespace to dashes
199 part = re.sub(r'-+', '-', part) # Squash consecutive dashes
200 part = re.sub(r'^-+|-+$', '', part) # Remove leading and trailing dashes
201 # fmt: on
202 if part:
203 parts.append(part)
205 long_name = '/'.join(parts)
207 # TODO: slashes could be allowed in the future.
208 long_name = long_name.replace('/', '-')
210 if not long_name:
211 long_name = default_name
213 assert self.is_name_valid(long_name)
215 name_len = config.getint('stgit.namelength')
217 words = long_name.split('-')
218 short_name = words[0]
219 for word in words[1:]:
220 new_name = '%s-%s' % (short_name, word)
221 if name_len <= 0 or len(new_name) <= name_len:
222 short_name = new_name
223 else:
224 break
225 assert self.is_name_valid(short_name)
227 if not unique:
228 return short_name
230 unique_name = short_name
231 while self.exists(unique_name):
232 m = re.match(r'(.*?)(-)?(\d+)$', unique_name)
233 if m:
234 base, sep, n_str = m.groups()
235 n = int(n_str) + 1
236 if sep:
237 unique_name = '%s%s%d' % (base, sep, n)
238 else:
239 unique_name = '%s%d' % (base, n)
240 else:
241 unique_name = '%s-1' % unique_name
243 assert self.is_name_valid(unique_name)
244 return unique_name
247 class Stack(Branch):
248 """Represents a StGit stack.
250 A StGit stack is a Git branch with extra metadata for patch stack state.
254 def __init__(self, repository, name):
255 super().__init__(repository, name)
256 if not stackupgrade.update_to_current_format_version(repository, name):
257 raise StackException('%s: branch not initialized' % name)
258 state = log.get_stack_state(self.repository, self.state_ref)
259 self._ensure_patch_refs(repository, name, state)
260 self.patchorder = PatchOrder(state)
261 self.patches = Patches(self, state)
263 @property
264 def base(self):
265 if self.patchorder.applied:
266 return self.patches.get(self.patchorder.applied[0]).commit.data.parent
267 else:
268 return self.head
270 @property
271 def top(self):
272 """Commit of the topmost patch, or the stack base if no patches are applied."""
273 if self.patchorder.applied:
274 return self.patches.get(self.patchorder.applied[-1]).commit
275 else:
276 # When no patches are applied, base == head.
277 return self.head
279 def head_top_equal(self):
280 if not self.patchorder.applied:
281 return True
282 top = self.patches.get(self.patchorder.applied[-1]).commit
283 return self.head == top
285 def set_parents(self, remote, branch):
286 if remote:
287 self.set_parent_remote(remote)
288 if branch:
289 self.set_parent_branch(branch)
290 config.set('branch.%s.stgit.parentbranch' % self.name, branch)
292 @property
293 def protected(self):
294 return config.getbool('branch.%s.stgit.protect' % self.name)
296 @protected.setter
297 def protected(self, protect):
298 protect_key = 'branch.%s.stgit.protect' % self.name
299 if protect:
300 config.set(protect_key, 'true')
301 elif self.protected:
302 config.unset(protect_key)
304 @property
305 def state_ref(self):
306 return _stack_state_ref(self.name)
308 def cleanup(self):
309 assert not self.protected, 'attempt to delete protected stack'
310 for pn in self.patchorder.all:
311 patch = self.patches.get(pn)
312 patch.delete()
313 self.repository.refs.delete(self.state_ref)
314 config.remove_section('branch.%s.stgit' % self.name)
316 def clear_log(self, msg='clear log'):
317 state_commit = log.StackState.from_stack(
318 prev=None, stack=self, message=msg
319 ).commit_state()
320 self.repository.refs.set(self.state_ref, state_commit, msg=msg)
322 def rename(self, new_name):
323 old_name = self.name
324 patch_names = self.patchorder.all
325 super().rename(new_name)
326 renames = []
327 for pn in patch_names:
328 renames.append((_patch_ref(old_name, pn), _patch_ref(new_name, pn)))
329 renames.append((_stack_state_ref(old_name), _stack_state_ref(new_name)))
331 self.repository.refs.rename('rename %s to %s' % (old_name, new_name), *renames)
333 config.rename_section(
334 'branch.%s.stgit' % old_name,
335 'branch.%s.stgit' % new_name,
338 def rename_patch(self, old_name, new_name, msg='rename'):
339 if new_name == old_name:
340 raise StackException('New patch name same as old: "%s"' % new_name)
341 elif self.patches.exists(new_name):
342 raise StackException('Patch already exists: "%s"' % new_name)
343 elif not self.patches.is_name_valid(new_name):
344 raise StackException('Invalid patch name: "%s"' % new_name)
345 elif not self.patches.exists(old_name):
346 raise StackException('Unknown patch name: "%s"' % old_name)
347 self.patchorder.rename_patch(old_name, new_name)
348 self.patches.get(old_name).set_name(new_name, msg)
350 def clone(self, clone_name, msg):
351 clone = self.create(
352 self.repository,
353 name=clone_name,
354 msg=msg,
355 create_at=self.base,
356 parent_remote=self.parent_remote,
357 parent_branch=self.name,
360 for pn in self.patchorder.all_visible:
361 patch = self.patches.get(pn)
362 clone.patches.new(pn, patch.commit, 'clone from %s' % self.name)
364 clone.patchorder.set_order(
365 applied=[],
366 unapplied=self.patchorder.all_visible,
367 hidden=[],
370 prefix = 'branch.%s.' % self.name
371 clone_prefix = 'branch.%s.' % clone_name
372 for k, v in list(config.getstartswith(prefix)):
373 clone_key = k.replace(prefix, clone_prefix, 1)
374 config.set(clone_key, v)
376 self.repository.refs.set(
377 clone.state_ref,
378 self.repository.refs.get(self.state_ref),
379 msg=msg,
382 return clone
384 @classmethod
385 def initialise(cls, repository, name=None, msg='initialise', switch_to=False):
386 """Initialise a Git branch to handle patch stack.
388 :param repository: :class:`Repository` where the :class:`Stack` will be created
389 :param name: the name of the :class:`Stack`
392 if not name:
393 name = repository.current_branch_name
394 # make sure that the corresponding Git branch exists
395 branch = Branch(repository, name)
397 stack_state_ref = _stack_state_ref(name)
398 if repository.refs.exists(stack_state_ref):
399 raise StackException('%s: stack already initialized' % name)
401 if switch_to:
402 branch.switch_to()
404 state_commit = log.StackState(
405 repository,
406 prev=None,
407 head=branch.head,
408 applied=[],
409 unapplied=[],
410 hidden=[],
411 patches={},
412 message=msg,
413 ).commit_state()
414 repository.refs.set(stack_state_ref, state_commit, msg)
416 return repository.get_stack(name)
418 @classmethod
419 def create(
420 cls,
421 repository,
422 name,
423 msg,
424 create_at=None,
425 parent_remote=None,
426 parent_branch=None,
427 switch_to=False,
429 """Create and initialise a Git branch returning the :class:`Stack` object.
431 :param repository: :class:`Repository` where the :class:`Stack` will be created
432 :param name: name of the :class:`Stack`
433 :param msg: message to use in newly created log
434 :param create_at: Git id used as the base for the newly created Git branch
435 :param parent_remote: name of the parent remote Git branch
436 :param parent_branch: name of the parent Git branch
439 branch = Branch.create(repository, name, create_at=create_at)
440 try:
441 stack = cls.initialise(repository, name, msg, switch_to=switch_to)
442 except (BranchException, StackException):
443 branch.delete()
444 raise
445 stack.set_parents(parent_remote, parent_branch)
446 return stack
448 @staticmethod
449 def _ensure_patch_refs(repository, stack_name, state):
450 """Ensure patch refs in repository match those from stack state."""
451 patch_ref_prefix = _patches_ref_prefix(stack_name)
453 state_patch_ref_map = {
454 _patch_ref(stack_name, pn): commit for pn, commit in state.patches.items()
457 state_patch_refs = set(state_patch_ref_map)
458 repo_patch_refs = {
459 ref for ref in repository.refs if ref.startswith(patch_ref_prefix)
462 delete_patch_refs = repo_patch_refs - state_patch_refs
463 create_patch_refs = state_patch_refs - repo_patch_refs
464 update_patch_refs = {
466 for ref in state_patch_refs - create_patch_refs
467 if state_patch_ref_map[ref].sha1 != repository.refs.get(ref).sha1
470 if create_patch_refs or update_patch_refs or delete_patch_refs:
471 repository.refs.batch_update(
472 msg='restore from stack state',
473 create=[(ref, state_patch_ref_map[ref]) for ref in create_patch_refs],
474 update=[(ref, state_patch_ref_map[ref]) for ref in update_patch_refs],
475 delete=delete_patch_refs,
479 class StackRepository(Repository):
480 """A Git :class:`Repository` with some added StGit-specific operations."""
482 def __init__(self, directory):
483 super().__init__(directory)
484 self._stacks = {} # name -> Stack
486 @property
487 def current_stack(self):
488 return self.get_stack()
490 def get_stack(self, name=None):
491 if not name:
492 name = self.current_branch_name
493 if name not in self._stacks:
494 self._stacks[name] = Stack(self, name)
495 return self._stacks[name]