Merge pull request #1237 from scop/fix/editor-fallback
[git-cola.git] / cola / diffparse.py
blobaacf389929a0da8f7b689c445c9bd1519e75d635
1 from __future__ import absolute_import, division, print_function, unicode_literals
2 import math
3 import re
4 from collections import defaultdict
6 from . import compat
9 _HUNK_HEADER_RE = re.compile(r'^@@ -([0-9,]+) \+([0-9,]+) @@(.*)')
12 class _DiffHunk(object):
13 def __init__(
14 self, old_start, old_count, new_start, new_count, heading, first_line_idx, lines
16 self.old_start = old_start
17 self.old_count = old_count
18 self.new_start = new_start
19 self.new_count = new_count
20 self.heading = heading
21 self.first_line_idx = first_line_idx
22 self.lines = lines
24 @property
25 def last_line_idx(self):
26 return self.first_line_idx + len(self.lines) - 1
29 def parse_range_str(range_str):
30 if ',' in range_str:
31 begin, end = range_str.split(',', 1)
32 return int(begin), int(end)
33 return int(range_str), 1
36 def _format_range(start, count):
37 if count == 1:
38 return str(start)
39 return '%d,%d' % (start, count)
42 def _format_hunk_header(old_start, old_count, new_start, new_count, heading=''):
43 return '@@ -%s +%s @@%s\n' % (
44 _format_range(old_start, old_count),
45 _format_range(new_start, new_count),
46 heading,
50 def _parse_diff(diff_text):
51 hunks = []
52 for line_idx, line in enumerate(diff_text.split('\n')):
53 match = _HUNK_HEADER_RE.match(line)
54 if match:
55 old_start, old_count = parse_range_str(match.group(1))
56 new_start, new_count = parse_range_str(match.group(2))
57 heading = match.group(3)
58 hunks.append(
59 _DiffHunk(
60 old_start,
61 old_count,
62 new_start,
63 new_count,
64 heading,
65 line_idx,
66 lines=[line + '\n'],
69 elif line and hunks:
70 hunks[-1].lines.append(line + '\n')
71 return hunks
74 def digits(number):
75 """Return the number of digits needed to display a number"""
76 if number >= 0:
77 result = int(math.log10(number)) + 1
78 else:
79 result = 1
80 return result
83 class Counter(object):
84 """Keep track of a diff range's values"""
86 def __init__(self, value=0, max_value=-1):
87 self.value = value
88 self.max_value = max_value
89 self._initial_max_value = max_value
91 def reset(self):
92 """Reset the max counter and return self for convenience"""
93 self.max_value = self._initial_max_value
94 return self
96 def parse(self, range_str):
97 """Parse a diff range and setup internal state"""
98 start, count = parse_range_str(range_str)
99 self.value = start
100 self.max_value = max(start + count - 1, self.max_value)
102 def tick(self, amount=1):
103 """Return the current value and increment to the next"""
104 value = self.value
105 self.value += amount
106 return value
109 class DiffLines(object):
110 """Parse diffs and gather line numbers"""
112 EMPTY = -1
113 DASH = -2
115 def __init__(self):
116 self.merge = False
118 # diff <old> <new>
119 # merge <ours> <theirs> <new>
120 self.old = Counter()
121 self.new = Counter()
122 self.ours = Counter()
123 self.theirs = Counter()
125 def digits(self):
126 return digits(
127 max(
128 self.old.max_value,
129 self.new.max_value,
130 self.ours.max_value,
131 self.theirs.max_value,
135 def parse(self, diff_text):
136 lines = []
137 DIFF_STATE = 1
138 state = INITIAL_STATE = 0
139 merge = self.merge = False
140 NO_NEWLINE = r'\ No newline at end of file'
142 old = self.old.reset()
143 new = self.new.reset()
144 ours = self.ours.reset()
145 theirs = self.theirs.reset()
147 for text in diff_text.split('\n'):
148 if text.startswith('@@ -'):
149 parts = text.split(' ', 4)
150 if parts[0] == '@@' and parts[3] == '@@':
151 state = DIFF_STATE
152 old.parse(parts[1][1:])
153 new.parse(parts[2][1:])
154 lines.append((self.DASH, self.DASH))
155 continue
156 if text.startswith('@@@ -'):
157 self.merge = merge = True
158 parts = text.split(' ', 5)
159 if parts[0] == '@@@' and parts[4] == '@@@':
160 state = DIFF_STATE
161 ours.parse(parts[1][1:])
162 theirs.parse(parts[2][1:])
163 new.parse(parts[3][1:])
164 lines.append((self.DASH, self.DASH, self.DASH))
165 continue
166 if state == INITIAL_STATE or text.rstrip() == NO_NEWLINE:
167 if merge:
168 lines.append((self.EMPTY, self.EMPTY, self.EMPTY))
169 else:
170 lines.append((self.EMPTY, self.EMPTY))
171 elif not merge and text.startswith('-'):
172 lines.append((old.tick(), self.EMPTY))
173 elif merge and text.startswith('- '):
174 lines.append((ours.tick(), self.EMPTY, self.EMPTY))
175 elif merge and text.startswith(' -'):
176 lines.append((self.EMPTY, theirs.tick(), self.EMPTY))
177 elif merge and text.startswith('--'):
178 lines.append((ours.tick(), theirs.tick(), self.EMPTY))
179 elif not merge and text.startswith('+'):
180 lines.append((self.EMPTY, new.tick()))
181 elif merge and text.startswith('++'):
182 lines.append((self.EMPTY, self.EMPTY, new.tick()))
183 elif merge and text.startswith('+ '):
184 lines.append((self.EMPTY, theirs.tick(), new.tick()))
185 elif merge and text.startswith(' +'):
186 lines.append((ours.tick(), self.EMPTY, new.tick()))
187 elif not merge and text.startswith(' '):
188 lines.append((old.tick(), new.tick()))
189 elif merge and text.startswith(' '):
190 lines.append((ours.tick(), theirs.tick(), new.tick()))
191 elif not text:
192 new.tick()
193 old.tick()
194 ours.tick()
195 theirs.tick()
196 else:
197 state = INITIAL_STATE
198 if merge:
199 lines.append((self.EMPTY, self.EMPTY, self.EMPTY))
200 else:
201 lines.append((self.EMPTY, self.EMPTY))
203 return lines
206 class FormatDigits(object):
207 """Format numbers for use in diff line numbers"""
209 DASH = DiffLines.DASH
210 EMPTY = DiffLines.EMPTY
212 def __init__(self, dash='', empty=''):
213 self.fmt = ''
214 self.empty = ''
215 self.dash = ''
216 self._dash = dash or compat.uchr(0xB7)
217 self._empty = empty or ' '
219 def set_digits(self, value):
220 self.fmt = '%%0%dd' % value
221 self.empty = self._empty * value
222 self.dash = self._dash * value
224 def value(self, old, new):
225 old_str = self._format(old)
226 new_str = self._format(new)
227 return '%s %s' % (old_str, new_str)
229 def merge_value(self, old, base, new):
230 old_str = self._format(old)
231 base_str = self._format(base)
232 new_str = self._format(new)
233 return '%s %s %s' % (old_str, base_str, new_str)
235 def number(self, value):
236 return self.fmt % value
238 def _format(self, value):
239 if value == self.DASH:
240 result = self.dash
241 elif value == self.EMPTY:
242 result = self.empty
243 else:
244 result = self.number(value)
245 return result
248 class DiffParser(object):
249 """Parse and rewrite diffs to produce edited patches
251 This parser is used for modifying the worktree and index by constructing
252 temporary patches that are applied using "git apply".
256 def __init__(self, filename, diff_text):
257 self.filename = filename
258 self.hunks = _parse_diff(diff_text)
260 def generate_patch(self, first_line_idx, last_line_idx, reverse=False):
261 """Return a patch containing a subset of the diff"""
263 ADDITION = '+'
264 DELETION = '-'
265 CONTEXT = ' '
266 NO_NEWLINE = '\\'
268 lines = ['--- a/%s\n' % self.filename, '+++ b/%s\n' % self.filename]
270 start_offset = 0
272 for hunk in self.hunks:
273 # skip hunks until we get to the one that contains the first
274 # selected line
275 if hunk.last_line_idx < first_line_idx:
276 continue
277 # once we have processed the hunk that contains the last selected
278 # line, we can stop
279 if hunk.first_line_idx > last_line_idx:
280 break
282 prev_skipped = False
283 counts = defaultdict(int)
284 filtered_lines = []
286 for line_idx, line in enumerate(
287 hunk.lines[1:], start=hunk.first_line_idx + 1
289 line_type, line_content = line[:1], line[1:]
291 if reverse:
292 if line_type == ADDITION:
293 line_type = DELETION
294 elif line_type == DELETION:
295 line_type = ADDITION
297 if not first_line_idx <= line_idx <= last_line_idx:
298 if line_type == ADDITION:
299 # Skip additions that are not selected.
300 prev_skipped = True
301 continue
302 if line_type == DELETION:
303 # Change deletions that are not selected to context.
304 line_type = CONTEXT
305 if line_type == NO_NEWLINE and prev_skipped:
306 # If the line immediately before a "No newline" line was
307 # skipped (because it was an unselected addition) skip
308 # the "No newline" line as well.
309 continue
310 filtered_lines.append(line_type + line_content)
311 counts[line_type] += 1
312 prev_skipped = False
314 # Do not include hunks that, after filtering, have only context
315 # lines (no additions or deletions).
316 if not counts[ADDITION] and not counts[DELETION]:
317 continue
319 old_count = counts[CONTEXT] + counts[DELETION]
320 new_count = counts[CONTEXT] + counts[ADDITION]
322 if reverse:
323 old_start = hunk.new_start
324 else:
325 old_start = hunk.old_start
326 new_start = old_start + start_offset
327 if old_count == 0:
328 new_start += 1
329 if new_count == 0:
330 new_start -= 1
332 start_offset += counts[ADDITION] - counts[DELETION]
334 lines.append(
335 _format_hunk_header(
336 old_start, old_count, new_start, new_count, hunk.heading
339 lines.extend(filtered_lines)
341 # If there are only two lines, that means we did not include any hunks,
342 # so return None.
343 if len(lines) == 2:
344 return None
345 return ''.join(lines)
347 def generate_hunk_patch(self, line_idx, reverse=False):
348 """Return a patch containing the hunk for the specified line only"""
349 hunk = None
350 for hunk in self.hunks:
351 if line_idx <= hunk.last_line_idx:
352 break
353 if hunk is None:
354 return None
355 return self.generate_patch(
356 hunk.first_line_idx, hunk.last_line_idx, reverse=reverse