Test SAT learning
[zeroinstall/solver.git] / tests / testsat.py
blob673313f8246d222fb686859fd79ece9faf945080
1 #!/usr/bin/env python
2 from basetest import BaseTest
3 import sys, os
4 import unittest
6 sys.path.insert(0, '..')
7 from zeroinstall.injector import model, arch, qdom
8 from zeroinstall.injector.namespaces import XMLNS_IFACE
10 #from zeroinstall.injector.origsolver import DefaultSolver as Solver
11 #from zeroinstall.injector.pbsolver import PBSolver as Solver
12 #from zeroinstall.injector.sgsolver import DefaultSolver as Solver
13 from zeroinstall.injector.solver import SATSolver as Solver
15 import logging
16 logger = logging.getLogger()
18 class Stores:
19 def lookup_any(self, digests):
20 return "/"
22 stores = Stores()
24 uri_prefix = 'http://localhost/tests/'
26 class Version:
27 def __init__(self, n):
28 self.n = n
29 self.requires = []
30 self.arch = None
32 def add_requires(self, lib, min_v, max_v):
33 self.requires.append((lib, min_v, max_v))
35 class Program:
36 def __init__(self, name):
37 self.name = name
38 self.versions = {}
40 def get_version(self, version):
41 if version not in self.versions:
42 self.versions[version] = Version(version)
43 return self.versions[version]
45 def build_feed(self):
46 def child(parent, name, attrs = None):
47 new = qdom.Element(XMLNS_IFACE, name, attrs or {})
48 parent.childNodes.append(new)
49 return new
51 root = qdom.Element(XMLNS_IFACE, 'interface', {'uri' : uri_prefix + self.name})
52 child(root, 'name').content = self.name
53 child(root, 'summary').content = self.name
55 i = 0
56 for version in self.versions.values():
57 attrs = {
58 'id': str(i),
59 'version': str(version.n),
61 if version.arch:
62 attrs['arch'] = version.arch
63 impl = child(root, 'implementation', attrs)
64 child(impl, 'manifest-digest', {'sha1new': '1234'})
65 for lib, min_v, max_v in version.requires:
66 req = child(impl, 'requires', {'interface': uri_prefix + lib})
67 child(req, 'version', {
68 'before': str(int(max_v) + 1),
69 'not-before': min_v})
70 i += 1
72 feed = model.ZeroInstallFeed(root)
73 feed.last_modified = 1
74 return feed
76 class TestCache:
77 def __init__(self):
78 self.progs = {}
79 self.interfaces = {}
81 def get_prog(self, prog):
82 if not prog in self.progs:
83 self.progs[prog] = Program(prog)
84 return self.progs[prog]
86 def get_interface(self, uri):
87 if uri not in self.interfaces:
88 iface = model.Interface(uri)
89 iface._main_feed = self.progs[uri.rsplit('/', 1)[1]].build_feed()
90 self.interfaces[uri] = iface
91 return self.interfaces[uri]
93 def assertSelection(expected, repo):
94 cache = TestCache()
96 expected = [tuple(e.strip().split('-')) for e in expected.split(",")]
98 for line in repo.split('\n'):
99 line = line.strip()
100 if not line: continue
101 if ':' in line:
102 prog, versions = line.split(':')
103 prog = prog.strip()
104 if ' ' in prog:
105 prog, prog_arch = prog.split()
106 else:
107 prog_arch = None
108 for v in versions.split():
109 cache.get_prog(prog).get_version(v).arch = prog_arch
110 elif '=>' in line:
111 prog, requires = line.split('=>')
112 prog, version_range = prog.strip().split('[')
113 lib, min_v, max_v = requires.split()
114 assert version_range.endswith(']')
115 version_range = version_range[:-1]
116 if ',' in version_range:
117 min_p, max_p = map(int, version_range.split(','))
118 prog_versions = range(min_p, max_p + 1)
119 else:
120 prog_versions = [int(version_range)]
121 for prog_version in prog_versions:
122 cache.get_prog(prog).get_version(str(prog_version)).add_requires(lib, min_v, max_v)
124 root = uri_prefix + expected[0][0]
125 s = Solver(model.network_offline, cache, stores)
126 s.solve(root, arch.get_architecture('Linux', 'x86_64'))
128 if expected[0][1] == 'FAIL':
129 assert not s.ready
130 else:
131 assert s.ready
133 actual = []
134 for iface, impl in s.selections.iteritems():
135 actual.append(((iface.uri.rsplit('/', 1)[1]), impl.get_version()))
137 expected.sort()
138 actual.sort()
139 if expected != actual:
140 raise Exception("Solve failed:\nExpected: %s\n Actual: %s" % (expected, actual))
142 class TestSAT(BaseTest):
143 def testTrivial(self):
144 assertSelection("prog-2", """
145 prog: 1 2
146 """)
148 def testSimple(self):
149 assertSelection("prog-5, liba-5", """
150 prog: 1 2 3 4 5
151 liba: 1 2 3 4 5
152 prog[1] => liba 0 4
153 prog[2] => liba 1 5
154 prog[5] => liba 4 5
155 """)
157 def testBestImpossible(self):
158 assertSelection("prog-1", """
159 prog: 1 2
160 liba: 1
161 prog[2] => liba 3 4
162 """)
164 def testSlow(self):
165 assertSelection("prog-1", """
166 prog: 1 2 3 4 5 6 7 8 9
167 liba: 1 2 3 4 5 6 7 8 9
168 libb: 1 2 3 4 5 6 7 8 9
169 libc: 1 2 3 4 5 6 7 8 9
170 libd: 1 2 3 4 5 6 7 8 9
171 libe: 1
172 prog[2,9] => liba 1 9
173 liba[1,9] => libb 1 9
174 libb[1,9] => libc 1 9
175 libc[1,9] => libd 1 9
176 libd[1,9] => libe 0 0
177 """)
179 def testNoSolution(self):
180 assertSelection("prog-FAIL", """
181 prog: 1 2 3
182 liba: 1
183 prog[1,3] => liba 2 3
184 """)
186 def testBacktrackSimple(self):
187 # We initially try liba-3 before learning that it
188 # is incompatible and backtracking.
189 # We learn that liba-3 doesn't work ever.
190 assertSelection("prog-1, liba-2", """
191 prog: 1
192 liba: 1 2 3
193 prog[1] => liba 1 2
194 """)
196 def testBacktrackLocal(self):
197 # We initially try liba-3 before learning that it
198 # is incompatible and backtracking.
199 # We learn that liba-3 doesn't work with prog-1.
200 assertSelection("prog-2, liba-2", """
201 prog: 1 2
202 liba: 1 2 3
203 prog[1,2] => liba 1 2
204 """)
206 def testLearning(self):
207 # Prog-2 depends on libb and libz, but we can't have both
208 # at once. The learning means we don't have to explore every
209 # possible combination of liba and libb.
210 assertSelection("prog-1", """
211 prog: 1 2
212 liba: 1 2 3
213 libb Linux-i486: 1 2 3
214 libz Linux-x86_64: 1 2
215 prog[2] => liba 1 3
216 prog[2] => libz 1 2
217 liba[1,3] => libb 1 3
218 """)
220 suite = unittest.makeSuite(TestSAT)
221 if __name__ == '__main__':
222 sys.argv.append('-v')
223 unittest.main()