If we fail to find a set of versions, try to return a close match
[zeroinstall/solver.git] / tests / testsat.py
blobb5e8cc6d43b5db4e40a9915e2c62479a53297101
1 #!/usr/bin/env python
2 from basetest import BaseTest
3 import sys, os
4 import unittest
6 sys.path.insert(0, '..')
7 from zeroinstall.injector import model, arch, qdom
8 from zeroinstall.injector.namespaces import XMLNS_IFACE
10 #from zeroinstall.injector.origsolver import DefaultSolver as Solver
11 #from zeroinstall.injector.pbsolver import PBSolver as Solver
12 #from zeroinstall.injector.sgsolver import DefaultSolver as Solver
13 from zeroinstall.injector.solver import SATSolver as Solver
14 from zeroinstall.injector import sat
16 import logging
17 logger = logging.getLogger()
19 class Stores:
20 def lookup_any(self, digests):
21 return "/"
23 stores = Stores()
25 uri_prefix = 'http://localhost/tests/'
27 class Version:
28 def __init__(self, n):
29 self.n = n
30 self.requires = []
31 self.arch = None
33 def add_requires(self, lib, min_v, max_v):
34 self.requires.append((lib, min_v, max_v))
36 class Program:
37 def __init__(self, name):
38 self.name = name
39 self.versions = {}
41 def get_version(self, version):
42 if version not in self.versions:
43 self.versions[version] = Version(version)
44 return self.versions[version]
46 def build_feed(self):
47 def child(parent, name, attrs = None):
48 new = qdom.Element(XMLNS_IFACE, name, attrs or {})
49 parent.childNodes.append(new)
50 return new
52 root = qdom.Element(XMLNS_IFACE, 'interface', {'uri' : uri_prefix + self.name})
53 child(root, 'name').content = self.name
54 child(root, 'summary').content = self.name
56 i = 0
57 for version in self.versions.values():
58 attrs = {
59 'id': str(i),
60 'version': str(version.n),
62 if version.arch:
63 attrs['arch'] = version.arch
64 impl = child(root, 'implementation', attrs)
65 child(impl, 'manifest-digest', {'sha1new': '1234'})
66 for lib, min_v, max_v in version.requires:
67 req = child(impl, 'requires', {'interface': uri_prefix + lib})
68 child(req, 'version', {
69 'before': str(int(max_v) + 1),
70 'not-before': min_v})
71 i += 1
73 feed = model.ZeroInstallFeed(root)
74 feed.last_modified = 1
75 return feed
77 class TestCache:
78 def __init__(self):
79 self.progs = {}
80 self.interfaces = {}
82 def get_prog(self, prog):
83 if not prog in self.progs:
84 self.progs[prog] = Program(prog)
85 return self.progs[prog]
87 def get_interface(self, uri):
88 if uri not in self.interfaces:
89 iface = model.Interface(uri)
90 iface._main_feed = self.progs[uri.rsplit('/', 1)[1]].build_feed()
91 self.interfaces[uri] = iface
92 return self.interfaces[uri]
94 def assertSelection(expected, repo):
95 cache = TestCache()
97 expected = [tuple(e.strip().split('-')) for e in expected.split(",")]
99 for line in repo.split('\n'):
100 line = line.strip()
101 if not line: continue
102 if ':' in line:
103 prog, versions = line.split(':')
104 prog = prog.strip()
105 if ' ' in prog:
106 prog, prog_arch = prog.split()
107 else:
108 prog_arch = None
109 for v in versions.split():
110 cache.get_prog(prog).get_version(v).arch = prog_arch
111 elif '=>' in line:
112 prog, requires = line.split('=>')
113 prog, version_range = prog.strip().split('[')
114 lib, min_v, max_v = requires.split()
115 assert version_range.endswith(']')
116 version_range = version_range[:-1]
117 if ',' in version_range:
118 min_p, max_p = map(int, version_range.split(','))
119 prog_versions = range(min_p, max_p + 1)
120 else:
121 prog_versions = [int(version_range)]
122 for prog_version in prog_versions:
123 cache.get_prog(prog).get_version(str(prog_version)).add_requires(lib, min_v, max_v)
125 root = uri_prefix + expected[0][0]
126 s = Solver(model.network_offline, cache, stores)
127 s.solve(root, arch.get_architecture('Linux', 'x86_64'))
129 if expected[0][1] == 'FAIL':
130 assert not s.ready
131 else:
132 assert s.ready
134 actual = []
135 for iface, impl in s.selections.iteritems():
136 actual.append(((iface.uri.rsplit('/', 1)[1]), impl.get_version()))
138 expected.sort()
139 actual.sort()
140 if expected != actual:
141 raise Exception("Solve failed:\nExpected: %s\n Actual: %s" % (expected, actual))
142 return s
144 class TestSAT(BaseTest):
145 def testTrivial(self):
146 assertSelection("prog-2", """
147 prog: 1 2
148 """)
150 def testSimple(self):
151 assertSelection("prog-5, liba-5", """
152 prog: 1 2 3 4 5
153 liba: 1 2 3 4 5
154 prog[1] => liba 0 4
155 prog[2] => liba 1 5
156 prog[5] => liba 4 5
157 """)
159 def testBestImpossible(self):
160 assertSelection("prog-1", """
161 prog: 1 2
162 liba: 1
163 prog[2] => liba 3 4
164 """)
166 def testSlow(self):
167 assertSelection("prog-1", """
168 prog: 1 2 3 4 5 6 7 8 9
169 liba: 1 2 3 4 5 6 7 8 9
170 libb: 1 2 3 4 5 6 7 8 9
171 libc: 1 2 3 4 5 6 7 8 9
172 libd: 1 2 3 4 5 6 7 8 9
173 libe: 1
174 prog[2,9] => liba 1 9
175 liba[1,9] => libb 1 9
176 libb[1,9] => libc 1 9
177 libc[1,9] => libd 1 9
178 libd[1,9] => libe 0 0
179 """)
181 def testNoSolution(self):
182 assertSelection("prog-FAIL", """
183 prog: 1 2 3
184 liba: 1
185 prog[1,3] => liba 2 3
186 """)
188 def testBacktrackSimple(self):
189 # We initially try liba-3 before learning that it
190 # is incompatible and backtracking.
191 # We learn that liba-3 doesn't work ever.
192 assertSelection("prog-1, liba-2", """
193 prog: 1
194 liba: 1 2 3
195 prog[1] => liba 1 2
196 """)
198 def testBacktrackLocal(self):
199 # We initially try liba-3 before learning that it
200 # is incompatible and backtracking.
201 # We learn that liba-3 doesn't work with prog-1.
202 assertSelection("prog-2, liba-2", """
203 prog: 1 2
204 liba: 1 2 3
205 prog[1,2] => liba 1 2
206 """)
208 def testLearning(self):
209 # Prog-2 depends on libb and libz, but we can't have both
210 # at once. The learning means we don't have to explore every
211 # possible combination of liba and libb.
212 assertSelection("prog-1", """
213 prog: 1 2
214 liba: 1 2 3
215 libb Linux-i486: 1 2 3
216 libz Linux-x86_64: 1 2
217 prog[2] => liba 1 3
218 prog[2] => libz 1 2
219 liba[1,3] => libb 1 3
220 """)
222 def testToplevelConflict(self):
223 # We don't detect the conflict until we start solving, but the
224 # conflict is top-level so we abort immediately without
225 # backtracking.
226 assertSelection("prog-FAIL", """
227 prog Linux-i386: 1
228 liba Linux-x86_64: 1
229 prog[1] => liba 1 1
230 """)
232 def testDiamondConflict(self):
233 # prog depends on liba and libb, which depend on incompatible
234 # versions of libc.
235 assertSelection("prog-FAIL", """
236 prog: 1
237 liba: 1
238 libb: 1
239 libc: 1 2
240 prog[1] => liba 1 1
241 prog[1] => libb 1 1
242 liba[1] => libc 1 1
243 libb[1] => libc 2 3
244 """)
246 def testCoverage(self):
247 # Try to trigger some edge cases...
249 # An at_most_one clause must be analysed for causing
250 # a conflict.
251 solver = sat.Solver()
252 v1 = solver.add_variable("v1")
253 v2 = solver.add_variable("v2")
254 v3 = solver.add_variable("v3")
255 solver.at_most_one([v1, v2])
256 solver.add_clause([v1, sat.neg(v3)])
257 solver.add_clause([v2, sat.neg(v3)])
258 solver.add_clause([v1, v3])
259 solver.run_solver(lambda: v3)
261 def testFailState(self):
262 # If we can't select a valid combination,
263 # try to select as many as we can.
264 s = assertSelection("prog-FAIL", """
265 prog: 1 2
266 liba: 1 2
267 libb: 1 2
268 libc: 5
269 prog[1,2] => liba 1 2
270 liba[1,2] => libb 1 2
271 libb[1,2] => libc 0 0
272 """)
273 assert not s.ready
274 selected = {}
275 for iface, impl in s.selections.iteritems():
276 if impl is not None: impl = impl.get_version()
277 selected[iface.uri.rsplit('/', 1)[1]] = impl
278 self.assertEquals({
279 'prog': '2',
280 'liba': '2',
281 'libb': '2',
282 'libc': None
283 }, selected)
285 suite = unittest.makeSuite(TestSAT)
286 if __name__ == '__main__':
287 sys.argv.append('-v')
288 unittest.main()