Stack "state" naming
[stgit.git] / stgit / lib / stack.py
blob65faad6e4bdfc265f804b610286f834e02d6cb88
1 """A Python class hierarchy wrapping the StGit on-disk metadata."""
3 import os
4 import shutil
6 from stgit import utils
7 from stgit.compat import fsencode_utf8
8 from stgit.config import config
9 from stgit.exception import StackException
10 from stgit.lib import log, stackupgrade
11 from stgit.lib.git import CommitData, Repository
12 from stgit.lib.git.branch import Branch, BranchException
13 from stgit.lib.objcache import ObjectCache
16 def _stack_state_ref(stack_name):
17 """Reference to stack state metadata. A.k.a. the stack's "log"."""
18 return 'refs/heads/%s.stgit' % (stack_name,)
21 def _patch_ref(stack_name, patch_name):
22 """Reference to a named patch's commit."""
23 return 'refs/patches/%s/%s' % (stack_name, patch_name)
26 def _patch_log_ref(stack_name, patch_name):
27 """Reference to a named patch's log."""
28 return 'refs/patches/%s/%s.log' % (stack_name, patch_name)
31 class Patch:
32 """Represents an StGit patch. This class is mainly concerned with
33 reading and writing the on-disk representation of a patch."""
35 def __init__(self, stack, name):
36 self._stack = stack
37 self.name = name
39 @property
40 def _ref(self):
41 return _patch_ref(self._stack.name, self.name)
43 @property
44 def _log_ref(self):
45 return _patch_log_ref(self._stack.name, self.name)
47 @property
48 def commit(self):
49 return self._stack.repository.refs.get(self._ref)
51 @property
52 def _compat_dir(self):
53 return os.path.join(self._stack.directory, 'patches', self.name)
55 def _write_compat_files(self, new_commit, msg):
56 """Write files used by the old infrastructure."""
58 def write(name, val, multiline=False):
59 fn = os.path.join(self._compat_dir, name)
60 fn = fsencode_utf8(fn)
61 if val:
62 utils.write_string(fn, val, multiline)
63 elif os.path.isfile(fn):
64 os.remove(fn)
66 def write_patchlog():
67 try:
68 old_log = [self._stack.repository.refs.get(self._log_ref)]
69 except KeyError:
70 old_log = []
71 cd = CommitData(
72 tree=new_commit.data.tree,
73 parents=old_log,
74 message='%s\t%s' % (msg, new_commit.sha1),
76 c = self._stack.repository.commit(cd)
77 self._stack.repository.refs.set(self._log_ref, c, msg)
78 return c
80 d = new_commit.data
81 write('authname', d.author.name)
82 write('authemail', d.author.email)
83 write('authdate', str(d.author.date))
84 write('commname', d.committer.name)
85 write('commemail', d.committer.email)
86 write('description', d.message_str, multiline=True)
87 write('log', write_patchlog().sha1)
88 write('top', new_commit.sha1)
89 write('bottom', d.parent.sha1)
90 try:
91 old_top_sha1 = self.commit.sha1
92 old_bottom_sha1 = self.commit.data.parent.sha1
93 except KeyError:
94 old_top_sha1 = None
95 old_bottom_sha1 = None
96 write('top.old', old_top_sha1)
97 write('bottom.old', old_bottom_sha1)
99 def _delete_compat_files(self):
100 if os.path.isdir(self._compat_dir):
101 for f in os.listdir(self._compat_dir):
102 os.remove(os.path.join(self._compat_dir, f))
103 os.rmdir(self._compat_dir)
104 try:
105 # this compatibility log ref might not exist
106 self._stack.repository.refs.delete(self._log_ref)
107 except KeyError:
108 pass
110 def set_commit(self, commit, msg):
111 try:
112 old_sha1 = self.commit.sha1
113 except KeyError:
114 old_sha1 = None
115 self._write_compat_files(commit, msg)
116 self._stack.repository.refs.set(self._ref, commit, msg)
117 if old_sha1 and old_sha1 != commit.sha1:
118 self._stack.repository.copy_notes(old_sha1, commit.sha1)
120 def set_name(self, name, msg):
121 commit = self.commit
122 self.delete()
123 self.name = name
124 self._write_compat_files(commit, msg)
125 self._stack.repository.refs.set(self._ref, commit, msg)
127 def delete(self):
128 self._delete_compat_files()
129 self._stack.repository.refs.delete(self._ref)
131 def is_empty(self):
132 return self.commit.data.is_nochange()
134 def files(self):
135 """Return the set of files this patch touches."""
136 fs = set()
137 for dt in self._stack.repository.diff_tree_files(
138 self.commit.data.parent.data.tree,
139 self.commit.data.tree,
141 _, _, _, _, _, oldname, newname = dt
142 fs.add(oldname)
143 fs.add(newname)
144 return fs
147 class PatchOrder:
148 """Keeps track of patch order, and which patches are applied.
149 Works with patch names, not actual patches."""
151 def __init__(self, stack):
152 self._stack = stack
153 self._lists = {}
155 def _read_file(self, fn):
156 return tuple(utils.read_strings(os.path.join(self._stack.directory, fn)))
158 def _write_file(self, fn, val):
159 utils.write_strings(os.path.join(self._stack.directory, fn), val)
161 def _get_list(self, name):
162 if name not in self._lists:
163 self._lists[name] = self._read_file(name)
164 return self._lists[name]
166 def _set_list(self, name, val):
167 val = tuple(val)
168 if val != self._lists.get(name, None):
169 self._lists[name] = val
170 self._write_file(name, val)
172 @property
173 def applied(self):
174 return self._get_list('applied')
176 @property
177 def unapplied(self):
178 return self._get_list('unapplied')
180 @property
181 def hidden(self):
182 return self._get_list('hidden')
184 @property
185 def all(self):
186 return self.applied + self.unapplied + self.hidden
188 @property
189 def all_visible(self):
190 return self.applied + self.unapplied
192 def set_order(self, applied, unapplied, hidden):
193 self._set_list('applied', applied)
194 self._set_list('unapplied', unapplied)
195 self._set_list('hidden', hidden)
197 def rename_patch(self, old_name, new_name):
198 for list_name in ['applied', 'unapplied', 'hidden']:
199 patch_list = list(self._get_list(list_name))
200 try:
201 index = patch_list.index(old_name)
202 except ValueError:
203 continue
204 else:
205 patch_list[index] = new_name
206 self._set_list(list_name, patch_list)
207 break
208 else:
209 raise AssertionError('"%s" not found in patchorder' % old_name)
211 @staticmethod
212 def create(stackdir):
213 """Create the PatchOrder specific files"""
214 utils.create_empty_file(os.path.join(stackdir, 'applied'))
215 utils.create_empty_file(os.path.join(stackdir, 'unapplied'))
216 utils.create_empty_file(os.path.join(stackdir, 'hidden'))
219 class Patches:
220 """Creates L{Patch} objects. Makes sure there is only one such object
221 per patch."""
223 def __init__(self, stack):
224 self._stack = stack
226 def create_patch(name):
227 p = Patch(self._stack, name)
228 p.commit # raise exception if the patch doesn't exist
229 return p
231 self._patches = ObjectCache(create_patch) # name -> Patch
233 def exists(self, name):
234 try:
235 self.get(name)
236 return True
237 except KeyError:
238 return False
240 def get(self, name):
241 return self._patches[name]
243 def is_name_valid(self, name):
244 if '/' in name:
245 # TODO slashes in patch names could be made to be okay
246 return False
247 ref = _patch_ref(self._stack.name, name)
248 p = self._stack.repository.run(['git', 'check-ref-format', ref])
249 p.returns([0, 1]).discard_stderr().discard_output()
250 return p.exitcode == 0
252 def new(self, name, commit, msg):
253 assert name not in self._patches
254 assert self.is_name_valid(name)
255 p = Patch(self._stack, name)
256 p.set_commit(commit, msg)
257 self._patches[name] = p
258 return p
261 class Stack(Branch):
262 """Represents an StGit stack (that is, a git branch with some extra
263 metadata)."""
265 _repo_subdir = 'patches'
267 def __init__(self, repository, name):
268 Branch.__init__(self, repository, name)
269 self.patchorder = PatchOrder(self)
270 self.patches = Patches(self)
271 if not stackupgrade.update_to_current_format_version(repository, name):
272 raise StackException('%s: branch not initialized' % name)
274 @property
275 def directory(self):
276 return os.path.join(self.repository.directory, self._repo_subdir, self.name)
278 @property
279 def base(self):
280 if self.patchorder.applied:
281 return self.patches.get(self.patchorder.applied[0]).commit.data.parent
282 else:
283 return self.head
285 @property
286 def top(self):
287 """Commit of the topmost patch, or the stack base if no patches are
288 applied."""
289 if self.patchorder.applied:
290 return self.patches.get(self.patchorder.applied[-1]).commit
291 else:
292 # When no patches are applied, base == head.
293 return self.head
295 def head_top_equal(self):
296 if not self.patchorder.applied:
297 return True
298 top = self.patches.get(self.patchorder.applied[-1]).commit
299 return self.head == top
301 def set_parents(self, remote, branch):
302 if remote:
303 self.set_parent_remote(remote)
304 if branch:
305 self.set_parent_branch(branch)
306 config.set('branch.%s.stgit.parentbranch' % self.name, branch)
308 @property
309 def protected(self):
310 return config.getbool('branch.%s.stgit.protect' % self.name)
312 @protected.setter
313 def protected(self, protect):
314 protect_key = 'branch.%s.stgit.protect' % self.name
315 if protect:
316 config.set(protect_key, 'true')
317 elif self.protected:
318 config.unset(protect_key)
320 @property
321 def state_ref(self):
322 return _stack_state_ref(self.name)
324 def cleanup(self):
325 assert not self.protected, 'attempt to delete protected stack'
326 for pn in self.patchorder.all:
327 patch = self.patches.get(pn)
328 patch.delete()
329 self.repository.refs.delete(self.state_ref)
330 shutil.rmtree(self.directory)
331 config.remove_section('branch.%s.stgit' % self.name)
333 def clear_log(self, msg='clear log'):
334 new_stack_state = log.StackState.from_stack(prev=None, stack=self, message=msg)
335 new_stack_state.write_commit()
336 self.repository.refs.set(self.state_ref, new_stack_state.commit, msg=msg)
338 def rename(self, new_name):
339 old_name = self.name
340 patch_names = self.patchorder.all
341 super(Stack, self).rename(new_name)
342 renames = []
343 for pn in patch_names:
344 renames.append((_patch_ref(old_name, pn), _patch_ref(new_name, pn)))
345 renames.append((_patch_log_ref(old_name, pn), _patch_log_ref(new_name, pn)))
346 renames.append((_stack_state_ref(old_name), _stack_state_ref(new_name)))
348 self.repository.refs.rename('rename %s to %s' % (old_name, new_name), *renames)
350 config.rename_section(
351 'branch.%s.stgit' % old_name,
352 'branch.%s.stgit' % new_name,
355 utils.rename(
356 os.path.join(self.repository.directory, self._repo_subdir),
357 old_name,
358 new_name,
361 def rename_patch(self, old_name, new_name, msg='rename'):
362 if new_name == old_name:
363 raise StackException('New patch name same as old: "%s"' % new_name)
364 elif self.patches.exists(new_name):
365 raise StackException('Patch already exists: "%s"' % new_name)
366 elif not self.patches.is_name_valid(new_name):
367 raise StackException('Invalid patch name: "%s"' % new_name)
368 elif not self.patches.exists(old_name):
369 raise StackException('Unknown patch name: "%s"' % old_name)
370 self.patchorder.rename_patch(old_name, new_name)
371 self.patches.get(old_name).set_name(new_name, msg)
373 def clone(self, clone_name, msg):
374 clone = self.create(
375 self.repository,
376 name=clone_name,
377 msg=msg,
378 create_at=self.base,
379 parent_remote=self.parent_remote,
380 parent_branch=self.name,
383 for pn in self.patchorder.all_visible:
384 patch = self.patches.get(pn)
385 clone.patches.new(pn, patch.commit, 'clone from %s' % self.name)
387 clone.patchorder.set_order(
388 applied=[],
389 unapplied=self.patchorder.all_visible,
390 hidden=[],
393 prefix = 'branch.%s.' % self.name
394 clone_prefix = 'branch.%s.' % clone_name
395 for k, v in list(config.getstartswith(prefix)):
396 clone_key = k.replace(prefix, clone_prefix, 1)
397 config.set(clone_key, v)
399 self.repository.refs.set(
400 clone.state_ref,
401 self.repository.refs.get(self.state_ref),
402 msg=msg,
405 return clone
407 @classmethod
408 def initialise(cls, repository, name=None, msg='initialise', switch_to=False):
409 """Initialise a Git branch to handle patch series.
411 @param repository: The L{Repository} where the L{Stack} will be created
412 @param name: The name of the L{Stack}
414 if not name:
415 name = repository.current_branch_name
416 # make sure that the corresponding Git branch exists
417 branch = Branch(repository, name)
419 stack_state_ref = _stack_state_ref(name)
420 if repository.refs.exists(stack_state_ref):
421 raise StackException('%s: stack already initialized' % name)
423 dir = os.path.join(repository.directory, cls._repo_subdir, name)
424 if os.path.exists(dir):
425 raise StackException('%s: branch already initialized' % name)
427 if switch_to:
428 branch.switch_to()
430 # create the stack directory and files
431 utils.create_dirs(dir)
432 compat_dir = os.path.join(dir, 'patches')
433 utils.create_dirs(compat_dir)
434 PatchOrder.create(dir)
435 config.set(
436 stackupgrade.format_version_key(name), str(stackupgrade.FORMAT_VERSION)
439 new_stack_state = log.StackState(
440 repository,
441 prev=None,
442 head=branch.head,
443 applied=[],
444 unapplied=[],
445 hidden=[],
446 patches={},
447 message=msg,
449 new_stack_state.write_commit()
450 repository.refs.set(stack_state_ref, new_stack_state.commit, msg)
452 return repository.get_stack(name)
454 @classmethod
455 def create(
456 cls,
457 repository,
458 name,
459 msg,
460 create_at=None,
461 parent_remote=None,
462 parent_branch=None,
463 switch_to=False,
465 """Create and initialise a Git branch returning the L{Stack} object.
467 @param repository: The L{Repository} where the L{Stack} will be created
468 @param name: The name of the L{Stack}
469 @param msg: Message to use in newly created log
470 @param create_at: The Git id used as the base for the newly created Git branch
471 @param parent_remote: The name of the remote Git branch
472 @param parent_branch: The name of the parent Git branch
474 branch = Branch.create(repository, name, create_at=create_at)
475 try:
476 stack = cls.initialise(repository, name, msg, switch_to=switch_to)
477 except (BranchException, StackException):
478 branch.delete()
479 raise
480 stack.set_parents(parent_remote, parent_branch)
481 return stack
484 class StackRepository(Repository):
485 """A git L{Repository<Repository>} with some added StGit-specific
486 operations."""
488 def __init__(self, *args, **kwargs):
489 Repository.__init__(self, *args, **kwargs)
490 self._stacks = {} # name -> Stack
492 @property
493 def current_stack(self):
494 return self.get_stack()
496 def get_stack(self, name=None):
497 if not name:
498 name = self.current_branch_name
499 if name not in self._stacks:
500 self._stacks[name] = Stack(self, name)
501 return self._stacks[name]