[PATCH] Make the ProgramError class printable.
[git/vmiklos.git] / gitMergeCommon.py
blobce9694b15ebe131fea9536dc7b2a0f04d9ceca02
1 import sys, re, os, traceback
2 from sets import Set
4 if sys.version_info[0] < 2 or \
5 (sys.version_info[0] == 2 and sys.version_info[1] < 4):
6 print 'Python version 2.4 required, found', \
7 str(sys.version_info[0])+'.'+str(sys.version_info[1])+'.'+ \
8 str(sys.version_info[2])
9 sys.exit(1)
11 import subprocess
13 def die(*args):
14 printList(args, sys.stderr)
15 sys.exit(2)
17 # Debugging machinery
18 # -------------------
20 DEBUG = 0
21 functionsToDebug = Set()
23 def addDebug(func):
24 if type(func) == str:
25 functionsToDebug.add(func)
26 else:
27 functionsToDebug.add(func.func_name)
29 def debug(*args):
30 if DEBUG:
31 funcName = traceback.extract_stack()[-2][2]
32 if funcName in functionsToDebug:
33 printList(args)
35 def printList(list, file=sys.stdout):
36 for x in list:
37 file.write(str(x))
38 file.write(' ')
39 file.write('\n')
41 # Program execution
42 # -----------------
44 class ProgramError(Exception):
45 def __init__(self, progStr, error):
46 self.progStr = progStr
47 self.error = error
49 def __str__(self):
50 return self.progStr + ': ' + self.error
52 addDebug('runProgram')
53 def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
54 debug('runProgram prog:', str(prog), 'input:', str(input))
55 if type(prog) is str:
56 progStr = prog
57 else:
58 progStr = ' '.join(prog)
60 try:
61 if pipeOutput:
62 stderr = subprocess.STDOUT
63 stdout = subprocess.PIPE
64 else:
65 stderr = None
66 stdout = None
67 pop = subprocess.Popen(prog,
68 shell = type(prog) is str,
69 stderr=stderr,
70 stdout=stdout,
71 stdin=subprocess.PIPE,
72 env=env)
73 except OSError, e:
74 debug('strerror:', e.strerror)
75 raise ProgramError(progStr, e.strerror)
77 if input != None:
78 pop.stdin.write(input)
79 pop.stdin.close()
81 if pipeOutput:
82 out = pop.stdout.read()
83 else:
84 out = ''
86 code = pop.wait()
87 if returnCode:
88 ret = [out, code]
89 else:
90 ret = out
91 if code != 0 and not returnCode:
92 debug('error output:', out)
93 debug('prog:', prog)
94 raise ProgramError(progStr, out)
95 # debug('output:', out.replace('\0', '\n'))
96 return ret
98 # Code for computing common ancestors
99 # -----------------------------------
101 currentId = 0
102 def getUniqueId():
103 global currentId
104 currentId += 1
105 return currentId
107 # The 'virtual' commit objects have SHAs which are integers
108 shaRE = re.compile('^[0-9a-f]{40}$')
109 def isSha(obj):
110 return (type(obj) is str and bool(shaRE.match(obj))) or \
111 (type(obj) is int and obj >= 1)
113 class Commit:
114 def __init__(self, sha, parents, tree=None):
115 self.parents = parents
116 self.firstLineMsg = None
117 self.children = []
119 if tree:
120 tree = tree.rstrip()
121 assert(isSha(tree))
122 self._tree = tree
124 if not sha:
125 self.sha = getUniqueId()
126 self.virtual = True
127 self.firstLineMsg = 'virtual commit'
128 assert(isSha(tree))
129 else:
130 self.virtual = False
131 self.sha = sha.rstrip()
132 assert(isSha(self.sha))
134 def tree(self):
135 self.getInfo()
136 assert(self._tree != None)
137 return self._tree
139 def shortInfo(self):
140 self.getInfo()
141 return str(self.sha) + ' ' + self.firstLineMsg
143 def __str__(self):
144 return self.shortInfo()
146 def getInfo(self):
147 if self.virtual or self.firstLineMsg != None:
148 return
149 else:
150 info = runProgram(['git-cat-file', 'commit', self.sha])
151 info = info.split('\n')
152 msg = False
153 for l in info:
154 if msg:
155 self.firstLineMsg = l
156 break
157 else:
158 if l.startswith('tree'):
159 self._tree = l[5:].rstrip()
160 elif l == '':
161 msg = True
163 class Graph:
164 def __init__(self):
165 self.commits = []
166 self.shaMap = {}
168 def addNode(self, node):
169 assert(isinstance(node, Commit))
170 self.shaMap[node.sha] = node
171 self.commits.append(node)
172 for p in node.parents:
173 p.children.append(node)
174 return node
176 def reachableNodes(self, n1, n2):
177 res = {}
178 def traverse(n):
179 res[n] = True
180 for p in n.parents:
181 traverse(p)
183 traverse(n1)
184 traverse(n2)
185 return res
187 def fixParents(self, node):
188 for x in range(0, len(node.parents)):
189 node.parents[x] = self.shaMap[node.parents[x]]
191 # addDebug('buildGraph')
192 def buildGraph(heads):
193 debug('buildGraph heads:', heads)
194 for h in heads:
195 assert(isSha(h))
197 g = Graph()
199 out = runProgram(['git-rev-list', '--parents'] + heads)
200 for l in out.split('\n'):
201 if l == '':
202 continue
203 shas = l.split(' ')
205 # This is a hack, we temporarily use the 'parents' attribute
206 # to contain a list of SHA1:s. They are later replaced by proper
207 # Commit objects.
208 c = Commit(shas[0], shas[1:])
210 g.commits.append(c)
211 g.shaMap[c.sha] = c
213 for c in g.commits:
214 g.fixParents(c)
216 for c in g.commits:
217 for p in c.parents:
218 p.children.append(c)
219 return g
221 # Write the empty tree to the object database and return its SHA1
222 def writeEmptyTree():
223 tmpIndex = os.environ['GIT_DIR'] + '/merge-tmp-index'
224 def delTmpIndex():
225 try:
226 os.unlink(tmpIndex)
227 except OSError:
228 pass
229 delTmpIndex()
230 newEnv = os.environ.copy()
231 newEnv['GIT_INDEX_FILE'] = tmpIndex
232 res = runProgram(['git-write-tree'], env=newEnv).rstrip()
233 delTmpIndex()
234 return res
236 def addCommonRoot(graph):
237 roots = []
238 for c in graph.commits:
239 if len(c.parents) == 0:
240 roots.append(c)
242 superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
243 graph.addNode(superRoot)
244 for r in roots:
245 r.parents = [superRoot]
246 superRoot.children = roots
247 return superRoot
249 def getCommonAncestors(graph, commit1, commit2):
250 '''Find the common ancestors for commit1 and commit2'''
251 assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
253 def traverse(start, set):
254 stack = [start]
255 while len(stack) > 0:
256 el = stack.pop()
257 set.add(el)
258 for p in el.parents:
259 if p not in set:
260 stack.append(p)
261 h1Set = Set()
262 h2Set = Set()
263 traverse(commit1, h1Set)
264 traverse(commit2, h2Set)
265 shared = h1Set.intersection(h2Set)
267 if len(shared) == 0:
268 shared = [addCommonRoot(graph)]
270 res = Set()
272 for s in shared:
273 if len([c for c in s.children if c in shared]) == 0:
274 res.add(s)
275 return list(res)