Better stack rename abstraction
[stgit.git] / stgit / lib / stack.py
blob0081501ccaff408690595d2a88f4eacf429f4450
1 """A Python class hierarchy wrapping the StGit on-disk metadata."""
3 import os
4 import shutil
6 from stgit import utils
7 from stgit.compat import fsencode_utf8
8 from stgit.config import config
9 from stgit.exception import StackException
10 from stgit.lib import log, stackupgrade
11 from stgit.lib.git import CommitData, Repository
12 from stgit.lib.git.branch import Branch, BranchException
13 from stgit.lib.objcache import ObjectCache
16 class Patch:
17 """Represents an StGit patch. This class is mainly concerned with
18 reading and writing the on-disk representation of a patch."""
20 def __init__(self, stack, name):
21 self._stack = stack
22 self.name = name
24 @property
25 def _ref(self):
26 return 'refs/patches/%s/%s' % (self._stack.name, self.name)
28 @property
29 def _log_ref(self):
30 return self._ref + '.log'
32 @property
33 def commit(self):
34 return self._stack.repository.refs.get(self._ref)
36 @property
37 def _compat_dir(self):
38 return os.path.join(self._stack.directory, 'patches', self.name)
40 def _write_compat_files(self, new_commit, msg):
41 """Write files used by the old infrastructure."""
43 def write(name, val, multiline=False):
44 fn = os.path.join(self._compat_dir, name)
45 fn = fsencode_utf8(fn)
46 if val:
47 utils.write_string(fn, val, multiline)
48 elif os.path.isfile(fn):
49 os.remove(fn)
51 def write_patchlog():
52 try:
53 old_log = [self._stack.repository.refs.get(self._log_ref)]
54 except KeyError:
55 old_log = []
56 cd = CommitData(
57 tree=new_commit.data.tree,
58 parents=old_log,
59 message='%s\t%s' % (msg, new_commit.sha1),
61 c = self._stack.repository.commit(cd)
62 self._stack.repository.refs.set(self._log_ref, c, msg)
63 return c
65 d = new_commit.data
66 write('authname', d.author.name)
67 write('authemail', d.author.email)
68 write('authdate', str(d.author.date))
69 write('commname', d.committer.name)
70 write('commemail', d.committer.email)
71 write('description', d.message_str, multiline=True)
72 write('log', write_patchlog().sha1)
73 write('top', new_commit.sha1)
74 write('bottom', d.parent.sha1)
75 try:
76 old_top_sha1 = self.commit.sha1
77 old_bottom_sha1 = self.commit.data.parent.sha1
78 except KeyError:
79 old_top_sha1 = None
80 old_bottom_sha1 = None
81 write('top.old', old_top_sha1)
82 write('bottom.old', old_bottom_sha1)
84 def _delete_compat_files(self):
85 if os.path.isdir(self._compat_dir):
86 for f in os.listdir(self._compat_dir):
87 os.remove(os.path.join(self._compat_dir, f))
88 os.rmdir(self._compat_dir)
89 try:
90 # this compatibility log ref might not exist
91 self._stack.repository.refs.delete(self._log_ref)
92 except KeyError:
93 pass
95 def set_commit(self, commit, msg):
96 try:
97 old_sha1 = self.commit.sha1
98 except KeyError:
99 old_sha1 = None
100 self._write_compat_files(commit, msg)
101 self._stack.repository.refs.set(self._ref, commit, msg)
102 if old_sha1 and old_sha1 != commit.sha1:
103 self._stack.repository.copy_notes(old_sha1, commit.sha1)
105 def set_name(self, name, msg):
106 commit = self.commit
107 self.delete()
108 self.name = name
109 self._write_compat_files(commit, msg)
110 self._stack.repository.refs.set(self._ref, commit, msg)
112 def delete(self):
113 self._delete_compat_files()
114 self._stack.repository.refs.delete(self._ref)
116 def is_empty(self):
117 return self.commit.data.is_nochange()
119 def files(self):
120 """Return the set of files this patch touches."""
121 fs = set()
122 for dt in self._stack.repository.diff_tree_files(
123 self.commit.data.parent.data.tree,
124 self.commit.data.tree,
126 _, _, _, _, _, oldname, newname = dt
127 fs.add(oldname)
128 fs.add(newname)
129 return fs
132 class PatchOrder:
133 """Keeps track of patch order, and which patches are applied.
134 Works with patch names, not actual patches."""
136 def __init__(self, stack):
137 self._stack = stack
138 self._lists = {}
140 def _read_file(self, fn):
141 return tuple(utils.read_strings(os.path.join(self._stack.directory, fn)))
143 def _write_file(self, fn, val):
144 utils.write_strings(os.path.join(self._stack.directory, fn), val)
146 def _get_list(self, name):
147 if name not in self._lists:
148 self._lists[name] = self._read_file(name)
149 return self._lists[name]
151 def _set_list(self, name, val):
152 val = tuple(val)
153 if val != self._lists.get(name, None):
154 self._lists[name] = val
155 self._write_file(name, val)
157 @property
158 def applied(self):
159 return self._get_list('applied')
161 @property
162 def unapplied(self):
163 return self._get_list('unapplied')
165 @property
166 def hidden(self):
167 return self._get_list('hidden')
169 @property
170 def all(self):
171 return self.applied + self.unapplied + self.hidden
173 @property
174 def all_visible(self):
175 return self.applied + self.unapplied
177 def set_order(self, applied, unapplied, hidden):
178 self._set_list('applied', applied)
179 self._set_list('unapplied', unapplied)
180 self._set_list('hidden', hidden)
182 def rename_patch(self, old_name, new_name):
183 for list_name in ['applied', 'unapplied', 'hidden']:
184 patch_list = list(self._get_list(list_name))
185 try:
186 index = patch_list.index(old_name)
187 except ValueError:
188 continue
189 else:
190 patch_list[index] = new_name
191 self._set_list(list_name, patch_list)
192 break
193 else:
194 raise AssertionError('"%s" not found in patchorder' % old_name)
196 @staticmethod
197 def create(stackdir):
198 """Create the PatchOrder specific files"""
199 utils.create_empty_file(os.path.join(stackdir, 'applied'))
200 utils.create_empty_file(os.path.join(stackdir, 'unapplied'))
201 utils.create_empty_file(os.path.join(stackdir, 'hidden'))
204 class Patches:
205 """Creates L{Patch} objects. Makes sure there is only one such object
206 per patch."""
208 def __init__(self, stack):
209 self._stack = stack
211 def create_patch(name):
212 p = Patch(self._stack, name)
213 p.commit # raise exception if the patch doesn't exist
214 return p
216 self._patches = ObjectCache(create_patch) # name -> Patch
218 def exists(self, name):
219 try:
220 self.get(name)
221 return True
222 except KeyError:
223 return False
225 def get(self, name):
226 return self._patches[name]
228 def is_name_valid(self, name):
229 if '/' in name:
230 # TODO slashes in patch names could be made to be okay
231 return False
232 ref_name = 'refs/patches/%s/%s' % (self._stack.name, name)
233 p = self._stack.repository.run(['git', 'check-ref-format', ref_name])
234 p.returns([0, 1]).discard_stderr().discard_output()
235 return p.exitcode == 0
237 def new(self, name, commit, msg):
238 assert name not in self._patches
239 assert self.is_name_valid(name)
240 p = Patch(self._stack, name)
241 p.set_commit(commit, msg)
242 self._patches[name] = p
243 return p
246 class Stack(Branch):
247 """Represents an StGit stack (that is, a git branch with some extra
248 metadata)."""
250 _repo_subdir = 'patches'
252 def __init__(self, repository, name):
253 Branch.__init__(self, repository, name)
254 self.patchorder = PatchOrder(self)
255 self.patches = Patches(self)
256 if not stackupgrade.update_to_current_format_version(repository, name):
257 raise StackException('%s: branch not initialized' % name)
259 @property
260 def directory(self):
261 return os.path.join(self.repository.directory, self._repo_subdir, self.name)
263 @property
264 def base(self):
265 if self.patchorder.applied:
266 return self.patches.get(self.patchorder.applied[0]).commit.data.parent
267 else:
268 return self.head
270 @property
271 def top(self):
272 """Commit of the topmost patch, or the stack base if no patches are
273 applied."""
274 if self.patchorder.applied:
275 return self.patches.get(self.patchorder.applied[-1]).commit
276 else:
277 # When no patches are applied, base == head.
278 return self.head
280 def head_top_equal(self):
281 if not self.patchorder.applied:
282 return True
283 top = self.patches.get(self.patchorder.applied[-1]).commit
284 return self.head == top
286 def set_parents(self, remote, branch):
287 if remote:
288 self.set_parent_remote(remote)
289 if branch:
290 self.set_parent_branch(branch)
291 config.set('branch.%s.stgit.parentbranch' % self.name, branch)
293 @property
294 def protected(self):
295 return config.getbool('branch.%s.stgit.protect' % self.name)
297 @protected.setter
298 def protected(self, protect):
299 protect_key = 'branch.%s.stgit.protect' % self.name
300 if protect:
301 config.set(protect_key, 'true')
302 elif self.protected:
303 config.unset(protect_key)
305 def cleanup(self):
306 assert not self.protected, 'attempt to delete protected stack'
307 for pn in self.patchorder.all:
308 patch = self.patches.get(pn)
309 patch.delete()
310 shutil.rmtree(self.directory)
311 config.remove_section('branch.%s.stgit' % self.name)
313 def rename(self, new_name):
314 old_name = self.name
315 patch_names = self.patchorder.all
316 super(Stack, self).rename(new_name)
317 renames = []
318 for pn in patch_names:
319 renames.append(
321 'refs/patches/%s/%s' % (old_name, pn),
322 'refs/patches/%s/%s' % (new_name, pn),
325 renames.append(
327 'refs/patches/%s/%s.log' % (old_name, pn),
328 'refs/patches/%s/%s.log' % (new_name, pn),
332 renames.append((log.log_ref(old_name), log.log_ref(new_name)))
334 self.repository.refs.rename('rename %s to %s' % (old_name, new_name), *renames)
336 config.rename_section(
337 'branch.%s.stgit' % old_name,
338 'branch.%s.stgit' % new_name,
341 utils.rename(
342 os.path.join(self.repository.directory, self._repo_subdir),
343 old_name,
344 new_name,
347 def rename_patch(self, old_name, new_name, msg='rename'):
348 if new_name == old_name:
349 raise StackException('New patch name same as old: "%s"' % new_name)
350 elif self.patches.exists(new_name):
351 raise StackException('Patch already exists: "%s"' % new_name)
352 elif not self.patches.is_name_valid(new_name):
353 raise StackException('Invalid patch name: "%s"' % new_name)
354 elif not self.patches.exists(old_name):
355 raise StackException('Unknown patch name: "%s"' % old_name)
356 self.patchorder.rename_patch(old_name, new_name)
357 self.patches.get(old_name).set_name(new_name, msg)
359 @classmethod
360 def initialise(cls, repository, name=None, switch_to=False):
361 """Initialise a Git branch to handle patch series.
363 @param repository: The L{Repository} where the L{Stack} will be created
364 @param name: The name of the L{Stack}
366 if not name:
367 name = repository.current_branch_name
368 # make sure that the corresponding Git branch exists
369 branch = Branch(repository, name)
371 dir = os.path.join(repository.directory, cls._repo_subdir, name)
372 if os.path.exists(dir):
373 raise StackException('%s: branch already initialized' % name)
375 if switch_to:
376 branch.switch_to()
378 # create the stack directory and files
379 utils.create_dirs(dir)
380 compat_dir = os.path.join(dir, 'patches')
381 utils.create_dirs(compat_dir)
382 PatchOrder.create(dir)
383 config.set(
384 stackupgrade.format_version_key(name), str(stackupgrade.FORMAT_VERSION)
387 return repository.get_stack(name)
389 @classmethod
390 def create(
391 cls,
392 repository,
393 name,
394 create_at=None,
395 parent_remote=None,
396 parent_branch=None,
397 switch_to=False,
399 """Create and initialise a Git branch returning the L{Stack} object.
401 @param repository: The L{Repository} where the L{Stack} will be created
402 @param name: The name of the L{Stack}
403 @param create_at: The Git id used as the base for the newly created
404 Git branch
405 @param parent_remote: The name of the remote Git branch
406 @param parent_branch: The name of the parent Git branch
408 branch = Branch.create(repository, name, create_at=create_at)
409 try:
410 stack = cls.initialise(repository, name, switch_to=switch_to)
411 except (BranchException, StackException):
412 branch.delete()
413 raise
414 stack.set_parents(parent_remote, parent_branch)
415 return stack
418 class StackRepository(Repository):
419 """A git L{Repository<Repository>} with some added StGit-specific
420 operations."""
422 def __init__(self, *args, **kwargs):
423 Repository.__init__(self, *args, **kwargs)
424 self._stacks = {} # name -> Stack
426 @property
427 def current_stack(self):
428 return self.get_stack()
430 def get_stack(self, name=None):
431 if not name:
432 name = self.current_branch_name
433 if name not in self._stacks:
434 self._stacks[name] = Stack(self, name)
435 return self._stacks[name]