Test SAT learning
[zeroinstall/solver.git] / zeroinstall / injector / sat.py
blobf4904a649ce2dd30a4c0c6f896567fafc362c1c2
1 # Copyright (C) 2010, Thomas Leonard
2 # See the README file for details, or visit http://0install.net.
4 # The design of this solver is very heavily based on the one described in
5 # the MiniSat paper "An Extensible SAT-solver [extended version 1.2]"
6 # http://minisat.se/Papers.html
8 # The main differences are:
10 # - We care about which solution we find (not use "satisfiable" or "not").
11 # - We take care to be deterministic (always select the same versions given
12 # the same input). We do not do random restarts, etc.
13 # - We add an AtMostOneClause (the paper suggests this in the Excercises, and
14 # it's very useful for our purposes).
15 # - We don't currently do conflict-driven learning.
17 # Also, as this is a work-in-progress, we don't support back-tracking yet!
19 import tempfile, subprocess, os, sys
20 from logging import warn
22 def debug(msg, *args):
23 return
24 print "SAT:", msg % args
26 # variables are numbered from 0
27 # literals have the same number as the corresponding variable,
28 # except they for negatives they are (-1-v):
30 # Variable Literal not(Literal)
31 # 0 0 -1
32 # 1 1 -2
33 def neg(lit):
34 return -1 - lit
36 def watch_index(lit):
37 if lit >= 0:
38 return lit * 2
39 return neg(lit) * 2 + 1
41 def makeAtMostOneClause(solver):
42 class AtMostOneClause:
43 def __init__(self, lits):
44 """Preferred literals come first."""
45 self.lits = lits
47 # The single literal from our set that is True.
48 # We store this explicitly because the decider needs to know quickly.
49 self.current = None
51 # Remove ourself from solver
52 def remove(self):
53 raise "help" #solver.watches.remove(index(neg(lits[0]))]
55 # Simplify ourself and return True if we are no longer needed,
56 # or False if we are.
57 def simplify(self):
58 # TODO
59 return False
61 def propagate(self, lit):
62 # value[lit] has just become True
63 assert solver.lit_value(lit) == True
64 assert lit >= 0
66 #debug("%s: noticed %s has become True" % (self, solver.name_lit(lit)))
68 # One is already selected
69 if self.current is not None:
70 debug("CONFLICT: already selected %s" % self.current)
71 return False
73 self.current = lit
75 # If we later backtrace, call our undo function to unset current
76 solver.get_varinfo_for_lit(lit).undo.append(self)
78 # Re-add ourselves to the watch list.
79 # (we we won't get any more notifications unless we backtrack,
80 # in which case we'd need to get back on the list anyway)
81 solver.watch_lit(lit, self)
83 count = 0
84 for l in self.lits:
85 value = solver.lit_value(l)
86 #debug("Value of %s is %s" % (solver.name_lit(l), value))
87 if value is True:
88 count += 1
89 if count > 1:
90 debug("CONFLICT: already selected %s" % self.current)
91 return False
92 if value is None:
93 # Since one of our lits is already true, all unknown ones
94 # can be set to False.
95 if not solver.enqueue(neg(l), self):
96 debug("CONFLICT: enqueue failed for %s", solver.name_lit(neg(l)))
97 return False # Conflict; abort
99 return True
101 def undo(self, lit):
102 debug("(backtracking: no longer selected %s)" % solver.name_lit(lit))
103 assert lit == self.current
104 self.current = None
106 # Why is lit True?
107 # Or, why are we causing a conflict (if lit is None)?
108 def cacl_reason(self, lit):
109 if lit is None:
110 # Find two True literals
111 trues = []
112 for l in self.lits:
113 if solver.lit_value(l) is True:
114 trues.append(l)
115 if len(trues) == 2: return trues
116 else:
117 for l in self.lits:
118 if l is not lit and solver.lit_value(l) is True:
119 return [l]
120 # Find one True literal
121 assert 0 # don't know why!
123 def best_undecided(self):
124 debug("best_undecided: %s" % (solver.name_lits(self.lits)))
125 for lit in self.lits:
126 #debug("%s = %s" % (solver.name_lit(lit), solver.lit_value(lit)))
127 if solver.lit_value(lit) is None:
128 return lit
129 return None
131 def __repr__(self):
132 return "<lone: %s>" % (', '.join(solver.name_lits(self.lits)))
134 return AtMostOneClause
136 def makeUnionClause(solver):
137 class UnionClause:
138 def __init__(self, lits):
139 self.lits = lits
141 # Remove ourself from solver
142 def remove(self):
143 raise "help" #solver.watches.remove(index(neg(lits[0]))]
145 # Simplify ourself and return True if we are no longer needed,
146 # or False if we are.
147 def simplify(self):
148 new_lits = []
149 for l in self.lits:
150 value = solver.lit_value(l)
151 if value == True:
152 # (... or True or ...) = True
153 return True
154 elif value == None:
155 new_lits.append(l)
156 self.lits = new_lits
157 return False
159 # Try to infer new facts.
160 # We can do this only when all of our literals are False except one,
161 # which is undecided. That is,
162 # False... or X or False... = True => X = True
164 # To get notified when this happens, we tell the solver to
165 # watch two of our undecided literals. Watching two undecided
166 # literals is sufficient. When one changes we check the state
167 # again. If we still have two or more undecided then we switch
168 # to watching them, otherwise we propagate.
170 # Returns False on conflict.
171 def propagate(self, lit):
172 # value[get(lit)] has just become False
174 #debug("%s: noticed %s has become False" % (self, solver.name_lit(neg(lit))))
176 # For simplicity, only handle the case where self.lits[1]
177 # is the one that just got set to False, so that:
178 # - value[lits[0]] = None | True
179 # - value[lits[1]] = False
180 # If it's the other way around, just swap them before we start.
181 if self.lits[0] == neg(lit):
182 self.lits[0], self.lits[1] = self.lits[1], self.lits[0]
184 if solver.lit_value(self.lits[0]) == True:
185 # We're already satisfied. Do nothing.
186 solver.watch_lit(lit, self)
187 return True
189 # Find a new literal to watch now that lits[1] is resolved,
190 # swap it with lits[1], and start watching it.
191 for i in range(2, len(self.lits)):
192 value = solver.lit_value(self.lits[i])
193 if value != False:
194 # Could be None or True. If it's True then we've already done our job,
195 # so this means we don't get notified unless we backtrack, which is fine.
196 self.lits[1], self.lits[i] = self.lits[i], self.lits[1]
197 solver.watch_lit(self.lits[1], self) # ??
198 return True
200 # Only lits[0], is now undefined.
201 solver.watch_lit(lit, self)
202 return solver.enqueue(self.lits[0], self)
204 def undo(self, lit): pass
206 # Why is lit True?
207 # Or, why are we causing a conflict (if lit is None)?
208 def cacl_reason(self, lit):
209 assert lit is None or lit is self.lits[0]
211 # The cause is everything except lit.
212 return [neg(l) for l in self.lits if l is not lit]
214 def __repr__(self):
215 return "<some: %s>" % (', '.join(solver.name_lits(self.lits)))
216 return UnionClause
218 # Using an array of VarInfo objects is less efficient than using multiple arrays, but
219 # easier for me to understand.
220 class VarInfo(object):
221 __slots__ = ['value', 'reason', 'level', 'undo', 'obj']
222 def __init__(self, obj):
223 self.value = None # True/False/None
224 self.reason = None # The constraint that implied our value, if True or False
225 self.level = -1 # The decision level at which we got a value (when not None)
226 self.undo = [] # Constraints to update if we become unbound (by backtracking)
227 self.obj = obj # The object this corresponds to (for our caller and for debugging)
229 def __repr__(self):
230 return '%s=%s' % (self.name, self.value)
232 @property
233 def name(self):
234 return str(self.obj)
236 class Solver(object):
237 def __init__(self):
238 # Constraints
239 self.constrs = [] # Constraints set by our user XXX - do we ever use this?
240 self.learnt = [] # Constraints we learnt while solving
241 # order?
243 # Propagation
244 self.watches = [] # watches[2i,2i+1] = constraints to check when literal[i] becomes True/False
245 self.propQ = [] # propagation queue
247 # Assignments
248 self.assigns = [] # [VarInfo]
249 self.trail = [] # order of assignments
250 self.trail_lim = [] # decision levels
252 self.toplevel_conflict = False
254 self.makeAtMostOneClause = makeAtMostOneClause(self)
255 self.makeUnionClause = makeUnionClause(self)
257 def get_decision_level(self):
258 return len(self.trail_lim)
260 def add_variable(self, obj):
261 index = len(self.assigns)
263 self.watches += [[], []] # Add watch lists for X and not(X)
264 self.assigns.append(VarInfo(obj))
265 return index
267 # lit is now True
268 # reason is the clause that is asserting this
269 # Returns False if this immediately causes a conflict.
270 def enqueue(self, lit, reason):
271 debug("%s => %s" % (reason, self.name_lit(lit)))
272 old_value = self.lit_value(lit)
273 if old_value is not None:
274 if old_value is False:
275 # Conflict
276 return False
277 else:
278 # Already set
279 return True
281 if lit < 0:
282 var_info = self.assigns[neg(lit)]
283 var_info.value = False
284 else:
285 var_info = self.assigns[lit]
286 var_info.value = True
287 var_info.level = self.get_decision_level()
288 var_info.reason = reason
290 self.trail.append(lit)
291 self.propQ.append(lit)
293 return True
295 # Pop most recent assignment from self.trail
296 def undo_one(self):
297 lit = self.trail[-1]
298 debug("(pop %s)", self.name_lit(lit))
299 var_info = self.get_varinfo_for_lit(lit)
300 var_info.value = None
301 var_info.reason = None
302 var_info.level = -1
303 self.trail.pop()
305 while var_info.undo:
306 var_info.undo.pop().undo(lit)
308 def cancel(self):
309 n_this_level = len(self.trail) - self.trail_lim[-1]
310 debug("backtracking from level %d (%d assignments)" %
311 (self.get_decision_level(), n_this_level))
312 while n_this_level != 0:
313 self.undo_one()
314 n_this_level -= 1
315 self.trail_lim.pop()
317 def cancel_until(self, level):
318 while self.get_decision_level() > level:
319 self.cancel()
321 # Process the propQ.
322 # Returns None when done, or the clause that caused a conflict.
323 def propagate(self):
324 #debug("propagate: queue length = %d", len(self.propQ))
325 while self.propQ:
326 lit = self.propQ[0]
327 del self.propQ[0]
328 var_info = self.get_varinfo_for_lit(lit)
329 wi = watch_index(lit)
330 watches = self.watches[wi]
331 self.watches[wi] = []
333 debug("%s -> True : watches: %s" % (self.name_lit(lit), watches))
335 # Notifiy all watchers
336 for i in range(len(watches)):
337 clause = watches[i]
338 if not clause.propagate(lit):
339 # Conflict
341 # Re-add remaining watches
342 self.watches[wi] += watches[i+1:]
344 # No point processing the rest of the queue as
345 # we'll have to backtrack now.
346 self.propQ = []
348 return clause
349 return None
351 def impossible(self):
352 self.toplevel_conflict = True
354 def get_varinfo_for_lit(self, lit):
355 if lit >= 0:
356 return self.assigns[lit]
357 else:
358 return self.assigns[neg(lit)]
360 def lit_value(self, lit):
361 if lit >= 0:
362 value = self.assigns[lit].value
363 return value
364 else:
365 v = -1 - lit
366 value = self.assigns[v].value
367 if value is None:
368 return None
369 else:
370 return not value
372 # Call cb when lit becomes True
373 def watch_lit(self, lit, cb):
374 #debug("%s is watching for %s to become True" % (cb, self.name_lit(lit)))
375 self.watches[watch_index(lit)].append(cb)
377 # Returns the new clause if one was added, True if none was added
378 # because this clause is trivially True, or False if the clause is
379 # False.
380 def _add_clause(self, lits, learnt):
381 if not lits:
382 assert not learnt
383 self.toplevel_conflict = True
384 return False
385 elif len(lits) == 1:
386 # A clause with only a single literal is represented
387 # as an assignment rather than as a clause.
388 if learnt:
389 reason = "learnt"
390 else:
391 reason = "top-level"
392 return self.enqueue(lits[0], reason)
394 clause = self.makeUnionClause(lits)
395 clause.learnt = learnt
396 self.constrs.append(clause)
398 if learnt:
399 # lits[0] is None because we just backtracked.
400 # Start watching the next literal that we will
401 # backtrack over.
402 best_level = -1
403 best_i = 1
404 for i in range(1, len(lits)):
405 level = self.get_varinfo_for_lit(lits[i]).level
406 if level > best_level:
407 best_level = level
408 best_i = i
409 lits[1], lits[best_i] = lits[best_i], lits[1]
411 # Watch the first two literals in the clause (both must be
412 # undefined at this point).
413 for lit in lits[:2]:
414 self.watch_lit(neg(lit), clause)
416 return clause
418 def name_lits(self, lst):
419 return [self.name_lit(l) for l in lst]
421 # For nicer debug messages
422 def name_lit(self, lit):
423 if lit >= 0:
424 return self.assigns[lit].name
425 return "not(%s)" % self.assigns[neg(lit)].name
427 def add_clause(self, lits):
428 # Public interface. Only used before the solve starts.
429 assert lits
431 debug("add_clause: %s" % self.name_lits(lits))
433 if any(self.lit_value(l) == True for l in lits):
434 # Trivially true already.
435 return True
436 lit_set = set(lits)
437 for l in lits:
438 if neg(l) in lit_set:
439 # X or not(X) is always True.
440 return True
441 # Remove duplicates and values known to be False
442 lits = [l for l in lit_set if self.lit_value(l) != False]
444 return self._add_clause(lits, learnt = False)
446 def at_most_one(self, lits):
447 assert lits
449 debug("at_most_one: %s" % self.name_lits(lits))
451 # If we have zero or one literals then we're trivially true
452 # and not really needed for the solve. However, Zero Install
453 # monitors these objects to find out what was selected, so
454 # keep even trivial ones around for that.
456 #if len(lits) < 2:
457 # return True # Trivially true
459 # Ensure no duplicates
460 assert len(set(lits)) == len(lits), lits
462 # Ignore any literals already known to be False.
463 # If any are True then they're enqueued and we'll process them
464 # soon.
465 lits = [l for l in lits if self.lit_value(l) != False]
467 clause = self.makeAtMostOneClause(lits)
469 self.constrs.append(clause)
471 for lit in lits:
472 self.watch_lit(lit, clause)
474 return clause
476 def analyse(self, cause):
477 # After trying some assignments, we've discovered a conflict.
478 # e.g.
479 # - we selected A then B then C
480 # - from A, B, C we got X, Y
481 # - we have a rule: not(A) or not(X) or not(Y)
483 # The simplest thing to do would be:
484 # 1. add the rule "not(A) or not(B) or not(C)"
485 # 2. unassign C
487 # Then we we'd deduce not(C) and we could try something else.
488 # However, that would be inefficient. We want to learn a more
489 # general rule that will help us with the rest of the problem.
491 # We take the clause that caused the conflict ("cause") and
492 # ask it for its cause. In this case:
494 # A and X and Y => conflict
496 # Since X and Y followed logically from A, B, C there's no
497 # point learning this rule; we need to know to avoid A, B, C
498 # *before* choosing C. We ask the two variables deduced at the
499 # current level (X and Y) what caused them, and work backwards.
500 # e.g.
502 # X: A and C => X
503 # Y: C => Y
505 # Combining these, we get the cause of the conflict in terms of
506 # things we knew before the current decision level:
508 # A and X and Y => conflict
509 # A and (A and C) and (C) => conflict
510 # A and C => conflict
512 # We can then learn (record) the more general rule:
514 # not(A) or not(C)
516 # Then, in future, whenever A is selected we can remove C and
517 # everything that depends on it from consideration.
520 learnt = [None] # The general rule we're learning
521 btlevel = 0 # The deepest decision in learnt
522 p = None # The literal we want to expand now
523 seen = set() # The variables involved in the conflict
525 counter = 0
527 while True:
528 # cause is the reason why p is True (i.e. it enqueued it).
529 # The first time, p is None, which requests the reason
530 # why it is conflicting.
531 if p is None:
532 debug("Why did %s make us fail?" % cause)
533 p_reason = cause.cacl_reason(p)
534 debug("Because: %s => conflict" % (' and '.join(self.name_lits(p_reason))))
535 else:
536 debug("Why did %s lead to %s?" % (cause, self.name_lit(p)))
537 p_reason = cause.cacl_reason(p)
538 debug("Because: %s => %s" % (' and '.join(self.name_lits(p_reason)), self.name_lit(p)))
540 # p_reason is in the form (A and B and ...)
541 # p_reason => p
543 # Check each of the variables in p_reason that we haven't
544 # already considered:
545 # - if the variable was assigned at the current level,
546 # mark it for expansion
547 # - otherwise, add it to learnt
549 for lit in p_reason:
550 var_info = self.get_varinfo_for_lit(lit)
551 if var_info not in seen:
552 seen.add(var_info)
553 if var_info.level == self.get_decision_level():
554 # We deduced this var since the last decision.
555 # It must be in self.trail, so we'll get to it
556 # soon. Remember not to stop until we've processed it.
557 counter += 1
558 elif var_info.level > 0:
559 # We won't expand lit, just remember it.
560 # (we could expand it if it's not a decision, but
561 # apparently not doing so is useful)
562 learnt.append(neg(lit))
563 btlevel = max(btlevel, var_info.level)
564 # else we already considered the cause of this assignment
566 # At this point, counter is the number of assigned
567 # variables in self.trail at the current decision level that
568 # we've seen. That is, the number left to process. Pop
569 # the next one off self.trail (as well as any unrelated
570 # variables before it; everything up to the previous
571 # decision has to go anyway).
573 while True:
574 p = self.trail[-1]
575 var_info = self.get_varinfo_for_lit(p)
576 cause = var_info.reason
577 self.undo_one()
578 if var_info in seen:
579 break
580 debug("(irrelevant)")
581 counter -= 1
583 if counter <= 0:
584 # If counter = 0 then we still have one more
585 # literal (p) at the current level that we
586 # could expand. However, apparently it's best
587 # to leave this unprocessed (says the minisat
588 # paper).
589 # If counter is -1 then we popped one extra
590 # assignment, but it doesn't matter because
591 # either it's at this level or it's the
592 # decision that led to this level. Either way,
593 # we'd have removed it anyway.
594 # XXX: is this true? won't self.cancel() get upset?
595 break
597 # p is the literal we decided to stop processing on. It's either
598 # a derived variable at the current level, or the decision that
599 # led to this level. Since we're not going to expand it, add it
600 # directly to the learnt clause.
601 learnt[0] = neg(p)
603 debug("Learnt: %s" % (' or '.join(self.name_lits(learnt))))
605 return learnt, btlevel
607 def run_solver(self, decide):
608 # Check whether we detected a trivial problem
609 # during setup.
610 if self.toplevel_conflict:
611 return False
613 while True:
614 # Use logical deduction to simplify the clauses
615 # and assign literals where there is only one possibility.
616 conflicting_clause = self.propagate()
617 if not conflicting_clause:
618 debug("new state: %s", self.assigns)
619 if all(info.value != None for info in self.assigns):
620 # Everything is assigned without conflicts
621 debug("SUCCESS!")
622 return True
623 else:
624 # Pick a variable and try assigning it one way.
625 # If it leads to a conflict, we'll backtrack and
626 # try it the other way.
627 lit = decide()
628 if lit is None:
629 debug("decide -> None")
630 return False
631 assert self.lit_value(lit) is None
632 self.trail_lim.append(len(self.trail))
633 r = self.enqueue(lit, reason = "considering")
634 assert r is True
635 else:
636 if self.get_decision_level() == 0:
637 return False
638 else:
639 # Figure out the root cause of this failure.
640 learnt, backtrack_level = self.analyse(conflicting_clause)
642 self.cancel_until(backtrack_level)
644 c = self._add_clause(learnt, learnt = True)
646 if c is not True:
647 # Everything except the first literal in learnt is known to
648 # be False, so the first must be True.
649 e = self.enqueue(learnt[0], c)
650 assert e is True
652 return ready, selected