Updated epydoc
[zeroinstall/solver.git] / tests / testsat.py
blob720e12a1ef84e308686040c76dc5cd60a48a68cc
1 #!/usr/bin/env python
2 from basetest import BaseTest
3 import sys, os
4 import unittest
6 sys.path.insert(0, '..')
7 from zeroinstall.injector import model, arch, qdom
8 from zeroinstall.injector.namespaces import XMLNS_IFACE
10 from zeroinstall.injector.solver import SATSolver as Solver
11 from zeroinstall.injector import sat
13 import logging
14 logger = logging.getLogger()
16 class Stores:
17 def lookup_any(self, digests):
18 return "/"
20 stores = Stores()
22 uri_prefix = 'http://localhost/tests/'
24 class Version:
25 def __init__(self, n):
26 self.n = n
27 self.requires = []
28 self.arch = None
30 def add_requires(self, lib, min_v, max_v):
31 self.requires.append((lib, min_v, max_v))
33 class Program:
34 def __init__(self, name):
35 self.name = name
36 self.versions = {}
38 def get_version(self, version):
39 if version not in self.versions:
40 self.versions[version] = Version(version)
41 return self.versions[version]
43 def build_feed(self):
44 def child(parent, name, attrs = None):
45 new = qdom.Element(XMLNS_IFACE, name, attrs or {})
46 parent.childNodes.append(new)
47 return new
49 root = qdom.Element(XMLNS_IFACE, 'interface', {'uri' : uri_prefix + self.name})
50 child(root, 'name').content = self.name
51 child(root, 'summary').content = self.name
53 i = 0
54 for version in self.versions.values():
55 attrs = {
56 'id': str(i),
57 'version': str(version.n),
59 if version.arch:
60 attrs['arch'] = version.arch
61 impl = child(root, 'implementation', attrs)
62 child(impl, 'manifest-digest', {'sha1new': '1234'})
63 for lib, min_v, max_v in version.requires:
64 req = child(impl, 'requires', {'interface': uri_prefix + lib})
65 child(req, 'version', {
66 'before': str(int(max_v) + 1),
67 'not-before': min_v})
68 i += 1
70 feed = model.ZeroInstallFeed(root)
71 feed.last_modified = 1
72 return feed
74 class TestCache:
75 def __init__(self):
76 self.progs = {}
77 self.interfaces = {}
78 self.feeds = {}
80 def get_prog(self, prog):
81 if not prog in self.progs:
82 self.progs[prog] = Program(prog)
83 return self.progs[prog]
85 def get_interface(self, uri):
86 if uri not in self.interfaces:
87 iface = model.Interface(uri)
88 self.interfaces[uri] = iface
89 return self.interfaces[uri]
91 def get_feed(self, url):
92 if url not in self.feeds:
93 feed = self.progs[url.rsplit('/', 1)[1]].build_feed()
94 self.feeds[url] = feed
95 return self.feeds[url]
97 def assertSelection(expected, repo):
98 cache = TestCache()
100 expected = [tuple(e.strip().split('-')) for e in expected.split(",")]
102 for line in repo.split('\n'):
103 line = line.strip()
104 if not line: continue
105 if ':' in line:
106 prog, versions = line.split(':')
107 prog = prog.strip()
108 if ' ' in prog:
109 prog, prog_arch = prog.split()
110 else:
111 prog_arch = None
112 for v in versions.split():
113 cache.get_prog(prog).get_version(v).arch = prog_arch
114 elif '=>' in line:
115 prog, requires = line.split('=>')
116 prog, version_range = prog.strip().split('[')
117 lib, min_v, max_v = requires.split()
118 assert version_range.endswith(']')
119 version_range = version_range[:-1]
120 if ',' in version_range:
121 min_p, max_p = map(int, version_range.split(','))
122 prog_versions = range(min_p, max_p + 1)
123 else:
124 prog_versions = [int(version_range)]
125 for prog_version in prog_versions:
126 cache.get_prog(prog).get_version(str(prog_version)).add_requires(lib, min_v, max_v)
128 root = uri_prefix + expected[0][0]
129 s = Solver(model.network_offline, cache, stores)
130 s.solve(root, arch.get_architecture('Linux', 'x86_64'))
132 if expected[0][1] == 'FAIL':
133 assert not s.ready
134 else:
135 assert s.ready
137 actual = []
138 for iface, impl in s.selections.iteritems():
139 actual.append(((iface.uri.rsplit('/', 1)[1]), impl.get_version()))
141 expected.sort()
142 actual.sort()
143 if expected != actual:
144 raise Exception("Solve failed:\nExpected: %s\n Actual: %s" % (expected, actual))
145 return s
147 class TestSAT(BaseTest):
148 def testTrivial(self):
149 assertSelection("prog-2", """
150 prog: 1 2
151 """)
153 def testSimple(self):
154 assertSelection("prog-5, liba-5", """
155 prog: 1 2 3 4 5
156 liba: 1 2 3 4 5
157 prog[1] => liba 0 4
158 prog[2] => liba 1 5
159 prog[5] => liba 4 5
160 """)
162 def testBestImpossible(self):
163 assertSelection("prog-1", """
164 prog: 1 2
165 liba: 1
166 prog[2] => liba 3 4
167 """)
169 def testSlow(self):
170 assertSelection("prog-1", """
171 prog: 1 2 3 4 5 6 7 8 9
172 liba: 1 2 3 4 5 6 7 8 9
173 libb: 1 2 3 4 5 6 7 8 9
174 libc: 1 2 3 4 5 6 7 8 9
175 libd: 1 2 3 4 5 6 7 8 9
176 libe: 1
177 prog[2,9] => liba 1 9
178 liba[1,9] => libb 1 9
179 libb[1,9] => libc 1 9
180 libc[1,9] => libd 1 9
181 libd[1,9] => libe 0 0
182 """)
184 def testNoSolution(self):
185 assertSelection("prog-FAIL", """
186 prog: 1 2 3
187 liba: 1
188 prog[1,3] => liba 2 3
189 """)
191 def testBacktrackSimple(self):
192 # We initially try liba-3 before learning that it
193 # is incompatible and backtracking.
194 # We learn that liba-3 doesn't work ever.
195 assertSelection("prog-1, liba-2", """
196 prog: 1
197 liba: 1 2 3
198 prog[1] => liba 1 2
199 """)
201 def testBacktrackLocal(self):
202 # We initially try liba-3 before learning that it
203 # is incompatible and backtracking.
204 # We learn that liba-3 doesn't work with prog-1.
205 assertSelection("prog-2, liba-2", """
206 prog: 1 2
207 liba: 1 2 3
208 prog[1,2] => liba 1 2
209 """)
211 def testLearning(self):
212 # Prog-2 depends on libb and libz, but we can't have both
213 # at once. The learning means we don't have to explore every
214 # possible combination of liba and libb.
215 assertSelection("prog-1", """
216 prog: 1 2
217 liba: 1 2 3
218 libb Linux-i486: 1 2 3
219 libz Linux-x86_64: 1 2
220 prog[2] => liba 1 3
221 prog[2] => libz 1 2
222 liba[1,3] => libb 1 3
223 """)
225 def testToplevelConflict(self):
226 # We don't detect the conflict until we start solving, but the
227 # conflict is top-level so we abort immediately without
228 # backtracking.
229 assertSelection("prog-FAIL", """
230 prog Linux-i386: 1
231 liba Linux-x86_64: 1
232 prog[1] => liba 1 1
233 """)
235 def testDiamondConflict(self):
236 # prog depends on liba and libb, which depend on incompatible
237 # versions of libc.
238 assertSelection("prog-FAIL", """
239 prog: 1
240 liba: 1
241 libb: 1
242 libc: 1 2
243 prog[1] => liba 1 1
244 prog[1] => libb 1 1
245 liba[1] => libc 1 1
246 libb[1] => libc 2 3
247 """)
249 def testCoverage(self):
250 # Try to trigger some edge cases...
252 # An at_most_one clause must be analysed for causing
253 # a conflict.
254 solver = sat.SATProblem()
255 v1 = solver.add_variable("v1")
256 v2 = solver.add_variable("v2")
257 v3 = solver.add_variable("v3")
258 solver.at_most_one([v1, v2])
259 solver.add_clause([v1, sat.neg(v3)])
260 solver.add_clause([v2, sat.neg(v3)])
261 solver.add_clause([v1, v3])
262 solver.run_solver(lambda: v3)
264 def testFailState(self):
265 # If we can't select a valid combination,
266 # try to select as many as we can.
267 s = assertSelection("prog-FAIL", """
268 prog: 1 2
269 liba: 1 2
270 libb: 1 2
271 libc: 5
272 prog[1,2] => liba 1 2
273 liba[1,2] => libb 1 2
274 libb[1,2] => libc 0 0
275 """)
276 assert not s.ready
277 selected = {}
278 for iface, impl in s.selections.iteritems():
279 if impl is not None: impl = impl.get_version()
280 selected[iface.uri.rsplit('/', 1)[1]] = impl
281 self.assertEquals({
282 'prog': '2',
283 'liba': '2',
284 'libb': '2',
285 'libc': None
286 }, selected)
288 def testWatch(self):
289 solver = sat.SATProblem()
291 a = solver.add_variable('a')
292 b = solver.add_variable('b')
293 c = solver.add_variable('c')
295 # Add a clause. It starts watching the first two variables (a and b).
296 # (use the internal function to avoid variable reordering)
297 solver._add_clause([a, b, c], False)
299 # b is False, so it switches to watching a and c
300 solver.add_clause([sat.neg(b)])
302 # Try to trigger bug.
303 solver.add_clause([c])
305 decisions = [a]
306 solver.run_solver(lambda: decisions.pop())
307 assert not decisions # All used up
309 assert solver.assigns[a].value == True
311 def testOverbacktrack(self):
312 # After learning that prog-3 => m0 we backtrack all the way up to the prog-3
313 # assignment, unselecting liba-3, and then select it again.
314 assertSelection("prog-3, liba-3, libb-3, libc-1, libz-2", """
315 prog: 1 2 3
316 liba: 1 2 3
317 libb: 1 2 3
318 libc Linux-x86_64: 2 3
319 libc Linux-i486: 1
320 libz Linux-i386: 1 2
321 prog[2,3] => liba 1 3
322 prog[2,3] => libz 1 2
323 liba[1,3] => libb 1 3
324 libb[1,3] => libc 1 3
325 """)
327 if __name__ == '__main__':
328 unittest.main()