diff --git a/src/internal/mod.rs b/src/internal/mod.rs index 7f11dea..08a55f4 100644 --- a/src/internal/mod.rs +++ b/src/internal/mod.rs @@ -3,3 +3,6 @@ pub mod frozen_copy_map; pub mod id; pub mod mapping; pub mod small_vec; +mod unwrap_unchecked; + +pub use unwrap_unchecked::debug_expect_unchecked; diff --git a/src/internal/unwrap_unchecked.rs b/src/internal/unwrap_unchecked.rs new file mode 100644 index 0000000..eee6970 --- /dev/null +++ b/src/internal/unwrap_unchecked.rs @@ -0,0 +1,13 @@ +/// An unsafe method that unwraps an option without checking if it is `None` in +/// release mode but does check the value in debug mode. +#[track_caller] +pub unsafe fn debug_expect_unchecked(opt: Option, _msg: &str) -> T { + #[cfg(debug_assertions)] + { + opt.expect(_msg) + } + #[cfg(not(debug_assertions))] + { + opt.unwrap_unchecked() + } +} diff --git a/src/solver/clause.rs b/src/solver/clause.rs index 4ccc610..f034130 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -1,6 +1,7 @@ use std::{ fmt::{Debug, Display, Formatter}, iter, + num::NonZeroU32, ops::ControlFlow, }; @@ -99,7 +100,7 @@ pub(crate) enum Clause { } impl Clause { - /// Returns the building blocks needed for a new [ClauseState] of the + /// Returns the building blocks needed for a new [WatchedLiterals] of the /// [Clause::Requires] kind. /// /// These building blocks are: @@ -149,7 +150,7 @@ impl Clause { } } - /// Returns the building blocks needed for a new [ClauseState] of the + /// Returns the building blocks needed for a new [WatchedLiterals] of the /// [Clause::Constrains] kind. /// /// These building blocks are: @@ -321,17 +322,19 @@ impl Clause { /// variable are grouped together in a linked list, so it becomes easy to notify /// them all. #[derive(Clone)] -pub(crate) struct ClauseState { - // The ids of the literals this clause is watching +pub(crate) struct WatchedLiterals { + /// The ids of the literals this clause is watching. A clause that is + /// watching literals is always watching two literals, no more, no less. pub watched_literals: [Literal; 2], - // The ids of the next clause in each linked list that this clause is part of + /// The ids of the next clause in each linked list that this clause is part + /// of. If either of these or `None` then there is no next clause. pub(crate) next_watches: [Option; 2], } -impl ClauseState { +impl WatchedLiterals { /// Shorthand method to construct a [`Clause::InstallRoot`] without /// requiring complicated arguments. - pub fn root() -> (Self, Clause) { + pub fn root() -> (Option, Clause) { let (kind, watched_literals) = Clause::root(); (Self::from_kind_and_initial_watches(watched_literals), kind) } @@ -346,7 +349,7 @@ impl ClauseState { requirement: Requirement, matching_candidates: impl IntoIterator, decision_tracker: &DecisionTracker, - ) -> (Self, bool, Clause) { + ) -> (Option, bool, Clause) { let (kind, watched_literals, conflict) = Clause::requires( candidate, requirement, @@ -371,7 +374,7 @@ impl ClauseState { constrained_package: VariableId, requirement: VersionSetId, decision_tracker: &DecisionTracker, - ) -> (Self, bool, Clause) { + ) -> (Option, bool, Clause) { let (kind, watched_literals, conflict) = Clause::constrains( candidate, constrained_package, @@ -386,7 +389,10 @@ impl ClauseState { ) } - pub fn lock(locked_candidate: VariableId, other_candidate: VariableId) -> (Self, Clause) { + pub fn lock( + locked_candidate: VariableId, + other_candidate: VariableId, + ) -> (Option, Clause) { let (kind, watched_literals) = Clause::lock(locked_candidate, other_candidate); (Self::from_kind_and_initial_watches(watched_literals), kind) } @@ -395,61 +401,31 @@ impl ClauseState { candidate: VariableId, other_candidate: Literal, name: NameId, - ) -> (Self, Clause) { + ) -> (Option, Clause) { let (kind, watched_literals) = Clause::forbid_multiple(candidate, other_candidate, name); (Self::from_kind_and_initial_watches(watched_literals), kind) } - pub fn learnt(learnt_clause_id: LearntClauseId, literals: &[Literal]) -> (Self, Clause) { + pub fn learnt( + learnt_clause_id: LearntClauseId, + literals: &[Literal], + ) -> (Option, Clause) { let (kind, watched_literals) = Clause::learnt(learnt_clause_id, literals); (Self::from_kind_and_initial_watches(watched_literals), kind) } - pub fn exclude(candidate: VariableId, reason: StringId) -> (Self, Clause) { + pub fn exclude(candidate: VariableId, reason: StringId) -> (Option, Clause) { let (kind, watched_literals) = Clause::exclude(candidate, reason); (Self::from_kind_and_initial_watches(watched_literals), kind) } - fn from_kind_and_initial_watches(watched_literals: Option<[Literal; 2]>) -> Self { - let watched_literals = watched_literals.unwrap_or([Literal::null(), Literal::null()]); - - let clause = Self { + fn from_kind_and_initial_watches(watched_literals: Option<[Literal; 2]>) -> Option { + let watched_literals = watched_literals?; + debug_assert!(watched_literals[0] != watched_literals[1]); + Some(Self { watched_literals, next_watches: [None, None], - }; - - debug_assert!(!clause.has_watches() || watched_literals[0] != watched_literals[1]); - - clause - } - - pub fn unlink_clause( - &mut self, - linked_clause: &ClauseState, - watched_solvable: VariableId, - linked_clause_watch_index: usize, - ) { - if self.watched_literals[0].variable() == watched_solvable { - self.next_watches[0] = linked_clause.next_watches[linked_clause_watch_index]; - } else { - debug_assert_eq!(self.watched_literals[1].variable(), watched_solvable); - self.next_watches[1] = linked_clause.next_watches[linked_clause_watch_index]; - } - } - - #[inline] - pub fn next_watched_clause(&self, solvable_id: VariableId) -> Option { - if solvable_id == self.watched_literals[0].variable() { - self.next_watches[0] - } else { - debug_assert_eq!(self.watched_literals[1].variable(), solvable_id); - self.next_watches[1] - } - } - - pub fn has_watches(&self) -> bool { - // If the first watch is not null, the second won't be either - !self.watched_literals[0].is_null() + }) } pub fn next_unwatched_literal( @@ -482,7 +458,7 @@ impl ClauseState { // The next unwatched variable (if available), is a variable that is: // * Not already being watched // * Not yet decided, or decided in such a way that the literal yields true - if self.watched_literals[other_watch_index].variable() != lit.variable() + if self.watched_literals[other_watch_index] != lit && lit.eval(decision_map).unwrap_or(true) { ControlFlow::Break(lit) @@ -500,43 +476,40 @@ impl ClauseState { } } -/// Represents a literal in a SAT clause (i.e. either A or ¬A) +/// Represents a literal in a SAT clause, a literal holds a variable and +/// indicates whether it should be positive or negative (i.e. either A or ¬A). +/// +/// A [`Literal`] stores a [`NonZeroU32`] which ensures that the size of an +/// `Option` is the same as a `Literal`. #[repr(transparent)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub(crate) struct Literal(u32); +pub(crate) struct Literal(NonZeroU32); impl Literal { /// Constructs a new [`Literal`] from a [`VariableId`] and a boolean /// indicating whether the literal should be negated. pub fn new(variable: VariableId, negate: bool) -> Self { let variable_idx = variable.to_usize(); - let literal_idx = variable_idx << 1 | negate as usize; - Self(literal_idx.try_into().expect("literal id too big")) + let encoded_literal = variable_idx << 1 | negate as usize; + Self::from_usize(encoded_literal) } } impl ArenaId for Literal { fn from_usize(x: usize) -> Self { - debug_assert!(x <= u32::MAX as usize, "watched literal id too big"); - Literal(x as u32) + let idx: u32 = (x + 1).try_into().expect("watched literal id too big"); + // SAFETY: This is safe because we are adding 1 to the index + unsafe { Literal(NonZeroU32::new_unchecked(idx)) } } fn to_usize(self) -> usize { - self.0 as usize + self.0.get() as usize - 1 } } impl Literal { - pub fn null() -> Self { - Self(u32::MAX) - } - - pub fn is_null(&self) -> bool { - self.0 == u32::MAX - } - pub fn negate(&self) -> bool { - (self.0 & 1) == 1 + (self.0.get() & 1) == 0 } /// Returns the value that would make the literal evaluate to true if @@ -549,7 +522,7 @@ impl Literal { /// assigned to the literal's solvable #[inline] pub(crate) fn variable(self) -> VariableId { - VariableId::from_usize((self.0 >> 1) as usize) + VariableId::from_usize(self.to_usize() >> 1) } /// Evaluates the literal, or returns `None` if no value has been assigned @@ -647,13 +620,6 @@ mod test { use super::*; use crate::{internal::arena::ArenaId, solver::decision::Decision}; - fn clause(next_clauses: [Option; 2], watch_literals: [Literal; 2]) -> ClauseState { - ClauseState { - watched_literals: watch_literals, - next_watches: next_clauses, - } - } - #[test] #[allow(clippy::bool_assert_comparison)] fn test_literal_satisfying_value() { @@ -685,111 +651,6 @@ mod test { assert_eq!(negated_literal.eval(&decision_map), Some(true)); } - #[test] - fn test_unlink_clause_different() { - let clause1 = clause( - [ - ClauseId::from_usize(2).into(), - ClauseId::from_usize(3).into(), - ], - [ - VariableId::from_usize(1596).negative(), - VariableId::from_usize(1211).negative(), - ], - ); - let clause2 = clause( - [None, ClauseId::from_usize(3).into()], - [ - VariableId::from_usize(1596).negative(), - VariableId::from_usize(1208).negative(), - ], - ); - let clause3 = clause( - [None, None], - [ - VariableId::from_usize(1211).negative(), - VariableId::from_usize(42).negative(), - ], - ); - - // Unlink 0 - { - let mut clause1 = clause1.clone(); - clause1.unlink_clause(&clause2, VariableId::from_usize(1596), 0); - assert_eq!( - clause1.watched_literals, - [ - VariableId::from_usize(1596).negative(), - VariableId::from_usize(1211).negative() - ] - ); - assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(3).into()]) - } - - // Unlink 1 - { - let mut clause1 = clause1; - clause1.unlink_clause(&clause3, VariableId::from_usize(1211), 0); - assert_eq!( - clause1.watched_literals, - [ - VariableId::from_usize(1596).negative(), - VariableId::from_usize(1211).negative() - ] - ); - assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None]) - } - } - - #[test] - fn test_unlink_clause_same() { - let clause1 = clause( - [ - ClauseId::from_usize(2).into(), - ClauseId::from_usize(2).into(), - ], - [ - VariableId::from_usize(1596).negative(), - VariableId::from_usize(1211).negative(), - ], - ); - let clause2 = clause( - [None, None], - [ - VariableId::from_usize(1596).negative(), - VariableId::from_usize(1211).negative(), - ], - ); - - // Unlink 0 - { - let mut clause1 = clause1.clone(); - clause1.unlink_clause(&clause2, VariableId::from_usize(1596), 0); - assert_eq!( - clause1.watched_literals, - [ - VariableId::from_usize(1596).negative(), - VariableId::from_usize(1211).negative() - ] - ); - assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(2).into()]) - } - - // Unlink 1 - { - let mut clause1 = clause1; - clause1.unlink_clause(&clause2, VariableId::from_usize(1211), 1); - assert_eq!( - clause1.watched_literals, - [ - VariableId::from_usize(1596).negative(), - VariableId::from_usize(1211).negative() - ] - ); - assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None]) - } - } - #[test] fn test_requires_with_and_without_conflict() { let mut decisions = DecisionTracker::new(); @@ -799,15 +660,21 @@ mod test { let candidate2 = VariableId::from_usize(3); // No conflict, all candidates available - let (clause, conflict, _kind) = ClauseState::requires( + let (clause, conflict, _kind) = WatchedLiterals::requires( parent, VersionSetId::from_usize(0).into(), [candidate1, candidate2], &decisions, ); assert!(!conflict); - assert_eq!(clause.watched_literals[0].variable(), parent); - assert_eq!(clause.watched_literals[1].variable(), candidate1.into()); + assert_eq!( + clause.as_ref().unwrap().watched_literals[0].variable(), + parent + ); + assert_eq!( + clause.unwrap().watched_literals[1].variable(), + candidate1.into() + ); // No conflict, still one candidate available decisions @@ -816,15 +683,21 @@ mod test { 1, ) .unwrap(); - let (clause, conflict, _kind) = ClauseState::requires( + let (clause, conflict, _kind) = WatchedLiterals::requires( parent, VersionSetId::from_usize(0).into(), [candidate1, candidate2], &decisions, ); assert!(!conflict); - assert_eq!(clause.watched_literals[0].variable(), parent); - assert_eq!(clause.watched_literals[1].variable(), candidate2.into()); + assert_eq!( + clause.as_ref().unwrap().watched_literals[0].variable(), + parent + ); + assert_eq!( + clause.as_ref().unwrap().watched_literals[1].variable(), + candidate2.into() + ); // Conflict, no candidates available decisions @@ -833,22 +706,28 @@ mod test { 1, ) .unwrap(); - let (clause, conflict, _kind) = ClauseState::requires( + let (clause, conflict, _kind) = WatchedLiterals::requires( parent, VersionSetId::from_usize(0).into(), [candidate1, candidate2], &decisions, ); assert!(conflict); - assert_eq!(clause.watched_literals[0].variable(), parent); - assert_eq!(clause.watched_literals[1].variable(), candidate1.into()); + assert_eq!( + clause.as_ref().unwrap().watched_literals[0].variable(), + parent + ); + assert_eq!( + clause.as_ref().unwrap().watched_literals[1].variable(), + candidate1.into() + ); // Panic decisions .try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1) .unwrap(); let panicked = std::panic::catch_unwind(|| { - ClauseState::requires( + WatchedLiterals::requires( parent, VersionSetId::from_usize(0).into(), [candidate1, candidate2], @@ -868,37 +747,79 @@ mod test { // No conflict, forbidden package not installed let (clause, conflict, _kind) = - ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions); + WatchedLiterals::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions); assert!(!conflict); - assert_eq!(clause.watched_literals[0].variable(), parent); - assert_eq!(clause.watched_literals[1].variable(), forbidden); + assert_eq!( + clause.as_ref().unwrap().watched_literals[0].variable(), + parent + ); + assert_eq!( + clause.as_ref().unwrap().watched_literals[1].variable(), + forbidden + ); // Conflict, forbidden package installed decisions .try_add_decision(Decision::new(forbidden, true, ClauseId::install_root()), 1) .unwrap(); let (clause, conflict, _kind) = - ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions); + WatchedLiterals::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions); assert!(conflict); - assert_eq!(clause.watched_literals[0].variable(), parent); - assert_eq!(clause.watched_literals[1].variable(), forbidden); + assert_eq!( + clause.as_ref().unwrap().watched_literals[0].variable(), + parent + ); + assert_eq!( + clause.as_ref().unwrap().watched_literals[1].variable(), + forbidden + ); // Panic decisions .try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1) .unwrap(); let panicked = std::panic::catch_unwind(|| { - ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions) + WatchedLiterals::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions) }) .is_err(); assert!(panicked); } #[test] - fn test_clause_size() { - // This test is here to ensure we don't increase the size of `ClauseState` by - // accident, as we are creating thousands of instances. + fn test_watched_literals_size() { + // This test is here to ensure we don't increase the size of `WatchedLiterals` + // by accident, as we are creating thousands of instances. // libsolv: 24 bytes - assert_eq!(std::mem::size_of::(), 16); + assert_eq!(std::mem::size_of::(), 16); + } + + #[test] + fn test_literal_size() { + assert_eq!(std::mem::size_of::(), 4); + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::>() + ); + assert_eq!( + std::mem::size_of::() * 2, + std::mem::size_of::<[Literal; 2]>() + ); + assert_eq!( + std::mem::size_of::() * 2, + std::mem::size_of::<[Option; 2]>() + ); + assert_eq!( + std::mem::size_of::() * 2, + std::mem::size_of::>() + ); + } + + #[test] + fn test_watched_literal_size() { + assert_eq!(std::mem::size_of::(), 16); + assert_eq!( + std::mem::size_of::>(), + std::mem::size_of::() + ); } } diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 9fca0bf..8c0e026 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -2,7 +2,7 @@ use std::{any::Any, fmt::Display, future::ready, ops::ControlFlow}; use ahash::{HashMap, HashSet}; pub use cache::SolverCache; -use clause::{Clause, ClauseState, Literal}; +use clause::{Clause, Literal, WatchedLiterals}; use decision::Decision; use decision_tracker::DecisionTracker; use elsa::FrozenMap; @@ -129,14 +129,14 @@ impl> Problem { #[derive(Default)] pub(crate) struct Clauses { pub(crate) kinds: Vec, - states: Vec, + watched_literals: Vec>, } impl Clauses { - pub fn alloc(&mut self, state: ClauseState, kind: Clause) -> ClauseId { + pub fn alloc(&mut self, watched_literals: Option, kind: Clause) -> ClauseId { let id = ClauseId::from_usize(self.kinds.len()); self.kinds.push(kind); - self.states.push(state); + self.watched_literals.push(watched_literals); id } } @@ -352,7 +352,7 @@ impl Solver { // The first clause will always be the install root clause. Here we verify that // this is indeed the case. let root_clause = { - let (state, kind) = ClauseState::root(); + let (state, kind) = WatchedLiterals::root(); self.clauses.alloc(state, kind) }; assert_eq!(root_clause, ClauseId::install_root()); @@ -646,14 +646,12 @@ impl Solver { } fn process_add_clause_output(&mut self, mut output: AddClauseOutput) -> Result<(), ClauseId> { - let clauses = &mut self.clauses.states; + let watched_literals = &mut self.clauses.watched_literals; for clause_id in output.clauses_to_watch { - debug_assert!( - clauses[clause_id.to_usize()].has_watches(), - "attempting to watch a clause without watches!" - ); - self.watches - .start_watching(&mut clauses[clause_id.to_usize()], clause_id); + let watched_literals = watched_literals[clause_id.to_usize()] + .as_mut() + .expect("attempting to watch a clause without watches!"); + self.watches.start_watching(watched_literals, clause_id); } for (solvable_id, requirement, clause_id) in output.new_requires_clauses { @@ -1093,8 +1091,112 @@ impl Solver { return Err(PropagationError::Cancelled(value)); }; - // Negative assertions derived from other rules (assertions are clauses that - // consist of a single literal, and therefore do not have watches) + // Add decisions from assertions and learned clauses. If any of these cause a + // conflict, we will return an error. + self.decide_assertions(level)?; + self.decide_learned(level)?; + + // For each decision that has not been propagated yet, we propagate the + // decision. + // + // Propagation entails iterating through the linked list of clauses that watch + // the literal that the decision caused to turn false. If a clause can only be + // satisfied if one of the literals involved is assigned a value, we also make a + // decision on that literal to ensure that the clause is satisfied. + // + // Any new decision is also propagated. If by making a decision on one of the + // remaining literals of a clause we cause a conflict, propagation is halted and + // an error is returned. + + let interner = self.cache.provider(); + let clause_kinds = &self.clauses.kinds; + + while let Some(decision) = self.decision_tracker.next_unpropagated() { + let watched_literal = Literal::new(decision.variable, decision.value); + + debug_assert!( + watched_literal.eval(self.decision_tracker.map()) == Some(false), + "we are only watching literals that are turning false" + ); + + // Propagate, iterating through the linked list of clauses that watch this + // solvable + let mut next_cursor = self + .watches + .cursor(&mut self.clauses.watched_literals, watched_literal); + while let Some(cursor) = next_cursor.take() { + let clause_id = cursor.clause_id(); + let clause = &clause_kinds[clause_id.to_usize()]; + let watch_index = cursor.watch_index(); + + // If the other literal the current clause is watching is already true, we can + // skip this clause. Its is already satisfied. + let watched_literals = cursor.watched_literals(); + let other_watched_literal = + watched_literals.watched_literals[1 - cursor.watch_index()]; + if other_watched_literal.eval(self.decision_tracker.map()) == Some(true) { + // Continue with the next clause in the linked list. + next_cursor = cursor.next(); + } else if let Some(literal) = watched_literals.next_unwatched_literal( + clause, + &self.learnt_clauses, + &self.requirement_to_sorted_candidates, + self.decision_tracker.map(), + watch_index, + ) { + // Update the watch to point to the new literal + next_cursor = cursor.update(literal); + } else { + // We could not find another literal to watch, which means the remaining + // watched literal must be set to true. + let decided = self + .decision_tracker + .try_add_decision( + Decision::new( + other_watched_literal.variable(), + other_watched_literal.satisfying_value(), + clause_id, + ), + level, + ) + .map_err(|_| { + PropagationError::Conflict( + other_watched_literal.variable(), + true, + clause_id, + ) + })?; + + if decided { + match clause { + // Skip logging for ForbidMultipleInstances, which is so noisy + Clause::ForbidMultipleInstances(..) => {} + _ => { + tracing::debug!( + "├ Propagate {} = {}. {}", + other_watched_literal + .variable() + .display(&self.variable_map, interner), + other_watched_literal.satisfying_value(), + clause.display(&self.variable_map, interner) + ); + } + } + } + + // Skip to the next clause in the linked list. + next_cursor = cursor.next(); + } + } + } + + Ok(()) + } + + /// Add decisions for negative assertions derived from other rules + /// (assertions are clauses that consist of a single literal, and + /// therefore do not have watches). + fn decide_assertions(&mut self, level: u32) -> Result<(), PropagationError> { for &(solvable_id, clause_id) in &self.negative_assertions { let value = false; let decided = self @@ -1110,7 +1212,11 @@ impl Solver { ); } } + Ok(()) + } + /// Add decisions derived from learnt clauses. + fn decide_learned(&mut self, level: u32) -> Result<(), PropagationError> { // Assertions derived from learnt rules for learn_clause_idx in 0..self.learnt_clause_ids.len() { let clause_id = self.learnt_clause_ids[learn_clause_idx]; @@ -1147,133 +1253,6 @@ impl Solver { ); } } - - // Watched solvables - let clauses = &self.clauses.kinds; - let clause_states = &mut self.clauses.states; - let interner = self.cache.provider(); - while let Some(decision) = self.decision_tracker.next_unpropagated() { - let watched_literal = Literal::new(decision.variable, decision.value); - - // Propagate, iterating through the linked list of clauses that watch this - // solvable - let mut old_predecessor_clause_id: Option; - let mut predecessor_clause_id: Option = None; - let mut next_clause_id = self.watches.first_clause_watching_literal(watched_literal); - while let Some(clause_id) = next_clause_id { - debug_assert!( - predecessor_clause_id != Some(clause_id), - "Linked list is circular!" - ); - - // Get mutable access to both clauses. - let (predecessor_clause_state, clause_state) = - if let Some(prev_clause_id) = predecessor_clause_id { - let prev_idx = prev_clause_id.to_usize(); - let current_idx = clause_id.to_usize(); - if prev_idx < current_idx { - let (left, right) = clause_states.split_at_mut(current_idx); - (Some(&mut left[prev_idx]), &mut right[0]) - } else { - let (left, right) = clause_states.split_at_mut(prev_idx); - (Some(&mut right[0]), &mut left[current_idx]) - } - } else { - (None, &mut clause_states[clause_id.to_usize()]) - }; - - // Update the prev_clause_id for the next run - old_predecessor_clause_id = predecessor_clause_id; - predecessor_clause_id = Some(clause_id); - - // Configure the next clause to visit - next_clause_id = clause_state.next_watched_clause(watched_literal.variable()); - - // Determine which watch turned false. - let (watch_index, other_watch_index) = - if clause_state.watched_literals[0].variable() == watched_literal.variable() { - (0, 1) - } else { - (1, 0) - }; - debug_assert!( - clause_state.watched_literals[watch_index].eval(self.decision_tracker.map()) - == Some(false) - ); - - // Find another literal to watch. If we can't find one, the other literal must - // be set to true for the clause to still hold. - if clause_state.watched_literals[other_watch_index] - .eval(self.decision_tracker.map()) - == Some(true) - { - // If the other watch is already true, we can simply skip - // this clause. - } else if let Some(variable) = clause_state.next_unwatched_literal( - &clauses[clause_id.to_usize()], - &self.learnt_clauses, - &self.requirement_to_sorted_candidates, - self.decision_tracker.map(), - watch_index, - ) { - self.watches.update_watched( - predecessor_clause_state, - clause_state, - clause_id, - watch_index, - watched_literal, - variable, - ); - - // Make sure the right predecessor is kept for the next iteration (i.e. the - // current clause is no longer a predecessor of the next one; the current - // clause's predecessor is) - predecessor_clause_id = old_predecessor_clause_id; - } else { - // We could not find another literal to watch, which means the remaining - // watched literal can be set to true - let remaining_watch_index = match watch_index { - 0 => 1, - 1 => 0, - _ => unreachable!(), - }; - - let remaining_watch = clause_state.watched_literals[remaining_watch_index]; - let decided = self - .decision_tracker - .try_add_decision( - Decision::new( - remaining_watch.variable(), - remaining_watch.satisfying_value(), - clause_id, - ), - level, - ) - .map_err(|_| { - PropagationError::Conflict(remaining_watch.variable(), true, clause_id) - })?; - - if decided { - let clause = &clauses[clause_id.to_usize()]; - match clause { - // Skip logging for ForbidMultipleInstances, which is so noisy - Clause::ForbidMultipleInstances(..) => {} - _ => { - tracing::debug!( - "├ Propagate {} = {}. {}", - remaining_watch - .variable() - .display(&self.variable_map, interner), - remaining_watch.satisfying_value(), - clause.display(&self.variable_map, interner) - ); - } - } - } - } - } - } - Ok(()) } @@ -1479,13 +1458,12 @@ impl Solver { let learnt_id = self.learnt_clauses.alloc(learnt.clone()); self.learnt_why.insert(learnt_id, learnt_why); - let (state, kind) = ClauseState::learnt(learnt_id, &learnt); - let has_watches = state.has_watches(); - let clause_id = self.clauses.alloc(state, kind); + let (watched_literals, kind) = WatchedLiterals::learnt(learnt_id, &learnt); + let clause_id = self.clauses.alloc(watched_literals, kind); self.learnt_clause_ids.push(clause_id); - if has_watches { - self.watches - .start_watching(&mut self.clauses.states[clause_id.to_usize()], clause_id); + if let Some(watched_literals) = self.clauses.watched_literals[clause_id.to_usize()].as_mut() + { + self.watches.start_watching(watched_literals, clause_id); } tracing::debug!("│├ Learnt disjunction:",); @@ -1643,7 +1621,7 @@ async fn add_clauses_for_solvables( // There is no information about the solvable's dependencies, so we add // an exclusion clause for it - let (state, kind) = ClauseState::exclude(variable, reason); + let (state, kind) = WatchedLiterals::exclude(variable, reason); let clause_id = clauses.alloc(state, kind); // Exclusions are negative assertions, tracked outside the watcher @@ -1747,11 +1725,11 @@ async fn add_clauses_for_solvables( for &other_candidate in candidates { if other_candidate != locked_solvable_id { let other_candidate_var = variable_map.intern_solvable(other_candidate); - let (state, kind) = - ClauseState::lock(locked_solvable_var, other_candidate_var); - let clause_id = clauses.alloc(state, kind); + let (watched_literals, kind) = + WatchedLiterals::lock(locked_solvable_var, other_candidate_var); + let clause_id = clauses.alloc(watched_literals, kind); - debug_assert!(clauses.states[clause_id.to_usize()].has_watches()); + debug_assert!(clauses.watched_literals[clause_id.to_usize()].is_some()); output.clauses_to_watch.push(clause_id); } } @@ -1760,8 +1738,8 @@ async fn add_clauses_for_solvables( // Add a clause for solvables that are externally excluded. for (solvable, reason) in package_candidates.excluded.iter().copied() { let solvable_var = variable_map.intern_solvable(solvable); - let (state, kind) = ClauseState::exclude(solvable_var, reason); - let clause_id = clauses.alloc(state, kind); + let (watched_literals, kind) = WatchedLiterals::exclude(solvable_var, reason); + let clause_id = clauses.alloc(watched_literals, kind); // Exclusions are negative assertions, tracked outside the watcher system output.negative_assertions.push((solvable_var, clause_id)); @@ -1829,13 +1807,13 @@ async fn add_clauses_for_solvables( other_solvables.add( candidate_var, |a, b, positive| { - let (state, kind) = ClauseState::forbid_multiple( + let (watched_literals, kind) = WatchedLiterals::forbid_multiple( a, if positive { b.positive() } else { b.negative() }, name_id, ); - let clause_id = clauses.alloc(state, kind); - debug_assert!(clauses.states[clause_id.to_usize()].has_watches()); + let clause_id = clauses.alloc(watched_literals, kind); + debug_assert!(clauses.watched_literals[clause_id.to_usize()].is_some()); output.clauses_to_watch.push(clause_id); }, || variable_map.alloc_forbid_multiple_variable(name_id), @@ -1844,14 +1822,14 @@ async fn add_clauses_for_solvables( // Add the requirements clause let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); - let (state, conflict, kind) = ClauseState::requires( + let (watched_literals, conflict, kind) = WatchedLiterals::requires( variable, requirement, version_set_variables.iter().flatten().copied(), decision_tracker, ); - let has_watches = state.has_watches(); - let clause_id = clauses.alloc(state, kind); + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); if has_watches { output.clauses_to_watch.push(clause_id); @@ -1890,7 +1868,7 @@ async fn add_clauses_for_solvables( // Add forbidden clauses for the candidates for &forbidden_candidate in non_matching_candidates { let forbidden_candidate_var = variable_map.intern_solvable(forbidden_candidate); - let (state, conflict, kind) = ClauseState::constrains( + let (state, conflict, kind) = WatchedLiterals::constrains( variable, forbidden_candidate_var, version_set_id, diff --git a/src/solver/watch_map.rs b/src/solver/watch_map.rs index 07cce36..adc5c9e 100644 --- a/src/solver/watch_map.rs +++ b/src/solver/watch_map.rs @@ -1,6 +1,6 @@ use crate::{ - internal::{id::ClauseId, mapping::Mapping}, - solver::clause::{ClauseState, Literal}, + internal::{arena::ArenaId, debug_expect_unchecked, id::ClauseId, mapping::Mapping}, + solver::clause::{Literal, WatchedLiterals}, }; /// A map from literals to the clauses that are watching them. Each literal @@ -20,7 +20,7 @@ impl WatchMap { /// Add the clause to the linked list of the literals that the clause is /// watching. - pub(crate) fn start_watching(&mut self, clause: &mut ClauseState, clause_id: ClauseId) { + pub(crate) fn start_watching(&mut self, clause: &mut WatchedLiterals, clause_id: ClauseId) { for (watch_index, watched_literal) in clause.watched_literals.into_iter().enumerate() { // Construct a linked list by adding the clause to the start of the linked list // and setting the previous head of the chain as the next element in the linked @@ -31,39 +31,180 @@ impl WatchMap { } } - pub(crate) fn update_watched( - &mut self, - predecessor_clause: Option<&mut ClauseState>, - clause: &mut ClauseState, - clause_id: ClauseId, - watch_index: usize, - previous_watch: Literal, - new_watch: Literal, - ) { - // Remove this clause from its current place in the linked list, because we - // are no longer watching what brought us here - if let Some(predecessor_clause) = predecessor_clause { - // Unlink the clause - predecessor_clause.unlink_clause(clause, previous_watch.variable(), watch_index); - } else if let Some(next_watch) = clause.next_watches[watch_index] { - // This was the first clause in the chain - self.map.insert(previous_watch, next_watch); + /// Returns a [`WatchMapCursor`] that can be used to navigate and manipulate + /// the linked list of the clauses that are watching the specified + /// literal. + pub fn cursor<'a>( + &'a mut self, + watches: &'a mut [Option], + literal: Literal, + ) -> Option> { + let clause_id = *self.map.get(literal)?; + let watched_literal = watches[clause_id.to_usize()] + .as_ref() + .expect("no watches found for clause"); + let watch_index = if watched_literal.watched_literals[0] == literal { + 0 } else { - self.map.unset(previous_watch); + debug_assert_eq!( + watched_literal.watched_literals[1], literal, + "the clause is not actually watching the literal" + ); + 1 + }; + + Some(WatchMapCursor { + watch_map: self, + watches, + literal, + previous: None, + current: WatchNode { + clause_id, + watch_index, + }, + }) + } +} + +struct WatchNode { + /// The index of the [`WatchedLiterals`] + clause_id: ClauseId, + + /// A [`WatchedLiterals`] contains the state for two linked lists. This + /// index indicates which of the two linked-list nodes is referenced. + watch_index: usize, +} + +/// The watchmap contains a linked-list of clauses that are watching a certain +/// literal. This linked-list is a singly linked list, which requires some +/// administration when trying to modify the list. The [`WatchMapCursor`] is a +/// utility that allows navigating the linked-list and manipulate it. +/// +/// A cursor is created using [`WatchMap::cursor`]. The cursor can iterate +/// through all the clauses using [`WatchMapCursor::next`] and a single watch +/// can be updated using the [`WatchMapCursor::update`] method. +pub struct WatchMapCursor<'a> { + /// The watchmap that is being navigated. + watch_map: &'a mut WatchMap, + + /// The nodes of the linked list. + watches: &'a mut [Option], + + /// The literal who's linked list is being navigated. + literal: Literal, + + /// The previous node we iterated or `None` if this is the head. + previous: Option, + + /// The current node. + current: WatchNode, +} + +impl<'a> WatchMapCursor<'a> { + /// Skip to the next node in the linked list. Returns `None` if there is no + /// next node. + pub fn next(mut self) -> Option { + let next = self.next_node()?; + + self.previous = Some(self.current); + self.current = next; + + Some(self) + } + + /// Returns the next node in the linked list or `None` if there is no next. + fn next_node(&self) -> Option { + let current_watch = self.watched_literals(); + let next_clause_id = current_watch.next_watches[self.current.watch_index]?; + let next_watch = self.watches[next_clause_id.to_usize()] + .as_ref() + .expect("watches are missing"); + let next_clause_watch_index = if next_watch.watched_literals[0] == self.literal { + 0 + } else { + debug_assert_eq!( + next_watch.watched_literals[1], self.literal, + "the clause is not actually watching the literal" + ); + 1 + }; + + Some(WatchNode { + clause_id: next_clause_id, + watch_index: next_clause_watch_index, + }) + } + + /// The current clause that is being navigated. + pub fn clause_id(&self) -> ClauseId { + self.current.clause_id + } + + /// Returns the watches of the current clause. + pub fn watched_literals(&self) -> &WatchedLiterals { + // SAFETY: Within the cursor, the current clause is always watching literals. + unsafe { + debug_expect_unchecked( + self.watches[self.current.clause_id.to_usize()].as_ref(), + "clause is not watching literals", + ) } + } - // Set the new watch - clause.watched_literals[watch_index] = new_watch; - let previous_clause_id = self.map.insert(new_watch, clause_id); - clause.next_watches[watch_index] = previous_clause_id; + /// Returns the index of the current watch in the current clause. + pub fn watch_index(&self) -> usize { + self.current.watch_index } - /// Returns the id of the first clause that is watching the specified - /// literal. - pub(crate) fn first_clause_watching_literal( - &mut self, - watched_literal: Literal, - ) -> Option { - self.map.get(watched_literal).copied() + /// Update the current watch to a new literal. This removes the current node + /// from the linked-list and sets up a watch on the new literal. + /// + /// Returns a cursor that points to the next node in the linked list or + /// `None` if there is no next. + pub fn update(mut self, new_watch: Literal) -> Option { + debug_assert_ne!( + new_watch, self.literal, + "cannot update watch to the same literal" + ); + + let clause_idx = self.current.clause_id.to_usize(); + let next_node = self.next_node(); + + // Update the previous node to point to the next node in the linked list + // (effectively removing this one). + if let Some(previous) = &self.previous { + // If there is a previous node we update that node to point to the next. + // SAFETY: Within the cursor, the watches are never unset, so if we have a + // previous index there will also be watch literals for that clause. + let previous_watches = unsafe { + debug_expect_unchecked( + self.watches[previous.clause_id.to_usize()].as_mut(), + "previous clause has no watches", + ) + }; + previous_watches.next_watches[previous.watch_index] = + next_node.as_ref().map(|node| node.clause_id); + } else if let Some(next_clause_id) = next_node.as_ref().map(|node| node.clause_id) { + // If there is no previous node, we are the head of the linked list. + self.watch_map.map.insert(self.literal, next_clause_id); + } else { + self.watch_map.map.unset(self.literal); + } + + // Set the new watch for the current clause. + let watch = unsafe { + debug_expect_unchecked( + self.watches[clause_idx].as_mut(), + "clause is not watching literals", + ) + }; + watch.watched_literals[self.current.watch_index] = new_watch; + let previous_clause_id = self.watch_map.map.insert(new_watch, self.current.clause_id); + watch.next_watches[self.current.watch_index] = previous_clause_id; + + // Update the current + self.current = next_node?; + + Some(self) } }