Started implementing SAT solver in Python
[zeroinstall/solver.git] / tests / testsat.py
blob9d8e237afd4855b87a80f3d782a7aaf4997caaef
1 #!/usr/bin/env python
2 from basetest import BaseTest
3 import sys, os
4 import unittest
6 sys.path.insert(0, '..')
7 from zeroinstall.injector import model, arch, qdom
8 from zeroinstall.injector.namespaces import XMLNS_IFACE
10 #from zeroinstall.injector.origsolver import DefaultSolver as Solver
11 #from zeroinstall.injector.pbsolver import PBSolver as Solver
12 #from zeroinstall.injector.sgsolver import DefaultSolver as Solver
13 from zeroinstall.injector.solver import SATSolver as Solver
15 import logging
16 logger = logging.getLogger()
18 class Stores:
19 def lookup_any(self, digests):
20 return "/"
22 stores = Stores()
24 uri_prefix = 'http://localhost/tests/'
26 class Version:
27 def __init__(self, n):
28 self.n = n
29 self.requires = []
31 def add_requires(self, lib, min_v, max_v):
32 self.requires.append((lib, min_v, max_v))
34 class Program:
35 def __init__(self, name):
36 self.name = name
37 self.versions = {}
39 def get_version(self, version):
40 if version not in self.versions:
41 self.versions[version] = Version(version)
42 return self.versions[version]
44 def build_feed(self):
45 def child(parent, name, attrs = None):
46 new = qdom.Element(XMLNS_IFACE, name, attrs or {})
47 parent.childNodes.append(new)
48 return new
50 root = qdom.Element(XMLNS_IFACE, 'interface', {'uri' : uri_prefix + self.name})
51 child(root, 'name').content = self.name
52 child(root, 'summary').content = self.name
54 i = 0
55 for version in self.versions.values():
56 impl = child(root, 'implementation', {
57 'id': str(i),
58 'version': str(version.n),
60 child(impl, 'manifest-digest', {'sha1new': '1234'})
61 for lib, min_v, max_v in version.requires:
62 req = child(impl, 'requires', {'interface': uri_prefix + lib})
63 child(req, 'version', {
64 'before': str(int(max_v) + 1),
65 'not-before': min_v})
66 i += 1
68 feed = model.ZeroInstallFeed(root)
69 feed.last_modified = 1
70 return feed
72 class TestCache:
73 def __init__(self):
74 self.progs = {}
75 self.interfaces = {}
77 def get_prog(self, prog):
78 if not prog in self.progs:
79 self.progs[prog] = Program(prog)
80 return self.progs[prog]
82 def get_interface(self, uri):
83 if uri not in self.interfaces:
84 iface = model.Interface(uri)
85 iface._main_feed = self.progs[uri.rsplit('/', 1)[1]].build_feed()
86 self.interfaces[uri] = iface
87 return self.interfaces[uri]
89 def assertSelection(expected, repo):
90 cache = TestCache()
92 expected = [tuple(e.strip().split('-')) for e in expected.split(",")]
94 for line in repo.split('\n'):
95 line = line.strip()
96 if not line: continue
97 if ':' in line:
98 prog, versions = line.split(':')
99 prog = prog.strip()
100 for v in versions.split():
101 cache.get_prog(prog).get_version(v)
102 elif '=>' in line:
103 prog, requires = line.split('=>')
104 prog, version_range = prog.strip().split('[')
105 lib, min_v, max_v = requires.split()
106 assert version_range.endswith(']')
107 version_range = version_range[:-1]
108 if ',' in version_range:
109 min_p, max_p = map(int, version_range.split(','))
110 prog_versions = range(min_p, max_p + 1)
111 else:
112 prog_versions = [int(version_range)]
113 for prog_version in prog_versions:
114 cache.get_prog(prog).get_version(str(prog_version)).add_requires(lib, min_v, max_v)
116 root = uri_prefix + expected[0][0]
117 s = Solver(model.network_offline, cache, stores)
118 s.solve(root, arch.get_architecture('Linux', 'x86_64'))
120 if expected[0][1] == 'FAIL':
121 assert not s.ready
122 else:
123 assert s.ready
125 actual = []
126 for iface, impl in s.selections.iteritems():
127 actual.append(((iface.uri.rsplit('/', 1)[1]), impl.get_version()))
129 expected.sort()
130 actual.sort()
131 if expected != actual:
132 raise Exception("Solve failed:\nExpected: %s\n Actual: %s" % (expected, actual))
134 class TestSAT(BaseTest):
135 def testTrivial(self):
136 assertSelection("prog-2", """
137 prog: 1 2
138 """)
140 def testSimple(self):
141 assertSelection("prog-5, liba-5", """
142 prog: 1 2 3 4 5
143 liba: 1 2 3 4 5
144 prog[1] => liba 0 4
145 prog[2] => liba 1 5
146 prog[5] => liba 4 5
147 """)
149 def testBestImpossible(self):
150 assertSelection("prog-1", """
151 prog: 1 2
152 liba: 1
153 prog[2] => liba 3 4
154 """)
156 def testSlow(self):
157 assertSelection("prog-1", """
158 prog: 1 2 3 4 5 6 7 8 9
159 liba: 1 2 3 4 5 6 7 8 9
160 libb: 1 2 3 4 5 6 7 8 9
161 libc: 1 2 3 4 5 6 7 8 9
162 libd: 1 2 3 4 5 6 7 8 9
163 libe: 1
164 prog[2,9] => liba 1 9
165 liba[1,9] => libb 1 9
166 libb[1,9] => libc 1 9
167 libc[1,9] => libd 1 9
168 libd[1,9] => libe 0 0
169 """)
171 def testNoSolution(self):
172 assertSelection("prog-FAIL", """
173 prog: 1 2 3
174 liba: 1
175 prog[1,3] => liba 2 3
176 """)
178 suite = unittest.makeSuite(TestSAT)
179 if __name__ == '__main__':
180 sys.argv.append('-v')
181 unittest.main()