diff --git a/src/solver/clause.rs b/src/solver/clause.rs index 406e258..24931a8 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -6,6 +6,7 @@ use std::{ }; use elsa::FrozenMap; +use itertools::Itertools; use crate::{ internal::{ @@ -237,8 +238,7 @@ impl Clause { fn conditional( parent_id: VariableId, requirement: Requirement, - condition_variable: VariableId, - condition_version_set_id: VersionSetId, + condition_variables: Vec, decision_tracker: &DecisionTracker, requirement_candidates: impl IntoIterator, ) -> (Self, Option<[Literal; 2]>, bool) { @@ -246,8 +246,8 @@ impl Clause { let mut requirement_candidates = requirement_candidates.into_iter(); let requirement_literal = - if decision_tracker.assigned_value(condition_variable) == Some(true) { - // then ~condition is false + if condition_variables.iter().all(|condition_variable| decision_tracker.assigned_value(*condition_variable) == Some(true)) { + // then all of the conditions are true, so we can require the requirement requirement_candidates .find(|&id| decision_tracker.assigned_value(id) != Some(false)) .map(|id| id.positive()) @@ -256,13 +256,13 @@ impl Clause { }; ( - Clause::Conditional(parent_id, condition_variable, requirement), + Clause::Conditional(parent_id, condition_variables, requirement), Some([ parent_id.negative(), - requirement_literal.unwrap_or(condition_variable.negative()), + requirement_literal.unwrap_or(condition_variables.first().unwrap().negative()), ]), requirement_literal.is_none() - && decision_tracker.assigned_value(condition_variable) == Some(true), + && condition_variables.iter().all(|condition_variable| decision_tracker.assigned_value(*condition_variable) == Some(true)), ) } @@ -308,9 +308,9 @@ impl Clause { Clause::Lock(_, s) => [s.negative(), VariableId::root().negative()] .into_iter() .try_fold(init, visit), - Clause::Conditional(package_id, condition_variable, _, requirement) => { + Clause::Conditional(package_id, condition_variables, requirement) => { iter::once(package_id.negative()) - .chain(iter::once(condition_variable.negative())) + .chain(condition_variables.iter().map(|c| c.negative())) .chain( requirements_to_sorted_candidates[&requirement] .iter() @@ -474,16 +474,14 @@ impl WatchedLiterals { pub fn conditional( package_id: VariableId, requirement: Requirement, - condition_variable: VariableId, - condition_version_set_id: VersionSetId, + condition_variables: Vec, decision_tracker: &DecisionTracker, requirement_candidates: impl IntoIterator, ) -> (Option, bool, Clause) { let (kind, watched_literals, conflict) = Clause::conditional( package_id, requirement, - condition_variable, - condition_version_set_id, + condition_variables, decision_tracker, requirement_candidates, ); @@ -687,14 +685,16 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { other, ) } - Clause::Conditional(package_id, condition_variable, _, requirement) => { + Clause::Conditional(package_id, condition_variables, requirement) => { write!( f, - "Conditional({}({:?}), {}({:?}), {})", + "Conditional({}({:?}), {}, {})", package_id.display(self.variable_map, self.interner), package_id, - condition_variable.display(self.variable_map, self.interner), - condition_variable, + condition_variables + .iter() + .map(|v| v.display(self.variable_map, self.interner)) + .join(", "), requirement.display(self.interner), ) }