Large-scale API cleanup
[zeroinstall/zeroinstall-afb.git] / zeroinstall / injector / sat.py
blob1f9ff2571577c038ee665fa3de875a4e514f9b4f
1 """
2 Internal implementation of a SAT solver, used by L{solver.SATSolver}.
3 This is not part of the public API.
4 """
6 # Copyright (C) 2010, Thomas Leonard
7 # See the README file for details, or visit http://0install.net.
9 # The design of this solver is very heavily based on the one described in
10 # the MiniSat paper "An Extensible SAT-solver [extended version 1.2]"
11 # http://minisat.se/Papers.html
13 # The main differences are:
15 # - We care about which solution we find (not just "satisfiable" or "not").
16 # - We take care to be deterministic (always select the same versions given
17 # the same input). We do not do random restarts, etc.
18 # - We add an AtMostOneClause (the paper suggests this in the Excercises, and
19 # it's very useful for our purposes).
21 def debug(msg, *args):
22 return
23 print "SAT:", msg % args
25 # variables are numbered from 0
26 # literals have the same number as the corresponding variable,
27 # except they for negatives they are (-1-v):
29 # Variable Literal not(Literal)
30 # 0 0 -1
31 # 1 1 -2
32 def neg(lit):
33 return -1 - lit
35 def watch_index(lit):
36 if lit >= 0:
37 return lit * 2
38 return neg(lit) * 2 + 1
40 def makeAtMostOneClause(solver):
41 class AtMostOneClause:
42 def __init__(self, lits):
43 """Preferred literals come first."""
44 self.lits = lits
46 # The single literal from our set that is True.
47 # We store this explicitly because the decider needs to know quickly.
48 self.current = None
50 def propagate(self, lit):
51 # Re-add ourselves to the watch list.
52 # (we we won't get any more notifications unless we backtrack,
53 # in which case we'd need to get back on the list anyway)
54 solver.watch_lit(lit, self)
56 # value[lit] has just become True
57 assert solver.lit_value(lit) == True
58 assert lit >= 0
60 #debug("%s: noticed %s has become True" % (self, solver.name_lit(lit)))
62 # If we already propagated successfully when the first
63 # one was set then we set all the others to False and
64 # anyone trying to set one True will get rejected. And
65 # if we didn't propagate yet, current will still be
66 # None, even if we now have a conflict (which we'll
67 # detect below).
68 assert self.current is None
70 self.current = lit
72 # If we later backtrace, call our undo function to unset current
73 solver.get_varinfo_for_lit(lit).undo.append(self)
75 for l in self.lits:
76 value = solver.lit_value(l)
77 #debug("Value of %s is %s" % (solver.name_lit(l), value))
78 if value is True and l is not lit:
79 # Due to queuing, we might get called with current = None
80 # and two versions already selected.
81 debug("CONFLICT: already selected %s" % solver.name_lit(l))
82 return False
83 if value is None:
84 # Since one of our lits is already true, all unknown ones
85 # can be set to False.
86 if not solver.enqueue(neg(l), self):
87 debug("CONFLICT: enqueue failed for %s", solver.name_lit(neg(l)))
88 return False # Conflict; abort
90 return True
92 def undo(self, lit):
93 debug("(backtracking: no longer selected %s)" % solver.name_lit(lit))
94 assert lit == self.current
95 self.current = None
97 # Why is lit True?
98 # Or, why are we causing a conflict (if lit is None)?
99 def cacl_reason(self, lit):
100 if lit is None:
101 # Find two True literals
102 trues = []
103 for l in self.lits:
104 if solver.lit_value(l) is True:
105 trues.append(l)
106 if len(trues) == 2: return trues
107 else:
108 for l in self.lits:
109 if l is not lit and solver.lit_value(l) is True:
110 return [l]
111 # Find one True literal
112 assert 0 # don't know why!
114 def best_undecided(self):
115 debug("best_undecided: %s" % (solver.name_lits(self.lits)))
116 for lit in self.lits:
117 #debug("%s = %s" % (solver.name_lit(lit), solver.lit_value(lit)))
118 if solver.lit_value(lit) is None:
119 return lit
120 return None
122 def __repr__(self):
123 return "<lone: %s>" % (', '.join(solver.name_lits(self.lits)))
125 return AtMostOneClause
127 def makeUnionClause(solver):
128 class UnionClause:
129 def __init__(self, lits):
130 self.lits = lits
132 # Try to infer new facts.
133 # We can do this only when all of our literals are False except one,
134 # which is undecided. That is,
135 # False... or X or False... = True => X = True
137 # To get notified when this happens, we tell the solver to
138 # watch two of our undecided literals. Watching two undecided
139 # literals is sufficient. When one changes we check the state
140 # again. If we still have two or more undecided then we switch
141 # to watching them, otherwise we propagate.
143 # Returns False on conflict.
144 def propagate(self, lit):
145 # value[get(lit)] has just become False
147 #debug("%s: noticed %s has become False" % (self, solver.name_lit(neg(lit))))
149 # For simplicity, only handle the case where self.lits[1]
150 # is the one that just got set to False, so that:
151 # - value[lits[0]] = None | True
152 # - value[lits[1]] = False
153 # If it's the other way around, just swap them before we start.
154 if self.lits[0] == neg(lit):
155 self.lits[0], self.lits[1] = self.lits[1], self.lits[0]
157 if solver.lit_value(self.lits[0]) == True:
158 # We're already satisfied. Do nothing.
159 solver.watch_lit(lit, self)
160 return True
162 assert solver.lit_value(self.lits[1]) == False
164 # Find a new literal to watch now that lits[1] is resolved,
165 # swap it with lits[1], and start watching it.
166 for i in range(2, len(self.lits)):
167 value = solver.lit_value(self.lits[i])
168 if value != False:
169 # Could be None or True. If it's True then we've already done our job,
170 # so this means we don't get notified unless we backtrack, which is fine.
171 self.lits[1], self.lits[i] = self.lits[i], self.lits[1]
172 solver.watch_lit(neg(self.lits[1]), self)
173 return True
175 # Only lits[0], is now undefined.
176 solver.watch_lit(lit, self)
177 return solver.enqueue(self.lits[0], self)
179 def undo(self, lit): pass
181 # Why is lit True?
182 # Or, why are we causing a conflict (if lit is None)?
183 def cacl_reason(self, lit):
184 assert lit is None or lit is self.lits[0]
186 # The cause is everything except lit.
187 return [neg(l) for l in self.lits if l is not lit]
189 def __repr__(self):
190 return "<some: %s>" % (', '.join(solver.name_lits(self.lits)))
191 return UnionClause
193 # Using an array of VarInfo objects is less efficient than using multiple arrays, but
194 # easier for me to understand.
195 class VarInfo(object):
196 __slots__ = ['value', 'reason', 'level', 'undo', 'obj']
197 def __init__(self, obj):
198 self.value = None # True/False/None
199 self.reason = None # The constraint that implied our value, if True or False
200 self.level = -1 # The decision level at which we got a value (when not None)
201 self.undo = [] # Constraints to update if we become unbound (by backtracking)
202 self.obj = obj # The object this corresponds to (for our caller and for debugging)
204 def __repr__(self):
205 return '%s=%s' % (self.name, self.value)
207 @property
208 def name(self):
209 return str(self.obj)
211 class SATProblem(object):
212 def __init__(self):
213 # Propagation
214 self.watches = [] # watches[2i,2i+1] = constraints to check when literal[i] becomes True/False
215 self.propQ = [] # propagation queue
217 # Assignments
218 self.assigns = [] # [VarInfo]
219 self.trail = [] # order of assignments
220 self.trail_lim = [] # decision levels
222 self.toplevel_conflict = False
224 self.makeAtMostOneClause = makeAtMostOneClause(self)
225 self.makeUnionClause = makeUnionClause(self)
227 def get_decision_level(self):
228 return len(self.trail_lim)
230 def add_variable(self, obj):
231 debug("add_variable('%s')", obj)
232 index = len(self.assigns)
234 self.watches += [[], []] # Add watch lists for X and not(X)
235 self.assigns.append(VarInfo(obj))
236 return index
238 # lit is now True
239 # reason is the clause that is asserting this
240 # Returns False if this immediately causes a conflict.
241 def enqueue(self, lit, reason):
242 debug("%s => %s" % (reason, self.name_lit(lit)))
243 old_value = self.lit_value(lit)
244 if old_value is not None:
245 if old_value is False:
246 # Conflict
247 return False
248 else:
249 # Already set (shouldn't happen)
250 return True
252 if lit < 0:
253 var_info = self.assigns[neg(lit)]
254 var_info.value = False
255 else:
256 var_info = self.assigns[lit]
257 var_info.value = True
258 var_info.level = self.get_decision_level()
259 var_info.reason = reason
261 self.trail.append(lit)
262 self.propQ.append(lit)
264 return True
266 # Pop most recent assignment from self.trail
267 def undo_one(self):
268 lit = self.trail[-1]
269 debug("(pop %s)", self.name_lit(lit))
270 var_info = self.get_varinfo_for_lit(lit)
271 var_info.value = None
272 var_info.reason = None
273 var_info.level = -1
274 self.trail.pop()
276 while var_info.undo:
277 var_info.undo.pop().undo(lit)
279 def cancel(self):
280 n_this_level = len(self.trail) - self.trail_lim[-1]
281 debug("backtracking from level %d (%d assignments)" %
282 (self.get_decision_level(), n_this_level))
283 while n_this_level != 0:
284 self.undo_one()
285 n_this_level -= 1
286 self.trail_lim.pop()
288 def cancel_until(self, level):
289 while self.get_decision_level() > level:
290 self.cancel()
292 # Process the propQ.
293 # Returns None when done, or the clause that caused a conflict.
294 def propagate(self):
295 #debug("propagate: queue length = %d", len(self.propQ))
296 while self.propQ:
297 lit = self.propQ[0]
298 del self.propQ[0]
299 wi = watch_index(lit)
300 watches = self.watches[wi]
301 self.watches[wi] = []
303 debug("%s -> True : watches: %s" % (self.name_lit(lit), watches))
305 # Notifiy all watchers
306 for i in range(len(watches)):
307 clause = watches[i]
308 if not clause.propagate(lit):
309 # Conflict
311 # Re-add remaining watches
312 self.watches[wi] += watches[i+1:]
314 # No point processing the rest of the queue as
315 # we'll have to backtrack now.
316 self.propQ = []
318 return clause
319 return None
321 def impossible(self):
322 self.toplevel_conflict = True
324 def get_varinfo_for_lit(self, lit):
325 if lit >= 0:
326 return self.assigns[lit]
327 else:
328 return self.assigns[neg(lit)]
330 def lit_value(self, lit):
331 if lit >= 0:
332 value = self.assigns[lit].value
333 return value
334 else:
335 v = -1 - lit
336 value = self.assigns[v].value
337 if value is None:
338 return None
339 else:
340 return not value
342 # Call cb when lit becomes True
343 def watch_lit(self, lit, cb):
344 #debug("%s is watching for %s to become True" % (cb, self.name_lit(lit)))
345 self.watches[watch_index(lit)].append(cb)
347 # Returns the new clause if one was added, True if none was added
348 # because this clause is trivially True, or False if the clause is
349 # False.
350 def _add_clause(self, lits, learnt):
351 if not lits:
352 assert not learnt
353 self.toplevel_conflict = True
354 return False
355 elif len(lits) == 1:
356 # A clause with only a single literal is represented
357 # as an assignment rather than as a clause.
358 if learnt:
359 reason = "learnt"
360 else:
361 reason = "top-level"
362 return self.enqueue(lits[0], reason)
364 clause = self.makeUnionClause(lits)
365 clause.learnt = learnt
367 if learnt:
368 # lits[0] is None because we just backtracked.
369 # Start watching the next literal that we will
370 # backtrack over.
371 best_level = -1
372 best_i = 1
373 for i in range(1, len(lits)):
374 level = self.get_varinfo_for_lit(lits[i]).level
375 if level > best_level:
376 best_level = level
377 best_i = i
378 lits[1], lits[best_i] = lits[best_i], lits[1]
380 # Watch the first two literals in the clause (both must be
381 # undefined at this point).
382 for lit in lits[:2]:
383 self.watch_lit(neg(lit), clause)
385 return clause
387 def name_lits(self, lst):
388 return [self.name_lit(l) for l in lst]
390 # For nicer debug messages
391 def name_lit(self, lit):
392 if lit >= 0:
393 return self.assigns[lit].name
394 return "not(%s)" % self.assigns[neg(lit)].name
396 def add_clause(self, lits):
397 # Public interface. Only used before the solve starts.
398 assert lits
400 debug("add_clause([%s])" % ', '.join(self.name_lits(lits)))
402 if any(self.lit_value(l) == True for l in lits):
403 # Trivially true already.
404 return True
405 lit_set = set(lits)
406 for l in lits:
407 if neg(l) in lit_set:
408 # X or not(X) is always True.
409 return True
410 # Remove duplicates and values known to be False
411 lits = [l for l in lit_set if self.lit_value(l) != False]
413 retval = self._add_clause(lits, learnt = False)
414 if not retval:
415 self.toplevel_conflict = True
416 return retval
418 def at_most_one(self, lits):
419 assert lits
421 debug("at_most_one(%s)" % ', '.join(self.name_lits(lits)))
423 # If we have zero or one literals then we're trivially true
424 # and not really needed for the solve. However, Zero Install
425 # monitors these objects to find out what was selected, so
426 # keep even trivial ones around for that.
428 #if len(lits) < 2:
429 # return True # Trivially true
431 # Ensure no duplicates
432 assert len(set(lits)) == len(lits), lits
434 # Ignore any literals already known to be False.
435 # If any are True then they're enqueued and we'll process them
436 # soon.
437 lits = [l for l in lits if self.lit_value(l) != False]
439 clause = self.makeAtMostOneClause(lits)
441 for lit in lits:
442 self.watch_lit(lit, clause)
444 return clause
446 def analyse(self, cause):
447 # After trying some assignments, we've discovered a conflict.
448 # e.g.
449 # - we selected A then B then C
450 # - from A, B, C we got X, Y
451 # - we have a rule: not(A) or not(X) or not(Y)
453 # The simplest thing to do would be:
454 # 1. add the rule "not(A) or not(B) or not(C)"
455 # 2. unassign C
457 # Then we we'd deduce not(C) and we could try something else.
458 # However, that would be inefficient. We want to learn a more
459 # general rule that will help us with the rest of the problem.
461 # We take the clause that caused the conflict ("cause") and
462 # ask it for its cause. In this case:
464 # A and X and Y => conflict
466 # Since X and Y followed logically from A, B, C there's no
467 # point learning this rule; we need to know to avoid A, B, C
468 # *before* choosing C. We ask the two variables deduced at the
469 # current level (X and Y) what caused them, and work backwards.
470 # e.g.
472 # X: A and C => X
473 # Y: C => Y
475 # Combining these, we get the cause of the conflict in terms of
476 # things we knew before the current decision level:
478 # A and X and Y => conflict
479 # A and (A and C) and (C) => conflict
480 # A and C => conflict
482 # We can then learn (record) the more general rule:
484 # not(A) or not(C)
486 # Then, in future, whenever A is selected we can remove C and
487 # everything that depends on it from consideration.
490 learnt = [None] # The general rule we're learning
491 btlevel = 0 # The deepest decision in learnt
492 p = None # The literal we want to expand now
493 seen = set() # The variables involved in the conflict
495 counter = 0
497 while True:
498 # cause is the reason why p is True (i.e. it enqueued it).
499 # The first time, p is None, which requests the reason
500 # why it is conflicting.
501 if p is None:
502 debug("Why did %s make us fail?" % cause)
503 p_reason = cause.cacl_reason(p)
504 debug("Because: %s => conflict" % (' and '.join(self.name_lits(p_reason))))
505 else:
506 debug("Why did %s lead to %s?" % (cause, self.name_lit(p)))
507 p_reason = cause.cacl_reason(p)
508 debug("Because: %s => %s" % (' and '.join(self.name_lits(p_reason)), self.name_lit(p)))
510 # p_reason is in the form (A and B and ...)
511 # p_reason => p
513 # Check each of the variables in p_reason that we haven't
514 # already considered:
515 # - if the variable was assigned at the current level,
516 # mark it for expansion
517 # - otherwise, add it to learnt
519 for lit in p_reason:
520 var_info = self.get_varinfo_for_lit(lit)
521 if var_info not in seen:
522 seen.add(var_info)
523 if var_info.level == self.get_decision_level():
524 # We deduced this var since the last decision.
525 # It must be in self.trail, so we'll get to it
526 # soon. Remember not to stop until we've processed it.
527 counter += 1
528 elif var_info.level > 0:
529 # We won't expand lit, just remember it.
530 # (we could expand it if it's not a decision, but
531 # apparently not doing so is useful)
532 learnt.append(neg(lit))
533 btlevel = max(btlevel, var_info.level)
534 # else we already considered the cause of this assignment
536 # At this point, counter is the number of assigned
537 # variables in self.trail at the current decision level that
538 # we've seen. That is, the number left to process. Pop
539 # the next one off self.trail (as well as any unrelated
540 # variables before it; everything up to the previous
541 # decision has to go anyway).
543 # On the first time round the loop, we must find the
544 # conflict depends on at least one assignment at the
545 # current level. Otherwise, simply setting the decision
546 # variable caused a clause to conflict, in which case
547 # the clause should have asserted not(decision-variable)
548 # before we ever made the decision.
549 # On later times round the loop, counter was already >
550 # 0 before we started iterating over p_reason.
551 assert counter > 0
553 while True:
554 p = self.trail[-1]
555 var_info = self.get_varinfo_for_lit(p)
556 cause = var_info.reason
557 self.undo_one()
558 if var_info in seen:
559 break
560 debug("(irrelevant)")
561 counter -= 1
563 if counter <= 0:
564 assert counter == 0
565 # If counter = 0 then we still have one more
566 # literal (p) at the current level that we
567 # could expand. However, apparently it's best
568 # to leave this unprocessed (says the minisat
569 # paper).
570 break
572 # p is the literal we decided to stop processing on. It's either
573 # a derived variable at the current level, or the decision that
574 # led to this level. Since we're not going to expand it, add it
575 # directly to the learnt clause.
576 learnt[0] = neg(p)
578 debug("Learnt: %s" % (' or '.join(self.name_lits(learnt))))
580 return learnt, btlevel
582 def run_solver(self, decide):
583 # Check whether we detected a trivial problem
584 # during setup.
585 if self.toplevel_conflict:
586 debug("FAIL: toplevel_conflict before starting solve!")
587 return False
589 while True:
590 # Use logical deduction to simplify the clauses
591 # and assign literals where there is only one possibility.
592 conflicting_clause = self.propagate()
593 if not conflicting_clause:
594 debug("new state: %s", self.assigns)
595 if all(info.value != None for info in self.assigns):
596 # Everything is assigned without conflicts
597 debug("SUCCESS!")
598 return True
599 else:
600 # Pick a variable and try assigning it one way.
601 # If it leads to a conflict, we'll backtrack and
602 # try it the other way.
603 lit = decide()
604 #print "TRYING:", self.name_lit(lit)
605 assert lit is not None, "decide function returned None!"
606 assert self.lit_value(lit) is None
607 self.trail_lim.append(len(self.trail))
608 r = self.enqueue(lit, reason = "considering")
609 assert r is True
610 else:
611 if self.get_decision_level() == 0:
612 debug("FAIL: conflict found at top level")
613 return False
614 else:
615 # Figure out the root cause of this failure.
616 learnt, backtrack_level = self.analyse(conflicting_clause)
618 self.cancel_until(backtrack_level)
620 c = self._add_clause(learnt, learnt = True)
622 if c is not True:
623 # Everything except the first literal in learnt is known to
624 # be False, so the first must be True.
625 e = self.enqueue(learnt[0], c)
626 assert e is True