diff --git a/src/conflict.rs b/src/conflict.rs index a597301..2fddec3 100644 --- a/src/conflict.rs +++ b/src/conflict.rs @@ -14,10 +14,14 @@ use petgraph::{ use crate::{ internal::{ arena::ArenaId, - id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VersionSetId}, + id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VariableId, VersionSetId}, }, runtime::AsyncRuntime, - solver::{clause::Clause, variable_map::VariableOrigin, Solver}, + solver::{ + clause::Clause, + variable_map::{VariableMap, VariableOrigin}, + Solver, + }, DependencyProvider, Interner, Requirement, }; @@ -159,7 +163,12 @@ impl Conflict { ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)), ); } - &Clause::Conditional(package_id, condition, requirement) => { + &Clause::Conditional( + package_id, + condition_variable, + condition_version_set_id, + requirement, + ) => { let solvable = package_id .as_solvable_or_root(&solver.variable_map) .expect("only solvables can be excluded"); @@ -176,10 +185,6 @@ impl Conflict { ) }); - let conditional_candidates = solver.async_runtime.block_on(solver.cache.get_or_cache_sorted_candidates(condition.into())).unwrap_or_else(|_| { - unreachable!("The condition's version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`") - }); - if requirement_candidates.is_empty() { tracing::trace!( "{package_id:?} conditionally requires {requirement:?}, which has no candidates" @@ -187,32 +192,27 @@ impl Conflict { graph.add_edge( package_node, unresolved_node, - ConflictEdge::ConditionalRequires(condition, requirement), + ConflictEdge::ConditionalRequires( + condition_version_set_id, + requirement, + ), ); - } else if conditional_candidates.is_empty() { + } else { tracing::trace!( - "{package_id:?} conditionally requires {requirement:?}, but the condition has no candidates" + "{package_id:?} conditionally requires {requirement:?} if {condition_variable:?}" ); - graph.add_edge( - package_node, - unresolved_node, - ConflictEdge::ConditionalRequires(condition, requirement), - ); - } else { - for &candidate_id in conditional_candidates { - tracing::trace!( - "{package_id:?} conditionally requires {requirement:?} if {candidate_id:?}" - ); - for &candidate_id in requirement_candidates { - let candidate_node = - Self::add_node(&mut graph, &mut nodes, candidate_id.into()); - graph.add_edge( - package_node, - candidate_node, - ConflictEdge::ConditionalRequires(condition, requirement), - ); - } + for &candidate_id in requirement_candidates { + let candidate_node = + Self::add_node(&mut graph, &mut nodes, candidate_id.into()); + graph.add_edge( + package_node, + candidate_node, + ConflictEdge::ConditionalRequires( + condition_version_set_id, + requirement, + ), + ); } } } @@ -415,10 +415,10 @@ impl ConflictGraph { ConflictEdge::Requires(requirement) => { requirement.display(interner).to_string() } - ConflictEdge::ConditionalRequires(version_set_id, requirement) => { + ConflictEdge::ConditionalRequires(condition_version_set_id, requirement) => { format!( "if {} then {}", - interner.display_version_set(*version_set_id), + interner.display_version_set(*condition_version_set_id), requirement.display(interner) ) } diff --git a/src/solver/clause.rs b/src/solver/clause.rs index 2693f3f..99bb693 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -80,7 +80,9 @@ pub(crate) enum Clause { /// In SAT terms: (¬A ∨ ¬C ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, /// C is the condition, and B1 to B99 represent the possible candidates for /// the provided [`Requirement`]. - Conditional(VariableId, VersionSetId, Requirement), + /// We need to store the version set id because in the conflict graph, the version set id + /// is used to identify the condition variable. + Conditional(VariableId, VariableId, VersionSetId, Requirement), /// Forbids the package on the right-hand side /// /// Note that the package on the left-hand side is not part of the clause, @@ -237,49 +239,38 @@ impl Clause { fn conditional( parent_id: VariableId, requirement: Requirement, - condition: VersionSetId, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, decision_tracker: &DecisionTracker, requirement_candidates: impl IntoIterator, - condition_candidates: impl IntoIterator, ) -> (Self, Option<[Literal; 2]>, bool) { assert_ne!(decision_tracker.assigned_value(parent_id), Some(false)); - let mut condition_candidates = condition_candidates.into_iter(); - let requirement_candidates = requirement_candidates.into_iter(); - - // Check if we have any condition candidates - let first_condition = condition_candidates - .next() - .expect("no condition candidates"); - - // Map condition candidates to negative literals and requirement candidates to positive literals - let mut iter = condition_candidates - .map(|id| (id, id.negative())) - .chain(requirement_candidates.map(|id| (id, id.positive()))) - .peekable(); - - let condition_literal = if iter.peek().is_some() { - iter.find(|&(id, _)| { - let value = decision_tracker.assigned_value(id); - value.is_none() || value == Some(true) - }) - .map(|(_, literal)| literal) - } else { - None - }; - match condition_literal { - // Found a valid literal - use it - Some(literal) => ( - Clause::Conditional(parent_id, condition, requirement), - Some([parent_id.negative(), literal]), - false, - ), - // No valid literals found - conflict case - None => ( - Clause::Conditional(parent_id, condition, requirement), - Some([parent_id.negative(), first_condition.negative()]), - true, + let mut requirement_candidates = requirement_candidates.into_iter(); + + let requirement_literal = + if decision_tracker.assigned_value(condition_variable) == Some(true) { + // then ~condition is false + requirement_candidates + .find(|&id| decision_tracker.assigned_value(id) != Some(false)) + .map(|id| id.positive()) + } else { + None + }; + + ( + Clause::Conditional( + parent_id, + condition_variable, + condition_version_set_id, + requirement, ), - } + Some([ + parent_id.negative(), + requirement_literal.unwrap_or(condition_variable.negative()), + ]), + requirement_literal.is_none() + && decision_tracker.assigned_value(condition_variable) == Some(true), + ) } /// Tries to fold over all the literals in the clause. @@ -294,11 +285,6 @@ impl Clause { Vec>, ahash::RandomState, >, - version_set_to_variables: &FrozenMap< - VersionSetId, - Vec>, - ahash::RandomState, - >, init: C, mut visit: F, ) -> ControlFlow @@ -329,14 +315,9 @@ impl Clause { Clause::Lock(_, s) => [s.negative(), VariableId::root().negative()] .into_iter() .try_fold(init, visit), - Clause::Conditional(package_id, condition, requirement) => { + Clause::Conditional(package_id, condition_variable, _, requirement) => { iter::once(package_id.negative()) - .chain( - version_set_to_variables[&condition] - .iter() - .flatten() - .map(|&s| s.negative()), - ) + .chain(iter::once(condition_variable.negative())) .chain( requirements_to_sorted_candidates[&requirement] .iter() @@ -359,17 +340,11 @@ impl Clause { Vec>, ahash::RandomState, >, - version_set_to_variables: &FrozenMap< - VersionSetId, - Vec>, - ahash::RandomState, - >, mut visit: impl FnMut(Literal), ) { self.try_fold_literals( learnt_clauses, requirements_to_sorted_candidates, - version_set_to_variables, (), |_, lit| { visit(lit); @@ -506,18 +481,18 @@ impl WatchedLiterals { pub fn conditional( package_id: VariableId, requirement: Requirement, - condition: VersionSetId, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, decision_tracker: &DecisionTracker, requirement_candidates: impl IntoIterator, - condition_candidates: impl IntoIterator, ) -> (Option, bool, Clause) { let (kind, watched_literals, conflict) = Clause::conditional( package_id, requirement, - condition, + condition_variable, + condition_version_set_id, decision_tracker, requirement_candidates, - condition_candidates, ); ( @@ -545,11 +520,6 @@ impl WatchedLiterals { Vec>, ahash::RandomState, >, - version_set_to_variables: &FrozenMap< - VersionSetId, - Vec>, - ahash::RandomState, - >, decision_map: &DecisionMap, for_watch_index: usize, ) -> Option { @@ -566,7 +536,6 @@ impl WatchedLiterals { let next = clause.try_fold_literals( learnt_clauses, requirement_to_sorted_candidates, - version_set_to_variables, (), |_, lit| { // The next unwatched variable (if available), is a variable that is: @@ -725,13 +694,14 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { other, ) } - Clause::Conditional(package_id, condition, requirement) => { + Clause::Conditional(package_id, condition_variable, _, requirement) => { write!( f, - "Conditional({}({:?}), {}, {})", + "Conditional({}({:?}), {}({:?}), {})", package_id.display(self.variable_map, self.interner), package_id, - self.interner.display_version_set(condition), + condition_variable.display(self.variable_map, self.interner), + condition_variable, requirement.display(self.interner), ) } diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 7e7e97a..bda9bd4 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -37,7 +37,7 @@ mod watch_map; #[derive(Default)] struct AddClauseOutput { new_requires_clauses: Vec<(VariableId, Requirement, ClauseId)>, - new_conditional_clauses: Vec<(VariableId, VersionSetId, Requirement, ClauseId)>, + new_conditional_clauses: Vec<(VariableId, VariableId, Requirement, ClauseId)>, conflicting_clauses: Vec, negative_assertions: Vec<(VariableId, ClauseId)>, clauses_to_watch: Vec, @@ -153,7 +153,7 @@ pub struct Solver { pub(crate) clauses: Clauses, requires_clauses: IndexMap, ahash::RandomState>, conditional_clauses: - IndexMap<(VariableId, VersionSetId), Vec<(Requirement, ClauseId)>, ahash::RandomState>, + IndexMap<(VariableId, VariableId), Vec<(Requirement, ClauseId)>, ahash::RandomState>, watches: WatchMap, /// A mapping from requirements to the variables that represent the @@ -161,10 +161,6 @@ pub struct Solver { requirement_to_sorted_candidates: FrozenMap, - /// A mapping from version sets to the variables that represent the - /// candidates. - version_set_to_variables: FrozenMap>, ahash::RandomState>, - pub(crate) variable_map: VariableMap, negative_assertions: Vec<(VariableId, ClauseId)>, @@ -210,7 +206,6 @@ impl Solver { requires_clauses: Default::default(), conditional_clauses: Default::default(), requirement_to_sorted_candidates: FrozenMap::default(), - version_set_to_variables: FrozenMap::default(), watches: WatchMap::new(), negative_assertions: Default::default(), learnt_clauses: Arena::new(), @@ -291,7 +286,6 @@ impl Solver { requires_clauses: self.requires_clauses, conditional_clauses: self.conditional_clauses, requirement_to_sorted_candidates: self.requirement_to_sorted_candidates, - version_set_to_variables: self.version_set_to_variables, watches: self.watches, negative_assertions: self.negative_assertions, learnt_clauses: self.learnt_clauses, @@ -489,7 +483,6 @@ impl Solver { &mut self.clauses_added_for_package, &mut self.forbidden_clauses_added, &mut self.requirement_to_sorted_candidates, - &self.version_set_to_variables, &self.root_requirements, &self.root_constraints, ))?; @@ -606,7 +599,6 @@ impl Solver { &mut self.clauses_added_for_package, &mut self.forbidden_clauses_added, &mut self.requirement_to_sorted_candidates, - &self.version_set_to_variables, &self.root_requirements, &self.root_constraints, ))?; @@ -674,9 +666,11 @@ impl Solver { .push((requirement, clause_id)); } - for (solvable_id, condition, requirement, clause_id) in output.new_conditional_clauses { + for (solvable_id, condition_variable, requirement, clause_id) in + output.new_conditional_clauses + { self.conditional_clauses - .entry((solvable_id, condition)) + .entry((solvable_id, condition_variable)) .or_default() .push((requirement, clause_id)); } @@ -832,25 +826,10 @@ impl Solver { } // For conditional clauses, check that at least one conditional variable is true - if let Some(condition) = condition { - tracing::trace!("condition o kir: {:?}", condition); - let condition_requirement: Requirement = condition.into(); - let conditional_candidates = match self - .requirement_to_sorted_candidates - .get(&condition_requirement) - { - Some(candidates) => candidates, - None => continue, - }; - + if let Some(condition_variable) = condition { // Check if any candidate that matches the condition's version set is installed - let condition_met = conditional_candidates.iter().any(|candidates| { - candidates.iter().any(|&candidate| { - // Only consider the condition met if a candidate that exactly matches - // the condition's version set is installed - self.decision_tracker.assigned_value(candidate) == Some(true) - }) - }); + let condition_met = + self.decision_tracker.assigned_value(condition_variable) == Some(true); // If the condition is not met, skip this requirement entirely if !condition_met { @@ -1216,7 +1195,6 @@ impl Solver { clause, &self.learnt_clauses, &self.requirement_to_sorted_candidates, - &self.version_set_to_variables, self.decision_tracker.map(), watch_index, ) { @@ -1376,7 +1354,6 @@ impl Solver { self.clauses.kinds[clause_id.to_usize()].visit_literals( &self.learnt_clauses, &self.requirement_to_sorted_candidates, - &self.version_set_to_variables, |literal| { involved.insert(literal.variable()); }, @@ -1415,7 +1392,6 @@ impl Solver { self.clauses.kinds[why.to_usize()].visit_literals( &self.learnt_clauses, &self.requirement_to_sorted_candidates, - &self.version_set_to_variables, |literal| { if literal.eval(self.decision_tracker.map()) == Some(true) { assert_eq!(literal.variable(), decision.variable); @@ -1463,7 +1439,6 @@ impl Solver { clause_kinds[clause_id.to_usize()].visit_literals( &self.learnt_clauses, &self.requirement_to_sorted_candidates, - &self.version_set_to_variables, |literal| { if !first_iteration && literal.variable() == conflicting_solvable { // We are only interested in the causes of the conflict, so we ignore the @@ -1598,7 +1573,6 @@ async fn add_clauses_for_solvables( RequirementCandidateVariables, ahash::RandomState, >, - version_set_to_variables: &FrozenMap>, ahash::RandomState>, root_requirements: &[ConditionalRequirement], root_constraints: &[VersionSetId], ) -> Result> { @@ -1614,7 +1588,7 @@ async fn add_clauses_for_solvables( SortedCandidates { solvable_id: SolvableOrRootId, requirement: Requirement, - condition: Option<(VersionSetId, Vec<&'i [SolvableId]>)>, + condition: Option<(SolvableId, VersionSetId)>, candidates: Vec<&'i [SolvableId]>, }, NonMatchingCandidates { @@ -1775,40 +1749,48 @@ async fn add_clauses_for_solvables( for conditional_requirement in conditional_requirements { // Find all the solvable that match for the given version set - pending_futures.push( - async move { - let version_sets = - conditional_requirement.requirement_version_sets(cache.provider()); - let candidates = - futures::future::try_join_all(version_sets.map(|version_set| { - cache - .get_or_cache_sorted_candidates_for_version_set(version_set) - })) - .await?; - - // condition is `VersionSetId` right now but it will become a `Requirement` - // in the next versions of resolvo - if let Some(condition) = conditional_requirement.condition { - let condition_candidates = - cache.get_or_cache_matching_candidates(condition).await?; - - Ok(TaskResult::SortedCandidates { - solvable_id, - requirement: conditional_requirement.requirement, - condition: Some((condition, vec![condition_candidates])), - candidates, - }) - } else { + let version_sets = + conditional_requirement.requirement_version_sets(cache.provider()); + let candidates = + futures::future::try_join_all(version_sets.map(|version_set| { + cache.get_or_cache_sorted_candidates_for_version_set(version_set) + })) + .await?; + + // condition is `VersionSetId` right now but it will become a `Requirement` + // in the next versions of resolvo + if let Some(condition) = conditional_requirement.condition { + let condition_candidates = + cache.get_or_cache_matching_candidates(condition).await?; + + for &condition_candidate in condition_candidates { + let candidates = candidates.clone(); + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: Some((condition_candidate, condition)), + candidates, + }) + } + .boxed_local(), + ); + } + } else { + // Add a task result for the condition + pending_futures.push( + async move { Ok(TaskResult::SortedCandidates { solvable_id, requirement: conditional_requirement.requirement, condition: None, - candidates, + candidates: candidates.clone(), }) } - } - .boxed_local(), - ); + .boxed_local(), + ); + } } for version_set_id in constrains { @@ -1944,28 +1926,17 @@ async fn add_clauses_for_solvables( ); } - if let Some((condition, condition_candidates)) = condition { - let condition_version_set_variables = version_set_to_variables.insert( - condition, - condition_candidates - .iter() - .map(|&candidates| { - candidates - .iter() - .map(|&var| variable_map.intern_solvable(var)) - .collect() - }) - .collect(), - ); + if let Some((condition, condition_version_set_id)) = condition { + let condition_variable = variable_map.intern_solvable(condition); // Add a condition clause let (watched_literals, conflict, kind) = WatchedLiterals::conditional( variable, requirement, - condition, + condition_variable, + condition_version_set_id, decision_tracker, version_set_variables.iter().flatten().copied(), - condition_version_set_variables.iter().flatten().copied(), ); // Add the conditional clause @@ -1980,7 +1951,7 @@ async fn add_clauses_for_solvables( output.new_conditional_clauses.push(( variable, - condition, + condition_variable, requirement, clause_id, )); diff --git a/tests/solver.rs b/tests/solver.rs index 963799c..d556dd7 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -136,10 +136,10 @@ impl FromStr for Spec { type Err = (); fn from_str(s: &str) -> Result { - let (spec, condition) = s.split_once("; if").unwrap(); + let split = s.split_once("; if"); - if condition.is_empty() { - let split = spec.split(' ').collect::>(); + if split.is_none() { + let split = s.split(' ').collect::>(); let name = split .first() .expect("spec does not have a name") @@ -148,6 +148,8 @@ impl FromStr for Spec { return Ok(Spec::new(name, versions, None)); } + let (spec, condition) = split.unwrap(); + let condition = Spec::parse_union(condition).next().unwrap().unwrap(); let spec = Spec::from_str(spec).unwrap(); @@ -1641,7 +1643,7 @@ fn test_conditional_requirements_multiple_versions() { // Since b=4 is installed (not b 1..3), c should not be installed insta::assert_snapshot!(result, @r###" a=1 - b=4 + b=5 "###); }