diff --git a/cpp/include/resolvo.h b/cpp/include/resolvo.h index 97d00f5..5343ac7 100644 --- a/cpp/include/resolvo.h +++ b/cpp/include/resolvo.h @@ -4,6 +4,7 @@ #include "resolvo_internal.h" namespace resolvo { +using cbindgen_private::ConditionalRequirement; using cbindgen_private::Problem; using cbindgen_private::Requirement; @@ -24,6 +25,23 @@ inline Requirement requirement_union(VersionSetUnionId id) { return cbindgen_private::resolvo_requirement_union(id); } +/** + * Specifies a conditional requirement (dependency) of a single version set. + * A solvable belonging to the version set satisfies the requirement if the condition is true. + */ +inline ConditionalRequirement conditional_requirement_single(VersionSetId id) { + return cbindgen_private::resolvo_conditional_requirement_single(id); +} + +/** + * Specifies a conditional requirement (dependency) of the union (logical OR) of multiple version + * sets. A solvable belonging to any of the version sets contained in the union satisfies the + * requirement if the condition is true. + */ +inline ConditionalRequirement conditional_requirement_union(VersionSetUnionId id) { + return cbindgen_private::resolvo_conditional_requirement_union(id); +} + /** * Called to solve a package problem. * diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index 781e365..8f4a393 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -31,6 +31,133 @@ impl From for resolvo::SolvableId { } } +/// A wrapper around an optional string id. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct FfiOptionStringId { + pub is_some: bool, + pub value: StringId, +} + +impl From> for FfiOptionStringId { + fn from(opt: Option) -> Self { + match opt { + Some(v) => Self { + is_some: true, + value: v.into(), + }, + None => Self { + is_some: false, + value: StringId { id: 0 }, + }, + } + } +} + +impl From for Option { + fn from(ffi: FfiOptionStringId) -> Self { + if ffi.is_some { + Some(ffi.value.into()) + } else { + None + } + } +} + +/// A wrapper around an optional version set id. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct FfiOptionVersionSetId { + pub is_some: bool, + pub value: VersionSetId, +} + +impl From> for FfiOptionVersionSetId { + fn from(opt: Option) -> Self { + match opt { + Some(v) => Self { + is_some: true, + value: v.into(), + }, + None => Self { + is_some: false, + value: VersionSetId { id: 0 }, + }, + } + } +} + +impl From for Option { + fn from(ffi: FfiOptionVersionSetId) -> Self { + if ffi.is_some { + Some(ffi.value.into()) + } else { + None + } + } +} + +impl From> for FfiOptionVersionSetId { + fn from(opt: Option) -> Self { + match opt { + Some(v) => Self { + is_some: true, + value: v, + }, + None => Self { + is_some: false, + value: VersionSetId { id: 0 }, + }, + } + } +} + +impl From for Option { + fn from(ffi: FfiOptionVersionSetId) -> Self { + if ffi.is_some { + Some(ffi.value) + } else { + None + } + } +} + +/// Specifies a conditional requirement, where the requirement is only active when the condition is met. +/// First VersionSetId is the condition, second is the requirement. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct ConditionalRequirement { + pub condition: FfiOptionVersionSetId, + pub requirement: Requirement, + pub extra: FfiOptionStringId, +} + +impl From for ConditionalRequirement { + fn from(value: resolvo::ConditionalRequirement) -> Self { + Self { + condition: value.condition.into(), + requirement: value.requirement.into(), + extra: value.extra.into(), + } + } +} + +impl From for resolvo::ConditionalRequirement { + fn from(value: ConditionalRequirement) -> Self { + Self { + condition: value.condition.into(), + requirement: value.requirement.into(), + extra: value.extra.into(), + } + } +} + /// Specifies the dependency of a solvable on a set of version sets. /// cbindgen:derive-eq /// cbindgen:derive-neq @@ -162,7 +289,7 @@ pub struct Dependencies { /// A pointer to the first element of a list of requirements. Requirements /// defines which packages should be installed alongside the depending /// package and the constraints applied to the package. - pub requirements: Vector, + pub requirements: Vector, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set @@ -475,7 +602,7 @@ impl<'d> resolvo::DependencyProvider for &'d DependencyProvider { #[repr(C)] pub struct Problem<'a> { - pub requirements: Slice<'a, Requirement>, + pub requirements: Slice<'a, ConditionalRequirement>, pub constraints: Slice<'a, VersionSetId>, pub soft_requirements: Slice<'a, SolvableId>, } @@ -525,6 +652,30 @@ pub extern "C" fn resolvo_solve( } } +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_conditional_requirement_single( + version_set_id: VersionSetId, +) -> ConditionalRequirement { + ConditionalRequirement { + condition: Option::::None.into(), + requirement: Requirement::Single(version_set_id), + extra: None.into(), + } +} + +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_conditional_requirement_union( + version_set_union_id: VersionSetUnionId, +) -> ConditionalRequirement { + ConditionalRequirement { + condition: Option::::None.into(), + requirement: Requirement::Union(version_set_union_id), + extra: None.into(), + } +} + #[no_mangle] #[allow(unused)] pub extern "C" fn resolvo_requirement_single(version_set_id: VersionSetId) -> Requirement { diff --git a/cpp/tests/solve.cpp b/cpp/tests/solve.cpp index 1bb02b7..952e86e 100644 --- a/cpp/tests/solve.cpp +++ b/cpp/tests/solve.cpp @@ -48,16 +48,17 @@ struct PackageDatabase : public resolvo::DependencyProvider { /** * Allocates a new requirement for a single version set. */ - resolvo::Requirement alloc_requirement(std::string_view package, uint32_t version_start, - uint32_t version_end) { + resolvo::ConditionalRequirement alloc_requirement(std::string_view package, + uint32_t version_start, + uint32_t version_end) { auto id = alloc_version_set(package, version_start, version_end); - return resolvo::requirement_single(id); + return resolvo::conditional_requirement_single(id); } /** * Allocates a new requirement for a version set union. */ - resolvo::Requirement alloc_requirement_union( + resolvo::ConditionalRequirement alloc_requirement_union( std::initializer_list> version_sets) { std::vector version_set_union{version_sets.size()}; @@ -69,7 +70,7 @@ struct PackageDatabase : public resolvo::DependencyProvider { auto id = resolvo::VersionSetUnionId{static_cast(version_set_unions.size())}; version_set_unions.push_back(std::move(version_set_union)); - return resolvo::requirement_union(id); + return resolvo::conditional_requirement_union(id); } /** @@ -219,7 +220,8 @@ SCENARIO("Solve") { const auto d_1 = db.alloc_candidate("d", 1, {}); // Construct a problem to be solved by the solver - resolvo::Vector requirements = {db.alloc_requirement("a", 1, 3)}; + resolvo::Vector requirements = { + db.alloc_requirement("a", 1, 3)}; resolvo::Vector constraints = { db.alloc_version_set("b", 1, 3), db.alloc_version_set("c", 1, 3), @@ -263,7 +265,7 @@ SCENARIO("Solve Union") { "f", 1, {{db.alloc_requirement("b", 1, 10)}, {db.alloc_version_set("a", 10, 20)}}); // Construct a problem to be solved by the solver - resolvo::Vector requirements = { + resolvo::Vector requirements = { db.alloc_requirement_union({{"c", 1, 10}, {"d", 1, 10}}), db.alloc_requirement("e", 1, 10), db.alloc_requirement("f", 1, 10), diff --git a/src/conflict.rs b/src/conflict.rs index 3d121b6..1656209 100644 --- a/src/conflict.rs +++ b/src/conflict.rs @@ -11,14 +11,13 @@ use petgraph::{ Direction, }; -use crate::solver::variable_map::VariableOrigin; use crate::{ internal::{ arena::ArenaId, id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VersionSetId}, }, runtime::AsyncRuntime, - solver::{clause::Clause, Solver}, + solver::{clause::Clause, variable_map::VariableOrigin, Solver}, DependencyProvider, Interner, Requirement, }; @@ -160,6 +159,61 @@ impl Conflict { ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)), ); } + &Clause::Conditional( + package_id, + condition_variable, + condition_version_set_id, + requirement, + ) => { + let solvable = package_id + .as_solvable_or_root(&solver.variable_map) + .expect("only solvables can be excluded"); + let package_node = Self::add_node(&mut graph, &mut nodes, solvable); + + let requirement_candidates = solver + .async_runtime + .block_on(solver.cache.get_or_cache_sorted_candidates( + requirement, + )) + .unwrap_or_else(|_| { + unreachable!( + "The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`" + ) + }); + + if requirement_candidates.is_empty() { + tracing::trace!( + "{package_id:?} conditionally requires {requirement:?}, which has no candidates" + ); + graph.add_edge( + package_node, + unresolved_node, + ConflictEdge::ConditionalRequires( + condition_version_set_id, + requirement, + ), + ); + } else { + tracing::trace!( + "{package_id:?} conditionally requires {requirement:?} if {condition_variable:?}" + ); + + for &candidate_id in requirement_candidates { + let candidate_node = + Self::add_node(&mut graph, &mut nodes, candidate_id.into()); + graph.add_edge( + package_node, + candidate_node, + ConflictEdge::ConditionalRequires( + condition_version_set_id, + requirement, + ), + ); + } + } + } + &Clause::RequiresWithExtra(..) => todo!(), + &Clause::ConditionalWithExtra(..) => todo!(), } } @@ -239,19 +293,22 @@ impl ConflictNode { } /// An edge in the graph representation of a [`Conflict`] -#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)] pub(crate) enum ConflictEdge { /// The target node is a candidate for the dependency specified by the /// [`Requirement`] Requires(Requirement), /// The target node is involved in a conflict, caused by `ConflictCause` Conflict(ConflictCause), + /// The target node is a candidate for a conditional dependency + ConditionalRequires(VersionSetId, Requirement), } impl ConflictEdge { fn try_requires(self) -> Option { match self { ConflictEdge::Requires(match_spec_id) => Some(match_spec_id), + ConflictEdge::ConditionalRequires(_, _) => None, ConflictEdge::Conflict(_) => None, } } @@ -259,6 +316,9 @@ impl ConflictEdge { fn requires(self) -> Requirement { match self { ConflictEdge::Requires(match_spec_id) => match_spec_id, + ConflictEdge::ConditionalRequires(_, _) => { + panic!("expected requires edge, found conditional requires") + } ConflictEdge::Conflict(_) => panic!("expected requires edge, found conflict"), } } @@ -341,6 +401,11 @@ impl ConflictGraph { ConflictEdge::Requires(_) if target != ConflictNode::UnresolvedDependency => { "black" } + ConflictEdge::ConditionalRequires(_, _) + if target != ConflictNode::UnresolvedDependency => + { + "blue" + } _ => "red", }; @@ -348,6 +413,13 @@ impl ConflictGraph { ConflictEdge::Requires(requirement) => { requirement.display(interner).to_string() } + ConflictEdge::ConditionalRequires(condition_version_set_id, requirement) => { + format!( + "if {} then {}", + interner.display_version_set(*condition_version_set_id), + requirement.display(interner) + ) + } ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { interner.display_version_set(*version_set_id).to_string() } @@ -493,9 +565,12 @@ impl ConflictGraph { .graph .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { - ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::Requires(req) => (req, e.target()), + ConflictEdge::ConditionalRequires(_, req) => (req, e.target()), ConflictEdge::Conflict(_) => unreachable!(), }) + .collect::>() + .into_iter() .chunk_by(|(&version_set_id, _)| version_set_id); for (_, mut deps) in &dependencies { @@ -540,8 +615,13 @@ impl ConflictGraph { .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::ConditionalRequires(_, version_set_id) => { + (version_set_id, e.target()) + } ConflictEdge::Conflict(_) => unreachable!(), }) + .collect::>() + .into_iter() .chunk_by(|(&version_set_id, _)| version_set_id); // Missing if at least one dependency is missing @@ -1020,6 +1100,7 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { let conflict = match e.weight() { ConflictEdge::Requires(_) => continue, ConflictEdge::Conflict(conflict) => conflict, + ConflictEdge::ConditionalRequires(_, _) => continue, }; // The only possible conflict at the root level is a Locked conflict diff --git a/src/internal/id.rs b/src/internal/id.rs index 47fe226..e3b160a 100644 --- a/src/internal/id.rs +++ b/src/internal/id.rs @@ -46,6 +46,12 @@ impl ArenaId for StringId { #[cfg_attr(feature = "serde", serde(transparent))] pub struct VersionSetId(pub u32); +impl From<(VersionSetId, Option)> for VersionSetId { + fn from((id, _): (VersionSetId, Option)) -> Self { + id + } +} + impl ArenaId for VersionSetId { fn from_usize(x: usize) -> Self { Self(x as u32) diff --git a/src/lib.rs b/src/lib.rs index 575c678..74eb27e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ pub use internal::{ mapping::Mapping, }; use itertools::Itertools; -pub use requirement::Requirement; +pub use requirement::{ConditionalRequirement, Requirement}; pub use solver::{Problem, Solver, SolverCache, UnsolvableOrCancelled}; /// An object that is used by the solver to query certain properties of @@ -206,7 +206,7 @@ pub struct KnownDependencies { feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty") )] - pub requirements: Vec, + pub requirements: Vec, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set diff --git a/src/requirement.rs b/src/requirement.rs index 244ec48..b249262 100644 --- a/src/requirement.rs +++ b/src/requirement.rs @@ -1,7 +1,96 @@ -use crate::{Interner, VersionSetId, VersionSetUnionId}; +use crate::{Interner, StringId, VersionSetId, VersionSetUnionId}; use itertools::Itertools; use std::fmt::Display; +/// Specifies a conditional requirement, where the requirement is only active when the condition is met. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct ConditionalRequirement { + /// The condition that must be met for the requirement to be active. + pub condition: Option, + /// The requirement that is only active when the condition is met. + pub requirement: Requirement, + /// The extra that must be enabled for the requirement to be active. + pub extra: Option, +} + +impl ConditionalRequirement { + /// Creates a new conditional requirement. + pub fn new( + condition: Option, + requirement: Requirement, + extra: Option, + ) -> Self { + Self { + condition, + requirement, + extra, + } + } + /// Returns the version sets that satisfy the requirement. + pub fn requirement_version_sets<'i>( + &'i self, + interner: &'i impl Interner, + ) -> impl Iterator + 'i { + self.requirement.version_sets(interner) + } + + /// Returns the version sets that satisfy the requirement, along with the condition that must be met. + pub fn version_sets_with_condition<'i>( + &'i self, + interner: &'i impl Interner, + ) -> impl Iterator)> + 'i { + self.requirement + .version_sets(interner) + .map(move |vs| (vs, self.condition)) + } + + /// Returns the condition and requirement. + pub fn into_condition_and_requirement(self) -> (Option, Requirement) { + (self.condition, self.requirement) + } +} + +impl From for ConditionalRequirement { + fn from(value: Requirement) -> Self { + Self { + condition: None, + requirement: value, + extra: None, + } + } +} + +impl From for ConditionalRequirement { + fn from(value: VersionSetId) -> Self { + Self { + condition: None, + requirement: value.into(), + extra: None, + } + } +} + +impl From for ConditionalRequirement { + fn from(value: VersionSetUnionId) -> Self { + Self { + condition: None, + requirement: value.into(), + extra: None, + } + } +} + +impl From<(VersionSetId, Option)> for ConditionalRequirement { + fn from((requirement, condition): (VersionSetId, Option)) -> Self { + Self { + condition, + requirement: requirement.into(), + extra: None, + } + } +} + /// Specifies the dependency of a solvable on a set of version sets. #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] diff --git a/src/snapshot.rs b/src/snapshot.rs index 0b8b6d2..ab6d926 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -220,7 +220,15 @@ impl DependencySnapshot { } } - for &requirement in deps.requirements.iter() { + for &req in deps.requirements.iter() { + let (condition, requirement) = req.into_condition_and_requirement(); + + if let Some(condition) = condition { + if seen.insert(Element::VersionSet(condition)) { + queue.push_back(Element::VersionSet(condition)); + } + } + match requirement { Requirement::Single(version_set) => { if seen.insert(Element::VersionSet(version_set)) { diff --git a/src/solver/clause.rs b/src/solver/clause.rs index f034130..8b9d40b 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -46,7 +46,7 @@ use crate::{ /// limited set of clauses. There are thousands of clauses for a particular /// dependency resolution problem, and we try to keep the [`Clause`] enum small. /// A naive implementation would store a `Vec`. -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Copy, Debug)] pub(crate) enum Clause { /// An assertion that the root solvable must be installed /// @@ -77,6 +77,31 @@ pub(crate) enum Clause { /// /// In SAT terms: (¬A ∨ ¬B) Constrains(VariableId, VariableId, VersionSetId), + /// In SAT terms: (¬A ∨ ¬C ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, + /// C is the condition, and B1 to B99 represent the possible candidates for + /// the provided [`Requirement`]. + /// We need to store the version set id because in the conflict graph, the version set id + /// is used to identify the condition variable. + Conditional(VariableId, VariableId, VersionSetId, Requirement), + /// A conditional clause that requires a feature to be enabled for the requirement to be active. + /// + /// In SAT terms: (¬A ∨ ¬C ∨ ¬F ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, + /// C is the condition, F is the feature, and B1 to B99 represent the possible candidates for + /// the provided [`Requirement`]. + ConditionalWithExtra( + VariableId, // solvable + VariableId, // condition + VersionSetId, // condition version set + VariableId, // extra + StringId, // extra name + Requirement, // requirement + ), + /// A requirement that requires a feature to be enabled for the requirement to be active. + /// + /// In SAT terms: (¬A ∨ ¬F ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, + /// F is the feature, and B1 to B99 represent the possible candidates for + /// the provided [`Requirement`]. + RequiresWithExtra(VariableId, VariableId, StringId, Requirement), /// Forbids the package on the right-hand side /// /// Note that the package on the left-hand side is not part of the clause, @@ -230,6 +255,138 @@ impl Clause { ) } + fn conditional( + parent_id: VariableId, + requirement: Requirement, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Self, Option<[Literal; 2]>, bool) { + assert_ne!(decision_tracker.assigned_value(parent_id), Some(false)); + let mut requirement_candidates = requirement_candidates.into_iter(); + + let requirement_literal = + if decision_tracker.assigned_value(condition_variable) == Some(true) { + // then ~condition is false + requirement_candidates + .find(|&id| decision_tracker.assigned_value(id) != Some(false)) + .map(|id| id.positive()) + } else { + None + }; + + ( + Clause::Conditional( + parent_id, + condition_variable, + condition_version_set_id, + requirement, + ), + Some([ + parent_id.negative(), + requirement_literal.unwrap_or(condition_variable.negative()), + ]), + requirement_literal.is_none() + && decision_tracker.assigned_value(condition_variable) == Some(true), + ) + } + + #[allow(clippy::too_many_arguments)] + fn conditional_with_extra( + parent_id: VariableId, + requirement: Requirement, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, + extra_variable: VariableId, + extra_name: StringId, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Self, Option<[Literal; 2]>, bool) { + assert_ne!(decision_tracker.assigned_value(parent_id), Some(false)); + let mut requirement_candidates = requirement_candidates.into_iter(); + + let requirement_literal = + if decision_tracker.assigned_value(condition_variable) == Some(true) { + // then ~condition is false + // if the feature is enabled, then we need to watch the requirement candidates + if decision_tracker.assigned_value(extra_variable) == Some(true) { + requirement_candidates + .find(|&id| decision_tracker.assigned_value(id) != Some(false)) + .map(|id| id.positive()) + } else { + // if the feature is disabled, then we need to watch the feature variable + Some(extra_variable.negative()) + } + } else { + None + }; + ( + Clause::ConditionalWithExtra( + parent_id, + condition_variable, + condition_version_set_id, + extra_variable, + extra_name, + requirement, + ), + Some([ + parent_id.negative(), + requirement_literal.unwrap_or(condition_variable.negative()), + ]), + requirement_literal.is_none() + && decision_tracker.assigned_value(condition_variable) == Some(true) + && decision_tracker.assigned_value(extra_variable) == Some(true), + ) + } + + fn requires_with_extra( + solvable_id: VariableId, + extra_variable: VariableId, + extra_name: StringId, + requirement: Requirement, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Self, Option<[Literal; 2]>, bool) { + // It only makes sense to introduce a requires clause when the parent solvable + // is undecided or going to be installed + assert_ne!(decision_tracker.assigned_value(solvable_id), Some(false)); + + let kind = Clause::RequiresWithExtra(solvable_id, extra_variable, extra_name, requirement); + let mut candidates = requirement_candidates.into_iter().peekable(); + let first_candidate = candidates.peek().copied(); + + if decision_tracker.assigned_value(extra_variable) == Some(true) { + // Feature is enabled, so watch the requirement candidates + if let Some(first_candidate) = first_candidate { + match candidates.find(|&c| decision_tracker.assigned_value(c) != Some(false)) { + // Watch any candidate that is not assigned to false + Some(watched_candidate) => ( + kind, + Some([solvable_id.negative(), watched_candidate.positive()]), + false, + ), + // All candidates are assigned to false - conflict + None => ( + kind, + Some([solvable_id.negative(), first_candidate.positive()]), + true, + ), + } + } else { + // No candidates available + (kind, None, false) + } + } else { + // Feature is not enabled, so watch the feature variable + ( + kind, + Some([solvable_id.negative(), extra_variable.negative()]), + false, + ) + } + } + /// Tries to fold over all the literals in the clause. /// /// This function is useful to iterate, find, or filter the literals in a @@ -272,6 +429,46 @@ impl Clause { Clause::Lock(_, s) => [s.negative(), VariableId::root().negative()] .into_iter() .try_fold(init, visit), + Clause::Conditional(package_id, condition_variable, _, requirement) => { + iter::once(package_id.negative()) + .chain(iter::once(condition_variable.negative())) + .chain( + requirements_to_sorted_candidates[&requirement] + .iter() + .flatten() + .map(|&s| s.positive()), + ) + .try_fold(init, visit) + } + Clause::ConditionalWithExtra( + package_id, + condition_variable, + _, + extra_variable, + _, + requirement, + ) => iter::once(package_id.negative()) + .chain(iter::once(condition_variable.negative())) + .chain(iter::once(extra_variable.negative())) + .chain( + requirements_to_sorted_candidates[&requirement] + .iter() + .flatten() + .map(|&s| s.positive()), + ) + .try_fold(init, visit), + Clause::RequiresWithExtra(solvable_id, extra_variable, _, requirement) => { + iter::once(solvable_id.negative()) + .chain(iter::once(solvable_id.negative())) + .chain(iter::once(extra_variable.negative())) + .chain( + requirements_to_sorted_candidates[&requirement] + .iter() + .flatten() + .map(|&s| s.positive()), + ) + .try_fold(init, visit) + } } } @@ -419,6 +616,96 @@ impl WatchedLiterals { (Self::from_kind_and_initial_watches(watched_literals), kind) } + /// Shorthand method to construct a [Clause::Conditional] without requiring + /// complicated arguments. + /// + /// The returned boolean value is true when adding the clause resulted in a + /// conflict. + pub fn conditional( + package_id: VariableId, + requirement: Requirement, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Option, bool, Clause) { + let (kind, watched_literals, conflict) = Clause::conditional( + package_id, + requirement, + condition_variable, + condition_version_set_id, + decision_tracker, + requirement_candidates, + ); + + ( + WatchedLiterals::from_kind_and_initial_watches(watched_literals), + conflict, + kind, + ) + } + + /// Shorthand method to construct a [Clause::ConditionalWithExtra] without requiring + /// complicated arguments. + /// + /// The returned boolean value is true when adding the clause resulted in a + /// conflict. + #[allow(clippy::too_many_arguments)] + pub fn conditional_with_extra( + package_id: VariableId, + requirement: Requirement, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, + extra_variable: VariableId, + extra_name: StringId, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Option, bool, Clause) { + let (kind, watched_literals, conflict) = Clause::conditional_with_extra( + package_id, + requirement, + condition_variable, + condition_version_set_id, + extra_variable, + extra_name, + decision_tracker, + requirement_candidates, + ); + ( + WatchedLiterals::from_kind_and_initial_watches(watched_literals), + conflict, + kind, + ) + } + + /// Shorthand method to construct a [Clause::RequiresWithExtra] without requiring + /// complicated arguments. + /// + /// The returned boolean value is true when adding the clause resulted in a + /// conflict. + pub fn requires_with_extra( + solvable_id: VariableId, + extra_variable: VariableId, + extra_name: StringId, + requirement: Requirement, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Option, bool, Clause) { + let (kind, watched_literals, conflict) = Clause::requires_with_extra( + solvable_id, + extra_variable, + extra_name, + requirement, + decision_tracker, + requirement_candidates, + ); + ( + WatchedLiterals::from_kind_and_initial_watches(watched_literals), + conflict, + kind, + ) + } + fn from_kind_and_initial_watches(watched_literals: Option<[Literal; 2]>) -> Option { let watched_literals = watched_literals?; debug_assert!(watched_literals[0] != watched_literals[1]); @@ -611,6 +898,48 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { other, ) } + Clause::Conditional(package_id, condition_variable, _, requirement) => { + write!( + f, + "Conditional({}({:?}), {}({:?}), {})", + package_id.display(self.variable_map, self.interner), + package_id, + condition_variable.display(self.variable_map, self.interner), + condition_variable, + requirement.display(self.interner), + ) + } + Clause::ConditionalWithExtra( + package_id, + condition_variable, + _, + extra_variable, + _, + requirement, + ) => { + write!( + f, + "ConditionalWithExtra({}({:?}), {}({:?}), {}({:?}), {})", + package_id.display(self.variable_map, self.interner), + package_id, + condition_variable.display(self.variable_map, self.interner), + condition_variable, + extra_variable.display(self.variable_map, self.interner), + extra_variable, + requirement.display(self.interner), + ) + } + Clause::RequiresWithExtra(solvable_id, extra_variable, _, requirement) => { + write!( + f, + "RequiresWithExtra({}({:?}), {}({:?}), {})", + solvable_id.display(self.variable_map, self.interner), + solvable_id, + extra_variable.display(self.variable_map, self.interner), + extra_variable, + requirement.display(self.interner), + ) + } } } } @@ -671,17 +1000,11 @@ mod test { clause.as_ref().unwrap().watched_literals[0].variable(), parent ); - assert_eq!( - clause.unwrap().watched_literals[1].variable(), - candidate1.into() - ); + assert_eq!(clause.unwrap().watched_literals[1].variable(), candidate1); // No conflict, still one candidate available decisions - .try_add_decision( - Decision::new(candidate1.into(), false, ClauseId::from_usize(0)), - 1, - ) + .try_add_decision(Decision::new(candidate1, false, ClauseId::from_usize(0)), 1) .unwrap(); let (clause, conflict, _kind) = WatchedLiterals::requires( parent, @@ -696,13 +1019,13 @@ mod test { ); assert_eq!( clause.as_ref().unwrap().watched_literals[1].variable(), - candidate2.into() + candidate2 ); // Conflict, no candidates available decisions .try_add_decision( - Decision::new(candidate2.into(), false, ClauseId::install_root()), + Decision::new(candidate2, false, ClauseId::install_root()), 1, ) .unwrap(); @@ -719,7 +1042,7 @@ mod test { ); assert_eq!( clause.as_ref().unwrap().watched_literals[1].variable(), - candidate1.into() + candidate1 ); // Panic diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 8c0e026..f142fa1 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -19,9 +19,11 @@ use crate::{ id::{ClauseId, LearntClauseId, NameId, SolvableId, SolvableOrRootId, VariableId}, mapping::Mapping, }, + requirement::ConditionalRequirement, runtime::{AsyncRuntime, NowOrNeverRuntime}, solver::binary_encoding::AtMostOnceTracker, - Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, VersionSetId, + Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, StringId, + VersionSetId, }; mod binary_encoding; @@ -33,9 +35,21 @@ mod decision_tracker; pub(crate) mod variable_map; mod watch_map; +/// The output of the `add_clauses_for_solvables` function. +type AddConditionalClauseOutput = ( + VariableId, + Option, + Option, + Requirement, + ClauseId, +); + #[derive(Default)] struct AddClauseOutput { new_requires_clauses: Vec<(VariableId, Requirement, ClauseId)>, + /// A vector of tuples from a solvable variable, conditional variable and extra(feature name) variable to + /// the clauses that need to be watched. + new_conditional_clauses: Vec, conflicting_clauses: Vec, negative_assertions: Vec<(VariableId, ClauseId)>, clauses_to_watch: Vec, @@ -51,7 +65,7 @@ struct AddClauseOutput { /// This struct follows the builder pattern and can have its fields set by one /// of the available setter methods. pub struct Problem { - requirements: Vec, + requirements: Vec, constraints: Vec, soft_requirements: S, } @@ -80,7 +94,7 @@ impl> Problem { /// /// Returns the [`Problem`] for further mutation or to pass to /// [`Solver::solve`]. - pub fn requirements(self, requirements: Vec) -> Self { + pub fn requirements(self, requirements: Vec) -> Self { Self { requirements, ..self @@ -142,6 +156,7 @@ impl Clauses { } type RequirementCandidateVariables = Vec>; +type ConditionalClauseMap = (VariableId, Option, Option); /// Drives the SAT solving process. pub struct Solver { @@ -150,6 +165,11 @@ pub struct Solver { pub(crate) clauses: Clauses, requires_clauses: IndexMap, ahash::RandomState>, + + /// A map from a solvable variable, conditional variable and extra(feature name) variable to + /// the clauses that need to be watched. + conditional_clauses: + IndexMap, ahash::RandomState>, watches: WatchMap, /// A mapping from requirements to the variables that represent the @@ -172,7 +192,7 @@ pub struct Solver { decision_tracker: DecisionTracker, /// The [`Requirement`]s that must be installed as part of the solution. - root_requirements: Vec, + root_requirements: Vec, /// Additional constraints imposed by the root. root_constraints: Vec, @@ -200,6 +220,7 @@ impl Solver { clauses: Clauses::default(), variable_map: VariableMap::default(), requires_clauses: Default::default(), + conditional_clauses: Default::default(), requirement_to_sorted_candidates: FrozenMap::default(), watches: WatchMap::new(), negative_assertions: Default::default(), @@ -213,7 +234,6 @@ impl Solver { clauses_added_for_solvable: Default::default(), forbidden_clauses_added: Default::default(), name_activity: Default::default(), - activity_add: 1.0, activity_decay: 0.95, } @@ -280,6 +300,7 @@ impl Solver { clauses: self.clauses, variable_map: self.variable_map, requires_clauses: self.requires_clauses, + conditional_clauses: self.conditional_clauses, requirement_to_sorted_candidates: self.requirement_to_sorted_candidates, watches: self.watches, negative_assertions: self.negative_assertions, @@ -660,6 +681,16 @@ impl Solver { .or_default() .push((requirement, clause_id)); } + + for (solvable_id, condition_variable, extra_variable, requirement, clause_id) in + output.new_conditional_clauses + { + self.conditional_clauses + .entry((solvable_id, condition_variable, extra_variable)) + .or_default() + .push((requirement, clause_id)); + } + self.negative_assertions .append(&mut output.negative_assertions); @@ -695,7 +726,7 @@ impl Solver { fn resolve_dependencies(&mut self, mut level: u32) -> Result { loop { // Make a decision. If no decision could be made it means the problem is - // satisfyable. + // satisfiable. let Some((candidate, required_by, clause_id)) = self.decide() else { break; }; @@ -767,8 +798,38 @@ impl Solver { } let mut best_decision: Option = None; - for (&solvable_id, requirements) in self.requires_clauses.iter() { + + // Chain together the requires_clauses and conditional_clauses iterations + let requires_iter = self + .requires_clauses + .iter() + .map(|(&solvable_id, requirements)| { + ( + solvable_id, + None, + None, + requirements + .iter() + .map(|(r, c)| (*r, *c)) + .collect::>(), + ) + }); + + let conditional_iter = + self.conditional_clauses + .iter() + .map(|((solvable_id, condition, extra), clauses)| { + ( + *solvable_id, + *condition, + *extra, + clauses.iter().map(|(r, c)| (*r, *c)).collect::>(), + ) + }); + + for (solvable_id, condition, extra, requirements) in requires_iter.chain(conditional_iter) { let is_explicit_requirement = solvable_id == VariableId::root(); + if let Some(best_decision) = &best_decision { // If we already have an explicit requirement, there is no need to evaluate // non-explicit requirements. @@ -782,11 +843,32 @@ impl Solver { continue; } - for (deps, clause_id) in requirements.iter() { + // For conditional clauses, check that at least one conditional variable is true + if let Some(condition_variable) = condition { + // Check if any candidate that matches the condition's version set is installed + let condition_met = + self.decision_tracker.assigned_value(condition_variable) == Some(true); + + // If the condition is not met, skip this requirement entirely + if !condition_met { + continue; + } + } + + if let Some(extra_variable) = extra { + // If the extra is not enabled, skip this requirement entirely + if self.decision_tracker.assigned_value(extra_variable) != Some(true) { + continue; + } + } + + for (requirement, clause_id) in requirements { let mut candidate = ControlFlow::Break(()); // Get the candidates for the individual version sets. - let version_set_candidates = &self.requirement_to_sorted_candidates[deps]; + let version_set_candidates = &self.requirement_to_sorted_candidates[&requirement]; + + let version_sets = requirement.version_sets(self.provider()); // Iterate over all version sets in the requirement and find the first version // set that we can act on, or if a single candidate (from any version set) makes @@ -795,10 +877,7 @@ impl Solver { // NOTE: We zip the version sets from the requirements and the variables that we // previously cached. This assumes that the order of the version sets is the // same in both collections. - for (version_set, candidates) in deps - .version_sets(self.provider()) - .zip(version_set_candidates) - { + for (version_set, candidates) in version_sets.zip(version_set_candidates) { // Find the first candidate that is not yet assigned a value or find the first // value that makes this clause true. candidate = candidates.iter().try_fold( @@ -875,7 +954,7 @@ impl Solver { candidate_count, package_activity, ))) => { - let decision = (candidate, solvable_id, *clause_id); + let decision = (candidate, solvable_id, clause_id); best_decision = Some(match &best_decision { None => PossibleDecision { is_explicit_requirement, @@ -1519,7 +1598,7 @@ async fn add_clauses_for_solvables( RequirementCandidateVariables, ahash::RandomState, >, - root_requirements: &[Requirement], + root_requirements: &[ConditionalRequirement], root_constraints: &[VersionSetId], ) -> Result> { let mut output = AddClauseOutput::default(); @@ -1534,6 +1613,8 @@ async fn add_clauses_for_solvables( SortedCandidates { solvable_id: SolvableOrRootId, requirement: Requirement, + condition: Option<(SolvableId, VersionSetId)>, + extra: Option<(VariableId, StringId)>, candidates: Vec<&'i [SolvableId]>, }, NonMatchingCandidates { @@ -1615,7 +1696,7 @@ async fn add_clauses_for_solvables( None => variable_map.root(), }; - let (requirements, constrains) = match dependencies { + let (conditional_requirements, constrains) = match dependencies { Dependencies::Known(deps) => (deps.requirements, deps.constrains), Dependencies::Unknown(reason) => { // There is no information about the solvable's dependencies, so we add @@ -1637,17 +1718,29 @@ async fn add_clauses_for_solvables( } }; - for version_set_id in requirements + for (version_set_id, condition) in conditional_requirements .iter() - .flat_map(|requirement| requirement.version_sets(cache.provider())) - .chain(constrains.iter().copied()) + .flat_map(|conditional_requirement| { + conditional_requirement.version_sets_with_condition(cache.provider()) + }) + .chain(constrains.iter().map(|&vs| (vs, None))) { let dependency_name = cache.provider().version_set_name(version_set_id); if clauses_added_for_package.insert(dependency_name) { - tracing::trace!( - "┝━ Adding clauses for package '{}'", - cache.provider().display_name(dependency_name), - ); + if let Some(condition) = condition { + let condition_name = cache.provider().version_set_name(condition); + tracing::trace!( + "┝━ Adding conditional clauses for package '{}' with condition '{}' and version set '{}'", + cache.provider().display_name(dependency_name), + cache.provider().display_name(condition_name), + cache.provider().display_version_set(condition), + ); + } else { + tracing::trace!( + "┝━ Adding clauses for package '{}'", + cache.provider().display_name(dependency_name), + ); + } pending_futures.push( async move { @@ -1660,32 +1753,116 @@ async fn add_clauses_for_solvables( } .boxed_local(), ); + + if let Some(condition) = condition { + let condition_name = cache.provider().version_set_name(condition); + if clauses_added_for_package.insert(condition_name) { + pending_futures.push( + async move { + let condition_candidates = + cache.get_or_cache_candidates(condition_name).await?; + Ok(TaskResult::Candidates { + name_id: condition_name, + package_candidates: condition_candidates, + }) + } + .boxed_local(), + ); + } + } } } - for requirement in requirements { + for conditional_requirement in conditional_requirements { // Find all the solvable that match for the given version set - pending_futures.push( - async move { - let candidates = futures::future::try_join_all( - requirement - .version_sets(cache.provider()) - .map(|version_set| { - cache.get_or_cache_sorted_candidates_for_version_set( - version_set, - ) - }), - ) - .await?; - - Ok(TaskResult::SortedCandidates { - solvable_id, - requirement, - candidates, - }) + let version_sets = + conditional_requirement.requirement_version_sets(cache.provider()); + let candidates = + futures::future::try_join_all(version_sets.map(|version_set| { + cache.get_or_cache_sorted_candidates_for_version_set(version_set) + })) + .await?; + + // condition is `VersionSetId` right now but it will become a `Requirement` + // in the next versions of resolvo + match ( + conditional_requirement.condition, + conditional_requirement.extra, + ) { + (None, Some(extra)) => { + let extra_variable = variable_map.intern_extra(extra); + + // Add a task result for the condition + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: None, + extra: Some((extra_variable, extra)), + candidates: candidates.clone(), + }) + } + .boxed_local(), + ); } - .boxed_local(), - ); + (Some(condition), Some(extra)) => { + let condition_candidates = + cache.get_or_cache_matching_candidates(condition).await?; + let extra_variable = variable_map.intern_extra(extra); + + for &condition_candidate in condition_candidates { + let candidates = candidates.clone(); + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: Some((condition_candidate, condition)), + extra: Some((extra_variable, extra)), + candidates, + }) + } + .boxed_local(), + ); + } + } + (Some(condition), None) => { + let condition_candidates = + cache.get_or_cache_matching_candidates(condition).await?; + + for &condition_candidate in condition_candidates { + let candidates = candidates.clone(); + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: Some((condition_candidate, condition)), + extra: None, + candidates, + }) + } + .boxed_local(), + ); + } + } + (None, None) => { + // Add a task result for the condition + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: None, + extra: None, + candidates: candidates.clone(), + }) + } + .boxed_local(), + ); + } + } } for version_set_id in constrains { @@ -1751,6 +1928,8 @@ async fn add_clauses_for_solvables( TaskResult::SortedCandidates { solvable_id, requirement, + condition, + extra, candidates, } => { tracing::trace!( @@ -1820,30 +1999,153 @@ async fn add_clauses_for_solvables( ); } - // Add the requirements clause - let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); - let (watched_literals, conflict, kind) = WatchedLiterals::requires( - variable, - requirement, - version_set_variables.iter().flatten().copied(), - decision_tracker, - ); - let has_watches = watched_literals.is_some(); - let clause_id = clauses.alloc(watched_literals, kind); + match (condition, extra) { + ( + Some((condition_variable, condition_version_set_id)), + Some((extra_variable, extra_name)), + ) => { + let condition_variable = variable_map.intern_solvable(condition_variable); - if has_watches { - output.clauses_to_watch.push(clause_id); - } + let (watched_literals, conflict, kind) = + WatchedLiterals::conditional_with_extra( + variable, + requirement, + condition_variable, + condition_version_set_id, + extra_variable, + extra_name, + decision_tracker, + version_set_variables.iter().flatten().copied(), + ); + + // Add the conditional clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); + + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output.new_conditional_clauses.push(( + variable, + Some(condition_variable), + Some(extra_variable), + requirement, + clause_id, + )); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } + } + (None, Some((extra_variable, extra_name))) => { + let (watched_literals, conflict, kind) = + WatchedLiterals::requires_with_extra( + variable, + extra_variable, + extra_name, + requirement, + decision_tracker, + version_set_variables.iter().flatten().copied(), + ); + + // Add the requirements clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); + + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output + .new_requires_clauses + .push((variable, requirement, clause_id)); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } + } + (Some((condition, condition_version_set_id)), None) => { + let condition_variable = variable_map.intern_solvable(condition); + + // Add a condition clause + let (watched_literals, conflict, kind) = WatchedLiterals::conditional( + variable, + requirement, + condition_variable, + condition_version_set_id, + decision_tracker, + version_set_variables.iter().flatten().copied(), + ); - output - .new_requires_clauses - .push((variable, requirement, clause_id)); + // Add the conditional clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); - if conflict { - output.conflicting_clauses.push(clause_id); - } else if no_candidates { - // Add assertions for unit clauses (i.e. those with no matching candidates) - output.negative_assertions.push((variable, clause_id)); + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output.new_conditional_clauses.push(( + variable, + Some(condition_variable), + None, + requirement, + clause_id, + )); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } + } + (None, None) => { + let (watched_literals, conflict, kind) = WatchedLiterals::requires( + variable, + requirement, + version_set_variables.iter().flatten().copied(), + decision_tracker, + ); + + // Add the requirements clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); + + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output + .new_requires_clauses + .push((variable, requirement, clause_id)); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } + } } } TaskResult::NonMatchingCandidates { diff --git a/src/solver/variable_map.rs b/src/solver/variable_map.rs index 608ed68..286f187 100644 --- a/src/solver/variable_map.rs +++ b/src/solver/variable_map.rs @@ -7,7 +7,7 @@ use crate::{ arena::ArenaId, id::{SolvableOrRootId, VariableId}, }, - Interner, NameId, SolvableId, + Interner, NameId, SolvableId, StringId, }; /// All variables in the solver are stored in a `VariableMap`. This map is used @@ -23,6 +23,9 @@ pub struct VariableMap { /// A map from solvable id to variable id. solvable_to_variable: HashMap, + /// A map from extra name to variable id. + extra_to_variable: HashMap, + /// Records the origins of all variables. origins: HashMap, } @@ -38,6 +41,9 @@ pub enum VariableOrigin { /// A variable that helps encode an at most one constraint. ForbidMultiple(NameId), + + /// A variable that represents an extra variable. + Extra(StringId), } impl Default for VariableMap { @@ -48,6 +54,7 @@ impl Default for VariableMap { Self { next_id: 1, // The first variable id is 1 because 0 is reserved for the root. solvable_to_variable: HashMap::default(), + extra_to_variable: HashMap::default(), origins, } } @@ -78,6 +85,22 @@ impl VariableMap { } } + /// Allocate a variable that represents an extra variable. + pub fn intern_extra(&mut self, extra_name: StringId) -> VariableId { + match self.extra_to_variable.entry(extra_name) { + Entry::Occupied(entry) => *entry.get(), + Entry::Vacant(entry) => { + let id = self.next_id; + self.next_id += 1; + let variable_id = VariableId::from_usize(id); + entry.insert(variable_id); + self.origins + .insert(variable_id, VariableOrigin::Extra(extra_name)); + variable_id + } + } + } + /// Allocate a variable that helps encode an at most one constraint. pub fn alloc_forbid_multiple_variable(&mut self, name: NameId) -> VariableId { let id = self.next_id; @@ -141,6 +164,9 @@ impl<'i, I: Interner> Display for VariableDisplay<'i, I> { VariableOrigin::ForbidMultiple(name) => { write!(f, "forbid-multiple({})", self.interner.display_name(name)) } + VariableOrigin::Extra(extra_name) => { + write!(f, "extra({})", self.interner.display_string(extra_name)) + } } } } diff --git a/tests/solver.rs b/tests/solver.rs index de15d8a..63a97fd 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -22,9 +22,9 @@ use itertools::Itertools; use resolvo::{ snapshot::{DependencySnapshot, SnapshotProvider}, utils::Pool, - Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, Problem, - Requirement, SolvableId, Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, - VersionSetUnionId, + Candidates, ConditionalRequirement, Dependencies, DependencyProvider, Interner, + KnownDependencies, NameId, Problem, Requirement, SolvableId, Solver, SolverCache, StringId, + UnsolvableOrCancelled, VersionSetId, VersionSetUnionId, }; use tracing_test::traced_test; use version_ranges::Ranges; @@ -113,19 +113,22 @@ impl FromStr for Pack { struct Spec { name: String, versions: Ranges, + condition: Option>, } impl Spec { - pub fn new(name: String, versions: Ranges) -> Self { - Self { name, versions } + pub fn new(name: String, versions: Ranges, condition: Option>) -> Self { + Self { + name, + versions, + condition, + } } pub fn parse_union( spec: &str, ) -> impl Iterator::Err>> + '_ { - spec.split('|') - .map(str::trim) - .map(|dep| Spec::from_str(dep)) + spec.split('|').map(str::trim).map(Spec::from_str) } } @@ -133,11 +136,23 @@ impl FromStr for Spec { type Err = (); fn from_str(s: &str) -> Result { - let split = s.split(' ').collect::>(); - let name = split - .first() - .expect("spec does not have a name") - .to_string(); + let split = s.split_once("; if"); + + if split.is_none() { + let split = s.split(' ').collect::>(); + let name = split + .first() + .expect("spec does not have a name") + .to_string(); + let versions = version_range(split.get(1)); + return Ok(Spec::new(name, versions, None)); + } + + let (spec, condition) = split.unwrap(); + + let condition = Spec::parse_union(condition).next().unwrap().unwrap(); + + let spec = Spec::from_str(spec).unwrap(); fn version_range(s: Option<&&str>) -> Ranges { if let Some(s) = s { @@ -156,9 +171,11 @@ impl FromStr for Spec { } } - let versions = version_range(split.get(1)); - - Ok(Spec::new(name, versions)) + Ok(Spec::new( + spec.name, + spec.versions, + Some(Box::new(condition)), + )) } } @@ -200,11 +217,19 @@ impl BundleBoxProvider { .expect("package missing") } - pub fn requirements>(&self, requirements: &[&str]) -> Vec { + pub fn requirements)>>( + &self, + requirements: &[&str], + ) -> Vec { requirements .iter() .map(|dep| Spec::from_str(dep).unwrap()) - .map(|spec| self.intern_version_set(&spec)) + .map(|spec| { + ( + self.intern_version_set(&spec), + spec.condition.as_ref().map(|c| self.intern_version_set(c)), + ) + }) .map(From::from) .collect() } @@ -386,7 +411,7 @@ impl DependencyProvider for BundleBoxProvider { candidates .iter() .copied() - .filter(|s| range.contains(&self.pool.resolve_solvable(*s).record) == !inverse) + .filter(|s| range.contains(&self.pool.resolve_solvable(*s).record) != inverse) .collect() } @@ -502,18 +527,44 @@ impl DependencyProvider for BundleBoxProvider { .intern_version_set(first_name, first.versions.clone()); let requirement = if remaining_req_specs.len() == 0 { - first_version_set.into() + if let Some(condition) = &first.condition { + ConditionalRequirement::new( + Some(self.intern_version_set(condition)), + first_version_set.into(), + ) + } else { + first_version_set.into() + } } else { - let other_version_sets = remaining_req_specs.map(|spec| { - self.pool.intern_version_set( + // Check if all specs have the same condition + let common_condition = first.condition.as_ref().map(|c| self.intern_version_set(c)); + + // Collect version sets for union + let mut version_sets = vec![first_version_set]; + for spec in remaining_req_specs { + // Verify condition matches + if spec.condition.as_ref().map(|c| self.intern_version_set(c)) + != common_condition + { + panic!("All specs in a union must have the same condition"); + } + + version_sets.push(self.pool.intern_version_set( self.pool.intern_package_name(&spec.name), spec.versions.clone(), - ) - }); - - self.pool - .intern_version_set_union(first_version_set, other_version_sets) - .into() + )); + } + + // Create union and wrap in conditional if needed + let union = self + .pool + .intern_version_set_union(version_sets[0], version_sets.into_iter().skip(1)); + + if let Some(condition) = common_condition { + ConditionalRequirement::new(Some(condition), union.into()) + } else { + union.into() + } }; result.requirements.push(requirement); @@ -538,7 +589,7 @@ impl DependencyProvider for BundleBoxProvider { } /// Create a string from a [`Transaction`] -fn transaction_to_string(interner: &impl Interner, solvables: &Vec) -> String { +fn transaction_to_string(interner: &impl Interner, solvables: &[SolvableId]) -> String { use std::fmt::Write; let mut buf = String::new(); for solvable in solvables @@ -590,7 +641,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { let requirements = provider.parse_requirements(specs); let mut solver = Solver::new(provider).with_runtime(runtime); - let problem = Problem::new().requirements(requirements); + let problem = Problem::new().requirements(requirements.into_iter().map(|r| r.into()).collect()); match solver.solve(problem) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), Err(UnsolvableOrCancelled::Unsolvable(conflict)) => { @@ -1429,6 +1480,238 @@ fn test_explicit_root_requirements() { "###); } +#[test] +#[traced_test] +fn test_conditional_requirements() { + let mut provider = BundleBoxProvider::new(); + + // Add packages + provider.add_package("a", 1.into(), &["b"], &[]); // a depends on b + provider.add_package("b", 1.into(), &[], &[]); // Simple package b + provider.add_package("c", 1.into(), &[], &[]); // Simple package c + + // Create problem with both regular and conditional requirements + let requirements = provider.requirements(&["a", "c 1; if b 1..2"]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + "###); +} + +#[test] +#[traced_test] +fn test_conditional_requirements_not_met() { + let mut provider = BundleBoxProvider::new(); + provider.add_package("b", 1.into(), &[], &[]); // Add b=1 as a candidate + provider.add_package("b", 2.into(), &[], &[]); // Different version of b + provider.add_package("c", 1.into(), &[], &[]); // Simple package c + provider.add_package("a", 1.into(), &["b 2"], &[]); // a depends on b=2 specifically + + // Create conditional requirement: if b=1 is installed, require c + let requirements = provider.requirements(&[ + "a", // Require package a + "c 1; if b 1", // If b=1 is installed, require c (note the exact version) + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (not b=1), c should not be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + "###); +} + +#[test] +fn test_nested_conditional_dependencies() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[]); // Base package + provider.add_package("b", 1.into(), &[], &[]); // First level conditional + provider.add_package("c", 1.into(), &[], &[]); // Second level conditional + provider.add_package("d", 1.into(), &[], &[]); // Third level conditional + + // Create nested conditional requirements using the parser + let requirements = provider.requirements(&[ + "a", // Require package a + "b 1; if a 1", // If a is installed, require b + "c 1; if b 1", // If b is installed, require c + "d 1; if c 1", // If c is installed, require d + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // All packages should be installed due to chain reaction + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + d=1 + "###); +} + +#[test] +fn test_multiple_conditions_same_package() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[]); + provider.add_package("b", 1.into(), &[], &[]); + provider.add_package("c", 1.into(), &[], &[]); + provider.add_package("target", 1.into(), &[], &[]); + + // Create multiple conditions that all require the same package using the parser + let requirements = provider.requirements(&[ + "b", // Only require package b + "target 1; if a 1", // If a is installed, require target + "target 1; if b 1", // If b is installed, require target + "target 1; if c 1", // If c is installed, require target + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // b and target should be installed + insta::assert_snapshot!(result, @r###" + b=1 + target=1 + "###); +} + +#[test] +fn test_circular_conditional_dependencies() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[]); + provider.add_package("b", 1.into(), &[], &[]); + + // Create circular conditional dependencies using the parser + let requirements = provider.requirements(&[ + "a", // Require package a + "b 1; if a 1", // If a is installed, require b + "a 1; if b 1", // If b is installed, require a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Both packages should be installed due to circular dependency + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + "###); +} + +#[test] +fn test_conditional_requirements_multiple_versions() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[]); + provider.add_package("b", 2.into(), &[], &[]); + provider.add_package("b", 3.into(), &[], &[]); + provider.add_package("b", 4.into(), &[], &[]); + provider.add_package("b", 5.into(), &[], &[]); + + provider.add_package("c", 1.into(), &[], &[]); // Simple package c + provider.add_package("a", 1.into(), &["b 4..6"], &[]); // a depends on b versions 4-5 + + // Create conditional requirement: if b=1..3 is installed, require c + let requirements = provider.requirements(&[ + "a", // Require package a + "c 1; if b 1..3", // If b is version 1-2, require c + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=4 is installed (not b 1..3), c should not be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=5 + "###); +} + +#[test] +fn test_conditional_requirements_multiple_versions_met() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[]); + provider.add_package("b", 2.into(), &[], &[]); + provider.add_package("b", 3.into(), &[], &[]); + provider.add_package("b", 4.into(), &[], &[]); + provider.add_package("b", 5.into(), &[], &[]); + + provider.add_package("c", 1.into(), &[], &[]); // Simple package c + provider.add_package("c", 2.into(), &[], &[]); // Version 2 of c + provider.add_package("c", 3.into(), &[], &[]); // Version 3 of c + provider.add_package("a", 1.into(), &["b 1..3", "c 1..3; if b 1..3"], &[]); // a depends on b 1-3 and conditionally on c 1-3 + + let requirements = provider.requirements(&[ + "a", // Require package a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (within b 1..2), c should be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + c=2 + "###); +} + +/// In this test, the resolver installs the highest available version of b which is b 2 +/// However, the conditional requirement is that if b 1..2 is installed, require c +/// Since b 2 is installed, c should not be installed +#[test] +fn test_conditional_requirements_multiple_versions_not_met() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[]); + provider.add_package("b", 2.into(), &[], &[]); + provider.add_package("b", 3.into(), &[], &[]); + provider.add_package("b", 4.into(), &[], &[]); + provider.add_package("b", 5.into(), &[], &[]); + + provider.add_package("c", 1.into(), &[], &[]); // Simple package c + provider.add_package("c", 2.into(), &[], &[]); // Version 2 of c + provider.add_package("c", 3.into(), &[], &[]); // Version 3 of c + provider.add_package("a", 1.into(), &["b 1..3", "c 1..3; if b 1..2"], &[]); // a depends on b 1-3 and conditionally on c 1-3 + + let requirements = provider.requirements(&[ + "a", // Require package a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (within b 1..2), c should be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + "###); +} #[cfg(feature = "serde")] fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef) { let file = std::io::BufWriter::new(std::fs::File::create(destination.as_ref()).unwrap()); diff --git a/tools/solve-snapshot/src/main.rs b/tools/solve-snapshot/src/main.rs index 901996c..3629eaf 100644 --- a/tools/solve-snapshot/src/main.rs +++ b/tools/solve-snapshot/src/main.rs @@ -128,7 +128,8 @@ fn main() { let start = Instant::now(); - let problem = Problem::default().requirements(requirements); + let problem = + Problem::default().requirements(requirements.into_iter().map(Into::into).collect()); let mut solver = Solver::new(provider); let mut records = None; let mut error = None;