Added Selection.get_path()
[zeroinstall/solver.git] / tests / testsat.py
blob82186673f13ecb5e95cadda24911b564da0e521d
1 #!/usr/bin/env python
2 from basetest import BaseTest
3 import sys
4 import unittest
6 sys.path.insert(0, '..')
7 from zeroinstall.injector import model, arch, qdom
8 from zeroinstall.injector.namespaces import XMLNS_IFACE
10 from zeroinstall.injector.solver import SATSolver as Solver
11 from zeroinstall.injector import sat
13 import logging
14 logger = logging.getLogger()
16 class Stores:
17 def lookup_maybe(self, digests):
18 return "/"
20 stores = Stores()
22 uri_prefix = 'http://localhost/tests/'
24 class Version:
25 def __init__(self, n):
26 self.n = n
27 self.requires = []
28 self.arch = None
30 def add_requires(self, lib, min_v, max_v):
31 self.requires.append((lib, min_v, max_v))
33 class Program:
34 def __init__(self, name):
35 self.name = name
36 self.versions = {}
38 def get_version(self, version):
39 if version not in self.versions:
40 self.versions[version] = Version(version)
41 return self.versions[version]
43 def build_feed(self):
44 def child(parent, name, attrs = None):
45 new = qdom.Element(XMLNS_IFACE, name, attrs or {})
46 parent.childNodes.append(new)
47 return new
49 root = qdom.Element(XMLNS_IFACE, 'interface', {'uri' : uri_prefix + self.name})
50 child(root, 'name').content = self.name
51 child(root, 'summary').content = self.name
53 i = 0
54 for version in self.versions.values():
55 attrs = {
56 'id': str(i),
57 'version': str(version.n),
58 'main': 'dummy',
60 if version.arch:
61 attrs['arch'] = version.arch
62 impl = child(root, 'implementation', attrs)
63 child(impl, 'manifest-digest', {'sha1new': '1234'})
64 for lib, min_v, max_v in version.requires:
65 req = child(impl, 'requires', {'interface': uri_prefix + lib})
66 child(req, 'version', {
67 'before': str(int(max_v) + 1),
68 'not-before': min_v})
69 i += 1
71 feed = model.ZeroInstallFeed(root)
72 feed.last_modified = 1
73 return feed
75 class TestCache:
76 def __init__(self):
77 self.progs = {}
78 self.interfaces = {}
79 self.feeds = {}
81 def get_prog(self, prog):
82 if not prog in self.progs:
83 self.progs[prog] = Program(prog)
84 return self.progs[prog]
86 def get_interface(self, uri):
87 if uri not in self.interfaces:
88 iface = model.Interface(uri)
89 self.interfaces[uri] = iface
90 return self.interfaces[uri]
92 def get_feed(self, url):
93 if url not in self.feeds:
94 feed = self.progs[url.rsplit('/', 1)[1]].build_feed()
95 self.feeds[url] = feed
96 return self.feeds[url]
98 def get_feed_imports(self, iface):
99 return []
101 def assertSelection(expected, repo):
102 cache = TestCache()
104 expected = [tuple(e.strip().split('-')) for e in expected.split(",")]
106 for line in repo.split('\n'):
107 line = line.strip()
108 if not line: continue
109 if ':' in line:
110 prog, versions = line.split(':')
111 prog = prog.strip()
112 if ' ' in prog:
113 prog, prog_arch = prog.split()
114 else:
115 prog_arch = None
116 for v in versions.split():
117 cache.get_prog(prog).get_version(v).arch = prog_arch
118 elif '=>' in line:
119 prog, requires = line.split('=>')
120 prog, version_range = prog.strip().split('[')
121 lib, min_v, max_v = requires.split()
122 assert version_range.endswith(']')
123 version_range = version_range[:-1]
124 if ',' in version_range:
125 min_p, max_p = map(int, version_range.split(','))
126 prog_versions = range(min_p, max_p + 1)
127 else:
128 prog_versions = [int(version_range)]
129 for prog_version in prog_versions:
130 cache.get_prog(prog).get_version(str(prog_version)).add_requires(lib, min_v, max_v)
132 root = uri_prefix + expected[0][0]
134 class TestConfig:
135 help_with_testing = False
136 network_use = model.network_offline
137 stores = stores
138 iface_cache = cache
140 s = Solver(TestConfig())
141 s.solve(root, arch.get_architecture('Linux', 'x86_64'))
143 if expected[0][1] == 'FAIL':
144 assert not s.ready
145 else:
146 assert s.ready
148 actual = []
149 for iface_uri, impl in s.selections.selections.iteritems():
150 actual.append(((iface_uri.rsplit('/', 1)[1]), impl.version))
152 expected.sort()
153 actual.sort()
154 if expected != actual:
155 raise Exception("Solve failed:\nExpected: %s\n Actual: %s" % (expected, actual))
156 return s
158 class TestSAT(BaseTest):
159 def testTrivial(self):
160 assertSelection("prog-2", """
161 prog: 1 2
162 """)
164 def testSimple(self):
165 assertSelection("prog-5, liba-5", """
166 prog: 1 2 3 4 5
167 liba: 1 2 3 4 5
168 prog[1] => liba 0 4
169 prog[2] => liba 1 5
170 prog[5] => liba 4 5
171 """)
173 def testBestImpossible(self):
174 assertSelection("prog-1", """
175 prog: 1 2
176 liba: 1
177 prog[2] => liba 3 4
178 """)
180 def testSlow(self):
181 assertSelection("prog-1", """
182 prog: 1 2 3 4 5 6 7 8 9
183 liba: 1 2 3 4 5 6 7 8 9
184 libb: 1 2 3 4 5 6 7 8 9
185 libc: 1 2 3 4 5 6 7 8 9
186 libd: 1 2 3 4 5 6 7 8 9
187 libe: 1
188 prog[2,9] => liba 1 9
189 liba[1,9] => libb 1 9
190 libb[1,9] => libc 1 9
191 libc[1,9] => libd 1 9
192 libd[1,9] => libe 0 0
193 """)
195 def testNoSolution(self):
196 assertSelection("prog-FAIL", """
197 prog: 1 2 3
198 liba: 1
199 prog[1,3] => liba 2 3
200 """)
202 def testBacktrackSimple(self):
203 # We initially try liba-3 before learning that it
204 # is incompatible and backtracking.
205 # We learn that liba-3 doesn't work ever.
206 assertSelection("prog-1, liba-2", """
207 prog: 1
208 liba: 1 2 3
209 prog[1] => liba 1 2
210 """)
212 def testBacktrackLocal(self):
213 # We initially try liba-3 before learning that it
214 # is incompatible and backtracking.
215 # We learn that liba-3 doesn't work with prog-1.
216 assertSelection("prog-2, liba-2", """
217 prog: 1 2
218 liba: 1 2 3
219 prog[1,2] => liba 1 2
220 """)
222 def testLearning(self):
223 # Prog-2 depends on libb and libz, but we can't have both
224 # at once. The learning means we don't have to explore every
225 # possible combination of liba and libb.
226 assertSelection("prog-1", """
227 prog: 1 2
228 liba: 1 2 3
229 libb Linux-i486: 1 2 3
230 libz Linux-x86_64: 1 2
231 prog[2] => liba 1 3
232 prog[2] => libz 1 2
233 liba[1,3] => libb 1 3
234 """)
236 def testToplevelConflict(self):
237 # We don't detect the conflict until we start solving, but the
238 # conflict is top-level so we abort immediately without
239 # backtracking.
240 assertSelection("prog-FAIL", """
241 prog Linux-i386: 1
242 liba Linux-x86_64: 1
243 prog[1] => liba 1 1
244 """)
246 def testDiamondConflict(self):
247 # prog depends on liba and libb, which depend on incompatible
248 # versions of libc.
249 assertSelection("prog-FAIL", """
250 prog: 1
251 liba: 1
252 libb: 1
253 libc: 1 2
254 prog[1] => liba 1 1
255 prog[1] => libb 1 1
256 liba[1] => libc 1 1
257 libb[1] => libc 2 3
258 """)
260 def testCoverage(self):
261 # Try to trigger some edge cases...
263 # An at_most_one clause must be analysed for causing
264 # a conflict.
265 solver = sat.SATProblem()
266 v1 = solver.add_variable("v1")
267 v2 = solver.add_variable("v2")
268 v3 = solver.add_variable("v3")
269 solver.at_most_one([v1, v2])
270 solver.add_clause([v1, sat.neg(v3)])
271 solver.add_clause([v2, sat.neg(v3)])
272 solver.add_clause([v1, v3])
273 solver.run_solver(lambda: v3)
275 def testFailState(self):
276 # If we can't select a valid combination,
277 # try to select as many as we can.
278 s = assertSelection("prog-FAIL", """
279 prog: 1 2
280 liba: 1 2
281 libb: 1 2
282 libc: 5
283 prog[1,2] => liba 1 2
284 liba[1,2] => libb 1 2
285 libb[1,2] => libc 0 0
286 """)
287 assert not s.ready
288 selected = {}
289 for iface_uri, impl in s.selections.selections.iteritems():
290 if impl is not None: impl = impl.version
291 selected[iface_uri.rsplit('/', 1)[1]] = impl
292 self.assertEqual({
293 'prog': '2',
294 'liba': '2',
295 'libb': '2',
296 'libc': None
297 }, selected)
299 def testWatch(self):
300 solver = sat.SATProblem()
302 a = solver.add_variable('a')
303 b = solver.add_variable('b')
304 c = solver.add_variable('c')
306 # Add a clause. It starts watching the first two variables (a and b).
307 # (use the internal function to avoid variable reordering)
308 solver._add_clause([a, b, c], False)
310 # b is False, so it switches to watching a and c
311 solver.add_clause([sat.neg(b)])
313 # Try to trigger bug.
314 solver.add_clause([c])
316 decisions = [a]
317 solver.run_solver(lambda: decisions.pop())
318 assert not decisions # All used up
320 assert solver.assigns[a].value == True
322 def testOverbacktrack(self):
323 # After learning that prog-3 => m0 we backtrack all the way up to the prog-3
324 # assignment, unselecting liba-3, and then select it again.
325 assertSelection("prog-3, liba-3, libb-3, libc-1, libz-2", """
326 prog: 1 2 3
327 liba: 1 2 3
328 libb: 1 2 3
329 libc Linux-x86_64: 2 3
330 libc Linux-i486: 1
331 libz Linux-i386: 1 2
332 prog[2,3] => liba 1 3
333 prog[2,3] => libz 1 2
334 liba[1,3] => libb 1 3
335 libb[1,3] => libc 1 3
336 """)
338 if __name__ == '__main__':
339 unittest.main()