diff --git a/src/internal/id.rs b/src/internal/id.rs index 2a46f60..6d1d33c 100644 --- a/src/internal/id.rs +++ b/src/internal/id.rs @@ -1,4 +1,7 @@ -use std::fmt::{Display, Formatter}; +use std::{ + fmt::{Display, Formatter}, + num::NonZeroU32, +}; use crate::{internal::arena::ArenaId, Interner}; @@ -165,32 +168,24 @@ impl From for u32 { #[repr(transparent)] #[derive(Copy, Clone, PartialOrd, Ord, Eq, PartialEq, Debug, Hash)] -pub(crate) struct ClauseId(u32); +pub(crate) struct ClauseId(NonZeroU32); impl ClauseId { - /// There is a guarentee that ClauseId(0) will always be + /// There is a guarentee that ClauseId(1) will always be /// "Clause::InstallRoot". This assumption is verified by the solver. pub(crate) fn install_root() -> Self { - Self(0) - } - - pub(crate) fn is_null(self) -> bool { - self.0 == u32::MAX - } - - pub(crate) fn null() -> ClauseId { - ClauseId(u32::MAX) + Self(unsafe { NonZeroU32::new_unchecked(1) }) } } impl ArenaId for ClauseId { fn from_usize(x: usize) -> Self { - assert!(x < u32::MAX as usize, "clause id too big"); - Self(x as u32) + // SAFETY: Safe because we always add 1 to the index + Self(unsafe { NonZeroU32::new_unchecked((x + 1).try_into().expect("clause id too big")) }) } fn to_usize(self) -> usize { - self.0 as usize + (self.0.get() - 1) as usize } } @@ -236,3 +231,17 @@ impl ArenaId for DependenciesId { self.0 as usize } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_clause_id_size() { + // Verify that the size of a ClauseId is the same as an Option. + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::>() + ); + } +} diff --git a/src/internal/mapping.rs b/src/internal/mapping.rs index 26df86b..d0e8ea7 100644 --- a/src/internal/mapping.rs +++ b/src/internal/mapping.rs @@ -65,6 +65,21 @@ impl Mapping { previous_value } + /// Unset a specific value in the mapping, returns the previous value. + pub fn unset(&mut self, id: TId) -> Option { + let idx = id.to_usize(); + let (chunk, offset) = Self::chunk_and_offset(idx); + if chunk >= self.chunks.len() { + return None; + } + + let previous_value = self.chunks[chunk][offset].take(); + if previous_value.is_some() { + self.len -= 1; + } + previous_value + } + /// Get a specific value in the mapping with bound checks pub fn get(&self, id: TId) -> Option<&TValue> { let (chunk, offset) = Self::chunk_and_offset(id.to_usize()); diff --git a/src/solver/clause.rs b/src/solver/clause.rs index ca4f5bf..7624d6d 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -324,7 +324,7 @@ pub(crate) struct ClauseState { // The ids of the solvables this clause is watching pub watched_literals: [Literal; 2], // The ids of the next clause in each linked list that this clause is part of - pub(crate) next_watches: [ClauseId; 2], + pub(crate) next_watches: [Option; 2], } impl ClauseState { @@ -417,7 +417,7 @@ impl ClauseState { let clause = Self { watched_literals, - next_watches: [ClauseId::null(), ClauseId::null()], + next_watches: [None, None], }; debug_assert!(!clause.has_watches() || watched_literals[0] != watched_literals[1]); @@ -425,7 +425,7 @@ impl ClauseState { clause } - pub fn link_to_clause(&mut self, watch_index: usize, linked_clause: ClauseId) { + pub fn link_to_clause(&mut self, watch_index: usize, linked_clause: Option) { self.next_watches[watch_index] = linked_clause; } @@ -444,7 +444,7 @@ impl ClauseState { } #[inline] - pub fn next_watched_clause(&self, solvable_id: InternalSolvableId) -> ClauseId { + pub fn next_watched_clause(&self, solvable_id: InternalSolvableId) -> Option { if solvable_id == self.watched_literals[0].solvable_id() { self.next_watches[0] } else { @@ -650,7 +650,7 @@ mod test { use super::*; use crate::{internal::arena::ArenaId, solver::decision::Decision}; - fn clause(next_clauses: [ClauseId; 2], watch_literals: [Literal; 2]) -> ClauseState { + fn clause(next_clauses: [Option; 2], watch_literals: [Literal; 2]) -> ClauseState { ClauseState { watched_literals: watch_literals, next_watches: next_clauses, @@ -691,21 +691,24 @@ mod test { #[test] fn test_unlink_clause_different() { let clause1 = clause( - [ClauseId::from_usize(2), ClauseId::from_usize(3)], + [ + ClauseId::from_usize(2).into(), + ClauseId::from_usize(3).into(), + ], [ InternalSolvableId::from_usize(1596).negative(), InternalSolvableId::from_usize(1211).negative(), ], ); let clause2 = clause( - [ClauseId::null(), ClauseId::from_usize(3)], + [None, ClauseId::from_usize(3).into()], [ InternalSolvableId::from_usize(1596).negative(), InternalSolvableId::from_usize(1208).negative(), ], ); let clause3 = clause( - [ClauseId::null(), ClauseId::null()], + [None, None], [ InternalSolvableId::from_usize(1211).negative(), InternalSolvableId::from_usize(42).negative(), @@ -723,10 +726,7 @@ mod test { InternalSolvableId::from_usize(1211).negative() ] ); - assert_eq!( - clause1.next_watches, - [ClauseId::null(), ClauseId::from_usize(3)] - ) + assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(3).into()]) } // Unlink 1 @@ -740,24 +740,24 @@ mod test { InternalSolvableId::from_usize(1211).negative() ] ); - assert_eq!( - clause1.next_watches, - [ClauseId::from_usize(2), ClauseId::null()] - ) + assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None]) } } #[test] fn test_unlink_clause_same() { let clause1 = clause( - [ClauseId::from_usize(2), ClauseId::from_usize(2)], + [ + ClauseId::from_usize(2).into(), + ClauseId::from_usize(2).into(), + ], [ InternalSolvableId::from_usize(1596).negative(), InternalSolvableId::from_usize(1211).negative(), ], ); let clause2 = clause( - [ClauseId::null(), ClauseId::null()], + [None, None], [ InternalSolvableId::from_usize(1596).negative(), InternalSolvableId::from_usize(1211).negative(), @@ -775,10 +775,7 @@ mod test { InternalSolvableId::from_usize(1211).negative() ] ); - assert_eq!( - clause1.next_watches, - [ClauseId::null(), ClauseId::from_usize(2)] - ) + assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(2).into()]) } // Unlink 1 @@ -792,10 +789,7 @@ mod test { InternalSolvableId::from_usize(1211).negative() ] ); - assert_eq!( - clause1.next_watches, - [ClauseId::from_usize(2), ClauseId::null()] - ) + assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None]) } } @@ -820,7 +814,10 @@ mod test { // No conflict, still one candidate available decisions - .try_add_decision(Decision::new(candidate1.into(), false, ClauseId::null()), 1) + .try_add_decision( + Decision::new(candidate1.into(), false, ClauseId::from_usize(0)), + 1, + ) .unwrap(); let (clause, conflict, _kind) = ClauseState::requires( parent, @@ -834,7 +831,10 @@ mod test { // Conflict, no candidates available decisions - .try_add_decision(Decision::new(candidate2.into(), false, ClauseId::null()), 1) + .try_add_decision( + Decision::new(candidate2.into(), false, ClauseId::install_root()), + 1, + ) .unwrap(); let (clause, conflict, _kind) = ClauseState::requires( parent, @@ -848,7 +848,7 @@ mod test { // Panic decisions - .try_add_decision(Decision::new(parent, false, ClauseId::null()), 1) + .try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1) .unwrap(); let panicked = std::panic::catch_unwind(|| { ClauseState::requires( @@ -878,7 +878,7 @@ mod test { // Conflict, forbidden package installed decisions - .try_add_decision(Decision::new(forbidden, true, ClauseId::null()), 1) + .try_add_decision(Decision::new(forbidden, true, ClauseId::install_root()), 1) .unwrap(); let (clause, conflict, _kind) = ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions); @@ -888,7 +888,7 @@ mod test { // Panic decisions - .try_add_decision(Decision::new(parent, false, ClauseId::null()), 1) + .try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1) .unwrap(); let panicked = std::panic::catch_unwind(|| { ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions) diff --git a/src/solver/mod.rs b/src/solver/mod.rs index c01997d..447314a 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -1435,11 +1435,8 @@ impl Solver { // solvable let mut old_predecessor_clause_id: Option; let mut predecessor_clause_id: Option = None; - let mut clause_id = self - .watches - .first_clause_watching_literal(watched_literal) - .unwrap_or(ClauseId::null()); - while !clause_id.is_null() { + let mut next_clause_id = self.watches.first_clause_watching_literal(watched_literal); + while let Some(clause_id) = next_clause_id { debug_assert!( predecessor_clause_id != Some(clause_id), "Linked list is circular!" @@ -1466,8 +1463,7 @@ impl Solver { predecessor_clause_id = Some(clause_id); // Configure the next clause to visit - let this_clause_id = clause_id; - clause_id = clause_state.next_watched_clause(watched_literal.solvable_id()); + next_clause_id = clause_state.next_watched_clause(watched_literal.solvable_id()); // Determine which watch turned false. let (watch_index, other_watch_index) = if clause_state.watched_literals[0] @@ -1492,7 +1488,7 @@ impl Solver { // If the other watch is already true, we can simply skip // this clause. } else if let Some(variable) = clause_state.next_unwatched_literal( - &clauses[this_clause_id.to_usize()], + &clauses[clause_id.to_usize()], &self.learnt_clauses, &self.cache.requirement_to_sorted_candidates, self.decision_tracker.map(), @@ -1501,7 +1497,7 @@ impl Solver { self.watches.update_watched( predecessor_clause_state, clause_state, - this_clause_id, + clause_id, watch_index, watched_literal, variable, @@ -1527,7 +1523,7 @@ impl Solver { Decision::new( remaining_watch.solvable_id(), remaining_watch.satisfying_value(), - this_clause_id, + clause_id, ), level, ) @@ -1535,12 +1531,12 @@ impl Solver { PropagationError::Conflict( remaining_watch.solvable_id(), true, - this_clause_id, + clause_id, ) })?; if decided { - let clause = &clauses[this_clause_id.to_usize()]; + let clause = &clauses[clause_id.to_usize()]; match clause { // Skip logging for ForbidMultipleInstances, which is so noisy Clause::ForbidMultipleInstances(..) => {} diff --git a/src/solver/watch_map.rs b/src/solver/watch_map.rs index 698aaf9..86fb281 100644 --- a/src/solver/watch_map.rs +++ b/src/solver/watch_map.rs @@ -1,7 +1,6 @@ -use crate::solver::clause::Literal; use crate::{ internal::{id::ClauseId, mapping::Mapping}, - solver::clause::ClauseState, + solver::clause::{ClauseState, Literal}, }; /// A map from solvables to the clauses that are watching them @@ -20,9 +19,7 @@ impl WatchMap { pub(crate) fn start_watching(&mut self, clause: &mut ClauseState, clause_id: ClauseId) { for (watch_index, watched_literal) in clause.watched_literals.into_iter().enumerate() { - let already_watching = self - .first_clause_watching_literal(watched_literal) - .unwrap_or(ClauseId::null()); + let already_watching = self.first_clause_watching_literal(watched_literal); clause.link_to_clause(watch_index, already_watching); self.watch_literal(watched_literal, clause_id); } @@ -42,18 +39,16 @@ impl WatchMap { if let Some(predecessor_clause) = predecessor_clause { // Unlink the clause predecessor_clause.unlink_clause(clause, previous_watch.solvable_id(), watch_index); - } else { + } else if let Some(next_watch) = clause.next_watches[watch_index] { // This was the first clause in the chain - self.map - .insert(previous_watch, clause.next_watches[watch_index]); + self.map.insert(previous_watch, next_watch); + } else { + self.map.unset(previous_watch); } // Set the new watch clause.watched_literals[watch_index] = new_watch; - let previous_clause_id = self - .map - .insert(new_watch, clause_id) - .unwrap_or(ClauseId::null()); + let previous_clause_id = self.map.insert(new_watch, clause_id); clause.next_watches[watch_index] = previous_clause_id; }