Fix some bugs found by pychecker
[zeroinstall.git] / tests / testsat.py
blobe6a51698ab1a561ca3e5a55596d62b3b94c64e65
1 #!/usr/bin/env python
2 from basetest import BaseTest
3 import sys, os
4 import unittest
6 sys.path.insert(0, '..')
7 from zeroinstall.injector import model, arch, qdom
8 from zeroinstall.injector.namespaces import XMLNS_IFACE
10 from zeroinstall.injector.solver import SATSolver as Solver
11 from zeroinstall.injector import sat
13 import logging
14 logger = logging.getLogger()
16 class Stores:
17 def lookup_any(self, digests):
18 return "/"
20 stores = Stores()
22 uri_prefix = 'http://localhost/tests/'
24 class Version:
25 def __init__(self, n):
26 self.n = n
27 self.requires = []
28 self.arch = None
30 def add_requires(self, lib, min_v, max_v):
31 self.requires.append((lib, min_v, max_v))
33 class Program:
34 def __init__(self, name):
35 self.name = name
36 self.versions = {}
38 def get_version(self, version):
39 if version not in self.versions:
40 self.versions[version] = Version(version)
41 return self.versions[version]
43 def build_feed(self):
44 def child(parent, name, attrs = None):
45 new = qdom.Element(XMLNS_IFACE, name, attrs or {})
46 parent.childNodes.append(new)
47 return new
49 root = qdom.Element(XMLNS_IFACE, 'interface', {'uri' : uri_prefix + self.name})
50 child(root, 'name').content = self.name
51 child(root, 'summary').content = self.name
53 i = 0
54 for version in self.versions.values():
55 attrs = {
56 'id': str(i),
57 'version': str(version.n),
59 if version.arch:
60 attrs['arch'] = version.arch
61 impl = child(root, 'implementation', attrs)
62 child(impl, 'manifest-digest', {'sha1new': '1234'})
63 for lib, min_v, max_v in version.requires:
64 req = child(impl, 'requires', {'interface': uri_prefix + lib})
65 child(req, 'version', {
66 'before': str(int(max_v) + 1),
67 'not-before': min_v})
68 i += 1
70 feed = model.ZeroInstallFeed(root)
71 feed.last_modified = 1
72 return feed
74 class TestCache:
75 def __init__(self):
76 self.progs = {}
77 self.interfaces = {}
78 self.feeds = {}
80 def get_prog(self, prog):
81 if not prog in self.progs:
82 self.progs[prog] = Program(prog)
83 return self.progs[prog]
85 def get_interface(self, uri):
86 if uri not in self.interfaces:
87 iface = model.Interface(uri)
88 self.interfaces[uri] = iface
89 return self.interfaces[uri]
91 def get_feed(self, url):
92 if url not in self.feeds:
93 feed = self.progs[url.rsplit('/', 1)[1]].build_feed()
94 self.feeds[url] = feed
95 return self.feeds[url]
97 def get_feed_imports(self, iface):
98 return []
100 def assertSelection(expected, repo):
101 cache = TestCache()
103 expected = [tuple(e.strip().split('-')) for e in expected.split(",")]
105 for line in repo.split('\n'):
106 line = line.strip()
107 if not line: continue
108 if ':' in line:
109 prog, versions = line.split(':')
110 prog = prog.strip()
111 if ' ' in prog:
112 prog, prog_arch = prog.split()
113 else:
114 prog_arch = None
115 for v in versions.split():
116 cache.get_prog(prog).get_version(v).arch = prog_arch
117 elif '=>' in line:
118 prog, requires = line.split('=>')
119 prog, version_range = prog.strip().split('[')
120 lib, min_v, max_v = requires.split()
121 assert version_range.endswith(']')
122 version_range = version_range[:-1]
123 if ',' in version_range:
124 min_p, max_p = map(int, version_range.split(','))
125 prog_versions = range(min_p, max_p + 1)
126 else:
127 prog_versions = [int(version_range)]
128 for prog_version in prog_versions:
129 cache.get_prog(prog).get_version(str(prog_version)).add_requires(lib, min_v, max_v)
131 root = uri_prefix + expected[0][0]
132 s = Solver(model.network_offline, cache, stores)
133 s.solve(root, arch.get_architecture('Linux', 'x86_64'))
135 if expected[0][1] == 'FAIL':
136 assert not s.ready
137 else:
138 assert s.ready
140 actual = []
141 for iface, impl in s.selections.iteritems():
142 actual.append(((iface.uri.rsplit('/', 1)[1]), impl.get_version()))
144 expected.sort()
145 actual.sort()
146 if expected != actual:
147 raise Exception("Solve failed:\nExpected: %s\n Actual: %s" % (expected, actual))
148 return s
150 class TestSAT(BaseTest):
151 def testTrivial(self):
152 assertSelection("prog-2", """
153 prog: 1 2
154 """)
156 def testSimple(self):
157 assertSelection("prog-5, liba-5", """
158 prog: 1 2 3 4 5
159 liba: 1 2 3 4 5
160 prog[1] => liba 0 4
161 prog[2] => liba 1 5
162 prog[5] => liba 4 5
163 """)
165 def testBestImpossible(self):
166 assertSelection("prog-1", """
167 prog: 1 2
168 liba: 1
169 prog[2] => liba 3 4
170 """)
172 def testSlow(self):
173 assertSelection("prog-1", """
174 prog: 1 2 3 4 5 6 7 8 9
175 liba: 1 2 3 4 5 6 7 8 9
176 libb: 1 2 3 4 5 6 7 8 9
177 libc: 1 2 3 4 5 6 7 8 9
178 libd: 1 2 3 4 5 6 7 8 9
179 libe: 1
180 prog[2,9] => liba 1 9
181 liba[1,9] => libb 1 9
182 libb[1,9] => libc 1 9
183 libc[1,9] => libd 1 9
184 libd[1,9] => libe 0 0
185 """)
187 def testNoSolution(self):
188 assertSelection("prog-FAIL", """
189 prog: 1 2 3
190 liba: 1
191 prog[1,3] => liba 2 3
192 """)
194 def testBacktrackSimple(self):
195 # We initially try liba-3 before learning that it
196 # is incompatible and backtracking.
197 # We learn that liba-3 doesn't work ever.
198 assertSelection("prog-1, liba-2", """
199 prog: 1
200 liba: 1 2 3
201 prog[1] => liba 1 2
202 """)
204 def testBacktrackLocal(self):
205 # We initially try liba-3 before learning that it
206 # is incompatible and backtracking.
207 # We learn that liba-3 doesn't work with prog-1.
208 assertSelection("prog-2, liba-2", """
209 prog: 1 2
210 liba: 1 2 3
211 prog[1,2] => liba 1 2
212 """)
214 def testLearning(self):
215 # Prog-2 depends on libb and libz, but we can't have both
216 # at once. The learning means we don't have to explore every
217 # possible combination of liba and libb.
218 assertSelection("prog-1", """
219 prog: 1 2
220 liba: 1 2 3
221 libb Linux-i486: 1 2 3
222 libz Linux-x86_64: 1 2
223 prog[2] => liba 1 3
224 prog[2] => libz 1 2
225 liba[1,3] => libb 1 3
226 """)
228 def testToplevelConflict(self):
229 # We don't detect the conflict until we start solving, but the
230 # conflict is top-level so we abort immediately without
231 # backtracking.
232 assertSelection("prog-FAIL", """
233 prog Linux-i386: 1
234 liba Linux-x86_64: 1
235 prog[1] => liba 1 1
236 """)
238 def testDiamondConflict(self):
239 # prog depends on liba and libb, which depend on incompatible
240 # versions of libc.
241 assertSelection("prog-FAIL", """
242 prog: 1
243 liba: 1
244 libb: 1
245 libc: 1 2
246 prog[1] => liba 1 1
247 prog[1] => libb 1 1
248 liba[1] => libc 1 1
249 libb[1] => libc 2 3
250 """)
252 def testCoverage(self):
253 # Try to trigger some edge cases...
255 # An at_most_one clause must be analysed for causing
256 # a conflict.
257 solver = sat.SATProblem()
258 v1 = solver.add_variable("v1")
259 v2 = solver.add_variable("v2")
260 v3 = solver.add_variable("v3")
261 solver.at_most_one([v1, v2])
262 solver.add_clause([v1, sat.neg(v3)])
263 solver.add_clause([v2, sat.neg(v3)])
264 solver.add_clause([v1, v3])
265 solver.run_solver(lambda: v3)
267 def testFailState(self):
268 # If we can't select a valid combination,
269 # try to select as many as we can.
270 s = assertSelection("prog-FAIL", """
271 prog: 1 2
272 liba: 1 2
273 libb: 1 2
274 libc: 5
275 prog[1,2] => liba 1 2
276 liba[1,2] => libb 1 2
277 libb[1,2] => libc 0 0
278 """)
279 assert not s.ready
280 selected = {}
281 for iface, impl in s.selections.iteritems():
282 if impl is not None: impl = impl.get_version()
283 selected[iface.uri.rsplit('/', 1)[1]] = impl
284 self.assertEquals({
285 'prog': '2',
286 'liba': '2',
287 'libb': '2',
288 'libc': None
289 }, selected)
291 def testWatch(self):
292 solver = sat.SATProblem()
294 a = solver.add_variable('a')
295 b = solver.add_variable('b')
296 c = solver.add_variable('c')
298 # Add a clause. It starts watching the first two variables (a and b).
299 # (use the internal function to avoid variable reordering)
300 solver._add_clause([a, b, c], False)
302 # b is False, so it switches to watching a and c
303 solver.add_clause([sat.neg(b)])
305 # Try to trigger bug.
306 solver.add_clause([c])
308 decisions = [a]
309 solver.run_solver(lambda: decisions.pop())
310 assert not decisions # All used up
312 assert solver.assigns[a].value == True
314 def testOverbacktrack(self):
315 # After learning that prog-3 => m0 we backtrack all the way up to the prog-3
316 # assignment, unselecting liba-3, and then select it again.
317 assertSelection("prog-3, liba-3, libb-3, libc-1, libz-2", """
318 prog: 1 2 3
319 liba: 1 2 3
320 libb: 1 2 3
321 libc Linux-x86_64: 2 3
322 libc Linux-i486: 1
323 libz Linux-i386: 1 2
324 prog[2,3] => liba 1 3
325 prog[2,3] => libz 1 2
326 liba[1,3] => libb 1 3
327 libb[1,3] => libc 1 3
328 """)
330 if __name__ == '__main__':
331 unittest.main()