Converted print statements to functions
[zeroinstall.git] / zeroinstall / injector / sat.py
blob9dbce04aed88e9b4f77e53a72624a79a72237a72
1 """
2 Internal implementation of a SAT solver, used by L{solver.SATSolver}.
3 This is not part of the public API.
4 """
6 from __future__ import print_function
8 # Copyright (C) 2010, Thomas Leonard
9 # See the README file for details, or visit http://0install.net.
11 # The design of this solver is very heavily based on the one described in
12 # the MiniSat paper "An Extensible SAT-solver [extended version 1.2]"
13 # http://minisat.se/Papers.html
15 # The main differences are:
17 # - We care about which solution we find (not just "satisfiable" or "not").
18 # - We take care to be deterministic (always select the same versions given
19 # the same input). We do not do random restarts, etc.
20 # - We add an AtMostOneClause (the paper suggests this in the Excercises, and
21 # it's very useful for our purposes).
23 def debug(msg, *args):
24 return
25 print("SAT:", msg % args)
27 # variables are numbered from 0
28 # literals have the same number as the corresponding variable,
29 # except they for negatives they are (-1-v):
31 # Variable Literal not(Literal)
32 # 0 0 -1
33 # 1 1 -2
34 def neg(lit):
35 return -1 - lit
37 def watch_index(lit):
38 if lit >= 0:
39 return lit * 2
40 return neg(lit) * 2 + 1
42 def makeAtMostOneClause(solver):
43 class AtMostOneClause:
44 def __init__(self, lits):
45 """Preferred literals come first."""
46 self.lits = lits
48 # The single literal from our set that is True.
49 # We store this explicitly because the decider needs to know quickly.
50 self.current = None
52 def propagate(self, lit):
53 # Re-add ourselves to the watch list.
54 # (we we won't get any more notifications unless we backtrack,
55 # in which case we'd need to get back on the list anyway)
56 solver.watch_lit(lit, self)
58 # value[lit] has just become True
59 assert solver.lit_value(lit) == True
60 assert lit >= 0
62 #debug("%s: noticed %s has become True" % (self, solver.name_lit(lit)))
64 # If we already propagated successfully when the first
65 # one was set then we set all the others to False and
66 # anyone trying to set one True will get rejected. And
67 # if we didn't propagate yet, current will still be
68 # None, even if we now have a conflict (which we'll
69 # detect below).
70 assert self.current is None
72 self.current = lit
74 # If we later backtrace, call our undo function to unset current
75 solver.get_varinfo_for_lit(lit).undo.append(self)
77 for l in self.lits:
78 value = solver.lit_value(l)
79 #debug("Value of %s is %s" % (solver.name_lit(l), value))
80 if value is True and l is not lit:
81 # Due to queuing, we might get called with current = None
82 # and two versions already selected.
83 debug("CONFLICT: already selected %s" % solver.name_lit(l))
84 return False
85 if value is None:
86 # Since one of our lits is already true, all unknown ones
87 # can be set to False.
88 if not solver.enqueue(neg(l), self):
89 debug("CONFLICT: enqueue failed for %s", solver.name_lit(neg(l)))
90 return False # Conflict; abort
92 return True
94 def undo(self, lit):
95 debug("(backtracking: no longer selected %s)" % solver.name_lit(lit))
96 assert lit == self.current
97 self.current = None
99 # Why is lit True?
100 # Or, why are we causing a conflict (if lit is None)?
101 def cacl_reason(self, lit):
102 if lit is None:
103 # Find two True literals
104 trues = []
105 for l in self.lits:
106 if solver.lit_value(l) is True:
107 trues.append(l)
108 if len(trues) == 2: return trues
109 else:
110 for l in self.lits:
111 if l is not lit and solver.lit_value(l) is True:
112 return [l]
113 # Find one True literal
114 assert 0 # don't know why!
116 def best_undecided(self):
117 debug("best_undecided: %s" % (solver.name_lits(self.lits)))
118 for lit in self.lits:
119 #debug("%s = %s" % (solver.name_lit(lit), solver.lit_value(lit)))
120 if solver.lit_value(lit) is None:
121 return lit
122 return None
124 def __repr__(self):
125 return "<lone: %s>" % (', '.join(solver.name_lits(self.lits)))
127 return AtMostOneClause
129 def makeUnionClause(solver):
130 class UnionClause:
131 def __init__(self, lits):
132 self.lits = lits
134 # Try to infer new facts.
135 # We can do this only when all of our literals are False except one,
136 # which is undecided. That is,
137 # False... or X or False... = True => X = True
139 # To get notified when this happens, we tell the solver to
140 # watch two of our undecided literals. Watching two undecided
141 # literals is sufficient. When one changes we check the state
142 # again. If we still have two or more undecided then we switch
143 # to watching them, otherwise we propagate.
145 # Returns False on conflict.
146 def propagate(self, lit):
147 # value[get(lit)] has just become False
149 #debug("%s: noticed %s has become False" % (self, solver.name_lit(neg(lit))))
151 # For simplicity, only handle the case where self.lits[1]
152 # is the one that just got set to False, so that:
153 # - value[lits[0]] = None | True
154 # - value[lits[1]] = False
155 # If it's the other way around, just swap them before we start.
156 if self.lits[0] == neg(lit):
157 self.lits[0], self.lits[1] = self.lits[1], self.lits[0]
159 if solver.lit_value(self.lits[0]) == True:
160 # We're already satisfied. Do nothing.
161 solver.watch_lit(lit, self)
162 return True
164 assert solver.lit_value(self.lits[1]) == False
166 # Find a new literal to watch now that lits[1] is resolved,
167 # swap it with lits[1], and start watching it.
168 for i in range(2, len(self.lits)):
169 value = solver.lit_value(self.lits[i])
170 if value != False:
171 # Could be None or True. If it's True then we've already done our job,
172 # so this means we don't get notified unless we backtrack, which is fine.
173 self.lits[1], self.lits[i] = self.lits[i], self.lits[1]
174 solver.watch_lit(neg(self.lits[1]), self)
175 return True
177 # Only lits[0], is now undefined.
178 solver.watch_lit(lit, self)
179 return solver.enqueue(self.lits[0], self)
181 def undo(self, lit): pass
183 # Why is lit True?
184 # Or, why are we causing a conflict (if lit is None)?
185 def cacl_reason(self, lit):
186 assert lit is None or lit is self.lits[0]
188 # The cause is everything except lit.
189 return [neg(l) for l in self.lits if l is not lit]
191 def __repr__(self):
192 return "<some: %s>" % (', '.join(solver.name_lits(self.lits)))
193 return UnionClause
195 # Using an array of VarInfo objects is less efficient than using multiple arrays, but
196 # easier for me to understand.
197 class VarInfo(object):
198 __slots__ = ['value', 'reason', 'level', 'undo', 'obj']
199 def __init__(self, obj):
200 self.value = None # True/False/None
201 self.reason = None # The constraint that implied our value, if True or False
202 self.level = -1 # The decision level at which we got a value (when not None)
203 self.undo = [] # Constraints to update if we become unbound (by backtracking)
204 self.obj = obj # The object this corresponds to (for our caller and for debugging)
206 def __repr__(self):
207 return '%s=%s' % (self.name, self.value)
209 @property
210 def name(self):
211 return str(self.obj)
213 class SATProblem(object):
214 def __init__(self):
215 # Propagation
216 self.watches = [] # watches[2i,2i+1] = constraints to check when literal[i] becomes True/False
217 self.propQ = [] # propagation queue
219 # Assignments
220 self.assigns = [] # [VarInfo]
221 self.trail = [] # order of assignments
222 self.trail_lim = [] # decision levels
224 self.toplevel_conflict = False
226 self.makeAtMostOneClause = makeAtMostOneClause(self)
227 self.makeUnionClause = makeUnionClause(self)
229 def get_decision_level(self):
230 return len(self.trail_lim)
232 def add_variable(self, obj):
233 debug("add_variable('%s')", obj)
234 index = len(self.assigns)
236 self.watches += [[], []] # Add watch lists for X and not(X)
237 self.assigns.append(VarInfo(obj))
238 return index
240 # lit is now True
241 # reason is the clause that is asserting this
242 # Returns False if this immediately causes a conflict.
243 def enqueue(self, lit, reason):
244 debug("%s => %s" % (reason, self.name_lit(lit)))
245 old_value = self.lit_value(lit)
246 if old_value is not None:
247 if old_value is False:
248 # Conflict
249 return False
250 else:
251 # Already set (shouldn't happen)
252 return True
254 if lit < 0:
255 var_info = self.assigns[neg(lit)]
256 var_info.value = False
257 else:
258 var_info = self.assigns[lit]
259 var_info.value = True
260 var_info.level = self.get_decision_level()
261 var_info.reason = reason
263 self.trail.append(lit)
264 self.propQ.append(lit)
266 return True
268 # Pop most recent assignment from self.trail
269 def undo_one(self):
270 lit = self.trail[-1]
271 debug("(pop %s)", self.name_lit(lit))
272 var_info = self.get_varinfo_for_lit(lit)
273 var_info.value = None
274 var_info.reason = None
275 var_info.level = -1
276 self.trail.pop()
278 while var_info.undo:
279 var_info.undo.pop().undo(lit)
281 def cancel(self):
282 n_this_level = len(self.trail) - self.trail_lim[-1]
283 debug("backtracking from level %d (%d assignments)" %
284 (self.get_decision_level(), n_this_level))
285 while n_this_level != 0:
286 self.undo_one()
287 n_this_level -= 1
288 self.trail_lim.pop()
290 def cancel_until(self, level):
291 while self.get_decision_level() > level:
292 self.cancel()
294 # Process the propQ.
295 # Returns None when done, or the clause that caused a conflict.
296 def propagate(self):
297 #debug("propagate: queue length = %d", len(self.propQ))
298 while self.propQ:
299 lit = self.propQ[0]
300 del self.propQ[0]
301 wi = watch_index(lit)
302 watches = self.watches[wi]
303 self.watches[wi] = []
305 debug("%s -> True : watches: %s" % (self.name_lit(lit), watches))
307 # Notifiy all watchers
308 for i in range(len(watches)):
309 clause = watches[i]
310 if not clause.propagate(lit):
311 # Conflict
313 # Re-add remaining watches
314 self.watches[wi] += watches[i+1:]
316 # No point processing the rest of the queue as
317 # we'll have to backtrack now.
318 self.propQ = []
320 return clause
321 return None
323 def impossible(self):
324 self.toplevel_conflict = True
326 def get_varinfo_for_lit(self, lit):
327 if lit >= 0:
328 return self.assigns[lit]
329 else:
330 return self.assigns[neg(lit)]
332 def lit_value(self, lit):
333 if lit >= 0:
334 value = self.assigns[lit].value
335 return value
336 else:
337 v = -1 - lit
338 value = self.assigns[v].value
339 if value is None:
340 return None
341 else:
342 return not value
344 # Call cb when lit becomes True
345 def watch_lit(self, lit, cb):
346 #debug("%s is watching for %s to become True" % (cb, self.name_lit(lit)))
347 self.watches[watch_index(lit)].append(cb)
349 # Returns the new clause if one was added, True if none was added
350 # because this clause is trivially True, or False if the clause is
351 # False.
352 def _add_clause(self, lits, learnt):
353 if not lits:
354 assert not learnt
355 self.toplevel_conflict = True
356 return False
357 elif len(lits) == 1:
358 # A clause with only a single literal is represented
359 # as an assignment rather than as a clause.
360 if learnt:
361 reason = "learnt"
362 else:
363 reason = "top-level"
364 return self.enqueue(lits[0], reason)
366 clause = self.makeUnionClause(lits)
367 clause.learnt = learnt
369 if learnt:
370 # lits[0] is None because we just backtracked.
371 # Start watching the next literal that we will
372 # backtrack over.
373 best_level = -1
374 best_i = 1
375 for i in range(1, len(lits)):
376 level = self.get_varinfo_for_lit(lits[i]).level
377 if level > best_level:
378 best_level = level
379 best_i = i
380 lits[1], lits[best_i] = lits[best_i], lits[1]
382 # Watch the first two literals in the clause (both must be
383 # undefined at this point).
384 for lit in lits[:2]:
385 self.watch_lit(neg(lit), clause)
387 return clause
389 def name_lits(self, lst):
390 return [self.name_lit(l) for l in lst]
392 # For nicer debug messages
393 def name_lit(self, lit):
394 if lit >= 0:
395 return self.assigns[lit].name
396 return "not(%s)" % self.assigns[neg(lit)].name
398 def add_clause(self, lits):
399 # Public interface. Only used before the solve starts.
400 assert lits
402 debug("add_clause([%s])" % ', '.join(self.name_lits(lits)))
404 if any(self.lit_value(l) == True for l in lits):
405 # Trivially true already.
406 return True
407 lit_set = set(lits)
408 for l in lits:
409 if neg(l) in lit_set:
410 # X or not(X) is always True.
411 return True
412 # Remove duplicates and values known to be False
413 lits = [l for l in lit_set if self.lit_value(l) != False]
415 retval = self._add_clause(lits, learnt = False)
416 if not retval:
417 self.toplevel_conflict = True
418 return retval
420 def at_most_one(self, lits):
421 assert lits
423 debug("at_most_one(%s)" % ', '.join(self.name_lits(lits)))
425 # If we have zero or one literals then we're trivially true
426 # and not really needed for the solve. However, Zero Install
427 # monitors these objects to find out what was selected, so
428 # keep even trivial ones around for that.
430 #if len(lits) < 2:
431 # return True # Trivially true
433 # Ensure no duplicates
434 assert len(set(lits)) == len(lits), lits
436 # Ignore any literals already known to be False.
437 # If any are True then they're enqueued and we'll process them
438 # soon.
439 lits = [l for l in lits if self.lit_value(l) != False]
441 clause = self.makeAtMostOneClause(lits)
443 for lit in lits:
444 self.watch_lit(lit, clause)
446 return clause
448 def analyse(self, cause):
449 # After trying some assignments, we've discovered a conflict.
450 # e.g.
451 # - we selected A then B then C
452 # - from A, B, C we got X, Y
453 # - we have a rule: not(A) or not(X) or not(Y)
455 # The simplest thing to do would be:
456 # 1. add the rule "not(A) or not(B) or not(C)"
457 # 2. unassign C
459 # Then we we'd deduce not(C) and we could try something else.
460 # However, that would be inefficient. We want to learn a more
461 # general rule that will help us with the rest of the problem.
463 # We take the clause that caused the conflict ("cause") and
464 # ask it for its cause. In this case:
466 # A and X and Y => conflict
468 # Since X and Y followed logically from A, B, C there's no
469 # point learning this rule; we need to know to avoid A, B, C
470 # *before* choosing C. We ask the two variables deduced at the
471 # current level (X and Y) what caused them, and work backwards.
472 # e.g.
474 # X: A and C => X
475 # Y: C => Y
477 # Combining these, we get the cause of the conflict in terms of
478 # things we knew before the current decision level:
480 # A and X and Y => conflict
481 # A and (A and C) and (C) => conflict
482 # A and C => conflict
484 # We can then learn (record) the more general rule:
486 # not(A) or not(C)
488 # Then, in future, whenever A is selected we can remove C and
489 # everything that depends on it from consideration.
492 learnt = [None] # The general rule we're learning
493 btlevel = 0 # The deepest decision in learnt
494 p = None # The literal we want to expand now
495 seen = set() # The variables involved in the conflict
497 counter = 0
499 while True:
500 # cause is the reason why p is True (i.e. it enqueued it).
501 # The first time, p is None, which requests the reason
502 # why it is conflicting.
503 if p is None:
504 debug("Why did %s make us fail?" % cause)
505 p_reason = cause.cacl_reason(p)
506 debug("Because: %s => conflict" % (' and '.join(self.name_lits(p_reason))))
507 else:
508 debug("Why did %s lead to %s?" % (cause, self.name_lit(p)))
509 p_reason = cause.cacl_reason(p)
510 debug("Because: %s => %s" % (' and '.join(self.name_lits(p_reason)), self.name_lit(p)))
512 # p_reason is in the form (A and B and ...)
513 # p_reason => p
515 # Check each of the variables in p_reason that we haven't
516 # already considered:
517 # - if the variable was assigned at the current level,
518 # mark it for expansion
519 # - otherwise, add it to learnt
521 for lit in p_reason:
522 var_info = self.get_varinfo_for_lit(lit)
523 if var_info not in seen:
524 seen.add(var_info)
525 if var_info.level == self.get_decision_level():
526 # We deduced this var since the last decision.
527 # It must be in self.trail, so we'll get to it
528 # soon. Remember not to stop until we've processed it.
529 counter += 1
530 elif var_info.level > 0:
531 # We won't expand lit, just remember it.
532 # (we could expand it if it's not a decision, but
533 # apparently not doing so is useful)
534 learnt.append(neg(lit))
535 btlevel = max(btlevel, var_info.level)
536 # else we already considered the cause of this assignment
538 # At this point, counter is the number of assigned
539 # variables in self.trail at the current decision level that
540 # we've seen. That is, the number left to process. Pop
541 # the next one off self.trail (as well as any unrelated
542 # variables before it; everything up to the previous
543 # decision has to go anyway).
545 # On the first time round the loop, we must find the
546 # conflict depends on at least one assignment at the
547 # current level. Otherwise, simply setting the decision
548 # variable caused a clause to conflict, in which case
549 # the clause should have asserted not(decision-variable)
550 # before we ever made the decision.
551 # On later times round the loop, counter was already >
552 # 0 before we started iterating over p_reason.
553 assert counter > 0
555 while True:
556 p = self.trail[-1]
557 var_info = self.get_varinfo_for_lit(p)
558 cause = var_info.reason
559 self.undo_one()
560 if var_info in seen:
561 break
562 debug("(irrelevant)")
563 counter -= 1
565 if counter <= 0:
566 assert counter == 0
567 # If counter = 0 then we still have one more
568 # literal (p) at the current level that we
569 # could expand. However, apparently it's best
570 # to leave this unprocessed (says the minisat
571 # paper).
572 break
574 # p is the literal we decided to stop processing on. It's either
575 # a derived variable at the current level, or the decision that
576 # led to this level. Since we're not going to expand it, add it
577 # directly to the learnt clause.
578 learnt[0] = neg(p)
580 debug("Learnt: %s" % (' or '.join(self.name_lits(learnt))))
582 return learnt, btlevel
584 def run_solver(self, decide):
585 # Check whether we detected a trivial problem
586 # during setup.
587 if self.toplevel_conflict:
588 debug("FAIL: toplevel_conflict before starting solve!")
589 return False
591 while True:
592 # Use logical deduction to simplify the clauses
593 # and assign literals where there is only one possibility.
594 conflicting_clause = self.propagate()
595 if not conflicting_clause:
596 debug("new state: %s", self.assigns)
597 if all(info.value != None for info in self.assigns):
598 # Everything is assigned without conflicts
599 debug("SUCCESS!")
600 return True
601 else:
602 # Pick a variable and try assigning it one way.
603 # If it leads to a conflict, we'll backtrack and
604 # try it the other way.
605 lit = decide()
606 #print "TRYING:", self.name_lit(lit)
607 assert lit is not None, "decide function returned None!"
608 assert self.lit_value(lit) is None
609 self.trail_lim.append(len(self.trail))
610 r = self.enqueue(lit, reason = "considering")
611 assert r is True
612 else:
613 if self.get_decision_level() == 0:
614 debug("FAIL: conflict found at top level")
615 return False
616 else:
617 # Figure out the root cause of this failure.
618 learnt, backtrack_level = self.analyse(conflicting_clause)
620 self.cancel_until(backtrack_level)
622 c = self._add_clause(learnt, learnt = True)
624 if c is not True:
625 # Everything except the first literal in learnt is known to
626 # be False, so the first must be True.
627 e = self.enqueue(learnt[0], c)
628 assert e is True