diffparse: allow partial staging of deleted files
[git-cola.git] / cola / diffparse.py
blob13f6a6f3a42ee12e3814500eaa44e41d40b3ce70
1 from __future__ import division, absolute_import, unicode_literals
2 import math
3 import re
4 from collections import defaultdict
6 from . import compat
9 _HUNK_HEADER_RE = re.compile(r'^@@ -([0-9,]+) \+([0-9,]+) @@(.*)')
12 class _DiffHunk(object):
14 def __init__(self, old_start, old_count, new_start, new_count, heading,
15 first_line_idx, lines):
16 self.old_start = old_start
17 self.old_count = old_count
18 self.new_start = new_start
19 self.new_count = new_count
20 self.heading = heading
21 self.first_line_idx = first_line_idx
22 self.lines = lines
24 @property
25 def last_line_idx(self):
26 return self.first_line_idx + len(self.lines) - 1
29 def parse_range_str(range_str):
30 if ',' in range_str:
31 begin, end = range_str.split(',', 1)
32 return int(begin), int(end)
33 return int(range_str), 1
36 def _format_range(start, count):
37 if count == 1:
38 return str(start)
39 return '%d,%d' % (start, count)
42 def _format_hunk_header(old_start, old_count, new_start, new_count,
43 heading=''):
44 return ('@@ -%s +%s @@%s\n'
45 % (_format_range(old_start, old_count),
46 _format_range(new_start, new_count), heading))
49 def _parse_diff(diff_text):
50 hunks = []
51 for line_idx, line in enumerate(diff_text.split('\n')):
52 match = _HUNK_HEADER_RE.match(line)
53 if match:
54 old_start, old_count = parse_range_str(match.group(1))
55 new_start, new_count = parse_range_str(match.group(2))
56 heading = match.group(3)
57 hunks.append(_DiffHunk(old_start, old_count,
58 new_start, new_count,
59 heading, line_idx, lines=[line + '\n']))
60 elif line and hunks:
61 hunks[-1].lines.append(line + '\n')
62 return hunks
65 def digits(number):
66 """Return the number of digits needed to display a number"""
67 if number >= 0:
68 result = int(math.log10(number)) + 1
69 else:
70 result = 1
71 return result
74 class Counter(object):
75 """Keep track of a diff range's values"""
77 def __init__(self, value=0, max_value=-1):
78 self.value = value
79 self.max_value = max_value
80 self._initial_max_value = max_value
82 def reset(self):
83 """Reset the max counter and return self for convenience"""
84 self.max_value = self._initial_max_value
85 return self
87 def parse(self, range_str):
88 """Parse a diff range and setup internal state"""
89 start, count = parse_range_str(range_str)
90 self.value = start
91 self.max_value = max(start + count, self.max_value)
93 def tick(self, amount=1):
94 """Return the current value and increment to the next"""
95 value = self.value
96 self.value += amount
97 return value
100 class DiffLines(object):
101 """Parse diffs and gather line numbers"""
103 EMPTY = -1
104 DASH = -2
106 def __init__(self):
107 self.valid = True
108 self.merge = False
110 # diff <old> <new>
111 # merge <ours> <theirs> <new>
112 self.old = Counter()
113 self.new = Counter()
114 self.ours = Counter()
115 self.theirs = Counter()
117 def digits(self):
118 return digits(max(self.old.max_value, self.new.max_value,
119 self.ours.max_value, self.theirs.max_value))
121 def parse(self, diff_text):
122 lines = []
123 DIFF_STATE = 1
124 state = INITIAL_STATE = 0
125 merge = self.merge = False
126 NO_NEWLINE = r'\ No newline at end of file'
128 old = self.old.reset()
129 new = self.new.reset()
130 ours = self.ours.reset()
131 theirs = self.theirs.reset()
133 for text in diff_text.split('\n'):
134 if text.startswith('@@ -'):
135 parts = text.split(' ', 4)
136 if parts[0] == '@@' and parts[3] == '@@':
137 state = DIFF_STATE
138 old.parse(parts[1][1:])
139 new.parse(parts[2][1:])
140 lines.append((self.DASH, self.DASH))
141 continue
142 if text.startswith('@@@ -'):
143 self.merge = merge = True
144 parts = text.split(' ', 5)
145 if parts[0] == '@@@' and parts[4] == '@@@':
146 state = DIFF_STATE
147 ours.parse(parts[1][1:])
148 theirs.parse(parts[2][1:])
149 new.parse(parts[3][1:])
150 lines.append((self.DASH, self.DASH, self.DASH))
151 continue
152 if state == INITIAL_STATE or text.rstrip() == NO_NEWLINE:
153 if merge:
154 lines.append((self.EMPTY, self.EMPTY, self.EMPTY))
155 else:
156 lines.append((self.EMPTY, self.EMPTY))
157 elif not merge and text.startswith('-'):
158 lines.append((old.tick(), self.EMPTY))
159 elif merge and text.startswith('- '):
160 lines.append((self.EMPTY, theirs.tick(), self.EMPTY))
161 elif merge and text.startswith(' -'):
162 lines.append((self.EMPTY, theirs.tick(), self.EMPTY))
163 elif merge and text.startswith('--'):
164 lines.append((ours.tick(), theirs.tick(), self.EMPTY))
165 elif not merge and text.startswith('+'):
166 lines.append((self.EMPTY, new.tick()))
167 elif merge and text.startswith('++'):
168 lines.append((self.EMPTY, self.EMPTY, new.tick()))
169 elif merge and text.startswith('+ '):
170 lines.append((self.EMPTY, theirs.tick(), new.tick()))
171 elif merge and text.startswith(' +'):
172 lines.append((ours.tick(), self.EMPTY, new.tick()))
173 elif not merge and text.startswith(' '):
174 lines.append((old.tick(), new.tick()))
175 elif merge and text.startswith(' '):
176 lines.append((ours.tick(), theirs.tick(), new.tick()))
177 elif not text:
178 new.tick()
179 old.tick()
180 ours.tick()
181 theirs.tick()
182 else:
183 state = INITIAL_STATE
184 if merge:
185 lines.append((self.EMPTY, self.EMPTY, self.EMPTY))
186 else:
187 lines.append((self.EMPTY, self.EMPTY))
189 return lines
192 class FormatDigits(object):
193 """Format numbers for use in diff line numbers"""
195 DASH = DiffLines.DASH
196 EMPTY = DiffLines.EMPTY
198 def __init__(self, dash='', empty=''):
199 self.fmt = ''
200 self.empty = ''
201 self.dash = ''
202 self._dash = dash or compat.uchr(0xb7)
203 self._empty = empty or ' '
205 def set_digits(self, value):
206 self.fmt = ('%%0%dd' % value)
207 self.empty = (self._empty * value)
208 self.dash = (self._dash * value)
210 def value(self, old, new):
211 old_str = self._format(old)
212 new_str = self._format(new)
213 return '%s %s' % (old_str, new_str)
215 def merge_value(self, old, base, new):
216 old_str = self._format(old)
217 base_str = self._format(base)
218 new_str = self._format(new)
219 return '%s %s %s' % (old_str, base_str, new_str)
221 def number(self, value):
222 return self.fmt % value
224 def _format(self, value):
225 if value == self.DASH:
226 result = self.dash
227 elif value == self.EMPTY:
228 result = self.empty
229 else:
230 result = self.number(value)
231 return result
234 class DiffParser(object):
235 """Parse and rewrite diffs to produce edited patches
237 This parser is used for modifying the worktree and index by constructing
238 temporary patches that are applied using "git apply".
242 def __init__(self, filename, diff_text):
243 self.filename = filename
244 self.hunks = _parse_diff(diff_text)
246 def generate_patch(self, first_line_idx, last_line_idx,
247 reverse=False):
248 """Return a patch containing a subset of the diff"""
250 ADDITION = '+'
251 DELETION = '-'
252 CONTEXT = ' '
253 NO_NEWLINE = '\\'
255 lines = ['--- a/%s\n' % self.filename,
256 '+++ b/%s\n' % self.filename]
258 start_offset = 0
260 for hunk in self.hunks:
261 # skip hunks until we get to the one that contains the first
262 # selected line
263 if hunk.last_line_idx < first_line_idx:
264 continue
265 # once we have processed the hunk that contains the last selected
266 # line, we can stop
267 if hunk.first_line_idx > last_line_idx:
268 break
270 prev_skipped = False
271 counts = defaultdict(int)
272 filtered_lines = []
274 for line_idx, line in enumerate(hunk.lines[1:],
275 start=hunk.first_line_idx + 1):
276 line_type, line_content = line[:1], line[1:]
278 if reverse:
279 if line_type == ADDITION:
280 line_type = DELETION
281 elif line_type == DELETION:
282 line_type = ADDITION
284 if not first_line_idx <= line_idx <= last_line_idx:
285 if line_type == ADDITION:
286 # Skip additions that are not selected.
287 prev_skipped = True
288 continue
289 elif line_type == DELETION:
290 # Change deletions that are not selected to context.
291 line_type = CONTEXT
292 if line_type == NO_NEWLINE and prev_skipped:
293 # If the line immediately before a "No newline" line was
294 # skipped (because it was an unselected addition) skip
295 # the "No newline" line as well.
296 continue
297 filtered_lines.append(line_type + line_content)
298 counts[line_type] += 1
299 prev_skipped = False
301 # Do not include hunks that, after filtering, have only context
302 # lines (no additions or deletions).
303 if not counts[ADDITION] and not counts[DELETION]:
304 continue
306 old_count = counts[CONTEXT] + counts[DELETION]
307 new_count = counts[CONTEXT] + counts[ADDITION]
309 if reverse:
310 old_start = hunk.new_start
311 else:
312 old_start = hunk.old_start
313 new_start = old_start + start_offset
314 if old_count == 0:
315 new_start += 1
316 if new_count == 0:
317 new_start -= 1
319 start_offset += counts[ADDITION] - counts[DELETION]
321 lines.append(_format_hunk_header(old_start, old_count,
322 new_start, new_count,
323 hunk.heading))
324 lines.extend(filtered_lines)
326 # If there are only two lines, that means we did not include any hunks,
327 # so return None.
328 if len(lines) == 2:
329 return None
330 return ''.join(lines)
332 def generate_hunk_patch(self, line_idx, reverse=False):
333 """Return a patch containing the hunk for the specified line only"""
334 hunk = None
335 for hunk in self.hunks:
336 if line_idx <= hunk.last_line_idx:
337 break
338 if hunk is None:
339 return None
340 return self.generate_patch(hunk.first_line_idx, hunk.last_line_idx,
341 reverse=reverse)