status: refactor StatusTreeWidget and add docstrings
[git-cola.git] / cola / diffparse.py
blobfb9895011ce93b85169248d8e217e6205a158de6
1 from __future__ import division, absolute_import, unicode_literals
2 import math
3 import re
4 from collections import defaultdict
6 from . import compat
9 _HUNK_HEADER_RE = re.compile(r'^@@ -([0-9,]+) \+([0-9,]+) @@(.*)')
12 class _DiffHunk(object):
13 def __init__(
14 self, old_start, old_count, new_start, new_count, heading, first_line_idx, lines
16 self.old_start = old_start
17 self.old_count = old_count
18 self.new_start = new_start
19 self.new_count = new_count
20 self.heading = heading
21 self.first_line_idx = first_line_idx
22 self.lines = lines
24 @property
25 def last_line_idx(self):
26 return self.first_line_idx + len(self.lines) - 1
29 def parse_range_str(range_str):
30 if ',' in range_str:
31 begin, end = range_str.split(',', 1)
32 return int(begin), int(end)
33 return int(range_str), 1
36 def _format_range(start, count):
37 if count == 1:
38 return str(start)
39 return '%d,%d' % (start, count)
42 def _format_hunk_header(old_start, old_count, new_start, new_count, heading=''):
43 return '@@ -%s +%s @@%s\n' % (
44 _format_range(old_start, old_count),
45 _format_range(new_start, new_count),
46 heading,
50 def _parse_diff(diff_text):
51 hunks = []
52 for line_idx, line in enumerate(diff_text.split('\n')):
53 match = _HUNK_HEADER_RE.match(line)
54 if match:
55 old_start, old_count = parse_range_str(match.group(1))
56 new_start, new_count = parse_range_str(match.group(2))
57 heading = match.group(3)
58 hunks.append(
59 _DiffHunk(
60 old_start,
61 old_count,
62 new_start,
63 new_count,
64 heading,
65 line_idx,
66 lines=[line + '\n'],
69 elif line and hunks:
70 hunks[-1].lines.append(line + '\n')
71 return hunks
74 def digits(number):
75 """Return the number of digits needed to display a number"""
76 if number >= 0:
77 result = int(math.log10(number)) + 1
78 else:
79 result = 1
80 return result
83 class Counter(object):
84 """Keep track of a diff range's values"""
86 def __init__(self, value=0, max_value=-1):
87 self.value = value
88 self.max_value = max_value
89 self._initial_max_value = max_value
91 def reset(self):
92 """Reset the max counter and return self for convenience"""
93 self.max_value = self._initial_max_value
94 return self
96 def parse(self, range_str):
97 """Parse a diff range and setup internal state"""
98 start, count = parse_range_str(range_str)
99 self.value = start
100 self.max_value = max(start + count, self.max_value)
102 def tick(self, amount=1):
103 """Return the current value and increment to the next"""
104 value = self.value
105 self.value += amount
106 return value
109 class DiffLines(object):
110 """Parse diffs and gather line numbers"""
112 EMPTY = -1
113 DASH = -2
115 def __init__(self):
116 self.valid = True
117 self.merge = False
119 # diff <old> <new>
120 # merge <ours> <theirs> <new>
121 self.old = Counter()
122 self.new = Counter()
123 self.ours = Counter()
124 self.theirs = Counter()
126 def digits(self):
127 return digits(
128 max(
129 self.old.max_value,
130 self.new.max_value,
131 self.ours.max_value,
132 self.theirs.max_value,
136 def parse(self, diff_text):
137 lines = []
138 DIFF_STATE = 1
139 state = INITIAL_STATE = 0
140 merge = self.merge = False
141 NO_NEWLINE = r'\ No newline at end of file'
143 old = self.old.reset()
144 new = self.new.reset()
145 ours = self.ours.reset()
146 theirs = self.theirs.reset()
148 for text in diff_text.split('\n'):
149 if text.startswith('@@ -'):
150 parts = text.split(' ', 4)
151 if parts[0] == '@@' and parts[3] == '@@':
152 state = DIFF_STATE
153 old.parse(parts[1][1:])
154 new.parse(parts[2][1:])
155 lines.append((self.DASH, self.DASH))
156 continue
157 if text.startswith('@@@ -'):
158 self.merge = merge = True
159 parts = text.split(' ', 5)
160 if parts[0] == '@@@' and parts[4] == '@@@':
161 state = DIFF_STATE
162 ours.parse(parts[1][1:])
163 theirs.parse(parts[2][1:])
164 new.parse(parts[3][1:])
165 lines.append((self.DASH, self.DASH, self.DASH))
166 continue
167 if state == INITIAL_STATE or text.rstrip() == NO_NEWLINE:
168 if merge:
169 lines.append((self.EMPTY, self.EMPTY, self.EMPTY))
170 else:
171 lines.append((self.EMPTY, self.EMPTY))
172 elif not merge and text.startswith('-'):
173 lines.append((old.tick(), self.EMPTY))
174 elif merge and text.startswith('- '):
175 lines.append((self.EMPTY, theirs.tick(), self.EMPTY))
176 elif merge and text.startswith(' -'):
177 lines.append((self.EMPTY, theirs.tick(), self.EMPTY))
178 elif merge and text.startswith('--'):
179 lines.append((ours.tick(), theirs.tick(), self.EMPTY))
180 elif not merge and text.startswith('+'):
181 lines.append((self.EMPTY, new.tick()))
182 elif merge and text.startswith('++'):
183 lines.append((self.EMPTY, self.EMPTY, new.tick()))
184 elif merge and text.startswith('+ '):
185 lines.append((self.EMPTY, theirs.tick(), new.tick()))
186 elif merge and text.startswith(' +'):
187 lines.append((ours.tick(), self.EMPTY, new.tick()))
188 elif not merge and text.startswith(' '):
189 lines.append((old.tick(), new.tick()))
190 elif merge and text.startswith(' '):
191 lines.append((ours.tick(), theirs.tick(), new.tick()))
192 elif not text:
193 new.tick()
194 old.tick()
195 ours.tick()
196 theirs.tick()
197 else:
198 state = INITIAL_STATE
199 if merge:
200 lines.append((self.EMPTY, self.EMPTY, self.EMPTY))
201 else:
202 lines.append((self.EMPTY, self.EMPTY))
204 return lines
207 class FormatDigits(object):
208 """Format numbers for use in diff line numbers"""
210 DASH = DiffLines.DASH
211 EMPTY = DiffLines.EMPTY
213 def __init__(self, dash='', empty=''):
214 self.fmt = ''
215 self.empty = ''
216 self.dash = ''
217 self._dash = dash or compat.uchr(0xB7)
218 self._empty = empty or ' '
220 def set_digits(self, value):
221 self.fmt = '%%0%dd' % value
222 self.empty = self._empty * value
223 self.dash = self._dash * value
225 def value(self, old, new):
226 old_str = self._format(old)
227 new_str = self._format(new)
228 return '%s %s' % (old_str, new_str)
230 def merge_value(self, old, base, new):
231 old_str = self._format(old)
232 base_str = self._format(base)
233 new_str = self._format(new)
234 return '%s %s %s' % (old_str, base_str, new_str)
236 def number(self, value):
237 return self.fmt % value
239 def _format(self, value):
240 if value == self.DASH:
241 result = self.dash
242 elif value == self.EMPTY:
243 result = self.empty
244 else:
245 result = self.number(value)
246 return result
249 class DiffParser(object):
250 """Parse and rewrite diffs to produce edited patches
252 This parser is used for modifying the worktree and index by constructing
253 temporary patches that are applied using "git apply".
257 def __init__(self, filename, diff_text):
258 self.filename = filename
259 self.hunks = _parse_diff(diff_text)
261 def generate_patch(self, first_line_idx, last_line_idx, reverse=False):
262 """Return a patch containing a subset of the diff"""
264 ADDITION = '+'
265 DELETION = '-'
266 CONTEXT = ' '
267 NO_NEWLINE = '\\'
269 lines = ['--- a/%s\n' % self.filename, '+++ b/%s\n' % self.filename]
271 start_offset = 0
273 for hunk in self.hunks:
274 # skip hunks until we get to the one that contains the first
275 # selected line
276 if hunk.last_line_idx < first_line_idx:
277 continue
278 # once we have processed the hunk that contains the last selected
279 # line, we can stop
280 if hunk.first_line_idx > last_line_idx:
281 break
283 prev_skipped = False
284 counts = defaultdict(int)
285 filtered_lines = []
287 for line_idx, line in enumerate(
288 hunk.lines[1:], start=hunk.first_line_idx + 1
290 line_type, line_content = line[:1], line[1:]
292 if reverse:
293 if line_type == ADDITION:
294 line_type = DELETION
295 elif line_type == DELETION:
296 line_type = ADDITION
298 if not first_line_idx <= line_idx <= last_line_idx:
299 if line_type == ADDITION:
300 # Skip additions that are not selected.
301 prev_skipped = True
302 continue
303 if line_type == DELETION:
304 # Change deletions that are not selected to context.
305 line_type = CONTEXT
306 if line_type == NO_NEWLINE and prev_skipped:
307 # If the line immediately before a "No newline" line was
308 # skipped (because it was an unselected addition) skip
309 # the "No newline" line as well.
310 continue
311 filtered_lines.append(line_type + line_content)
312 counts[line_type] += 1
313 prev_skipped = False
315 # Do not include hunks that, after filtering, have only context
316 # lines (no additions or deletions).
317 if not counts[ADDITION] and not counts[DELETION]:
318 continue
320 old_count = counts[CONTEXT] + counts[DELETION]
321 new_count = counts[CONTEXT] + counts[ADDITION]
323 if reverse:
324 old_start = hunk.new_start
325 else:
326 old_start = hunk.old_start
327 new_start = old_start + start_offset
328 if old_count == 0:
329 new_start += 1
330 if new_count == 0:
331 new_start -= 1
333 start_offset += counts[ADDITION] - counts[DELETION]
335 lines.append(
336 _format_hunk_header(
337 old_start, old_count, new_start, new_count, hunk.heading
340 lines.extend(filtered_lines)
342 # If there are only two lines, that means we did not include any hunks,
343 # so return None.
344 if len(lines) == 2:
345 return None
346 return ''.join(lines)
348 def generate_hunk_patch(self, line_idx, reverse=False):
349 """Return a patch containing the hunk for the specified line only"""
350 hunk = None
351 for hunk in self.hunks:
352 if line_idx <= hunk.last_line_idx:
353 break
354 if hunk is None:
355 return None
356 return self.generate_patch(
357 hunk.first_line_idx, hunk.last_line_idx, reverse=reverse