From 8c60946f5fd9104d0bf0bd3f7e44ece52366700c Mon Sep 17 00:00:00 2001 From: prsabahrami Date: Wed, 29 Jan 2025 13:56:43 -0500 Subject: [PATCH] initial commit for optional dependencies support in resolvo --- cpp/src/lib.rs | 40 +++++ src/conflict.rs | 2 + src/requirement.rs | 15 +- src/solver/clause.rs | 233 +++++++++++++++++++++++++++ src/solver/mod.rs | 311 +++++++++++++++++++++++++++---------- src/solver/variable_map.rs | 28 +++- 6 files changed, 548 insertions(+), 81 deletions(-) diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index a35b576..8f4a393 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -31,6 +31,41 @@ impl From for resolvo::SolvableId { } } +/// A wrapper around an optional string id. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct FfiOptionStringId { + pub is_some: bool, + pub value: StringId, +} + +impl From> for FfiOptionStringId { + fn from(opt: Option) -> Self { + match opt { + Some(v) => Self { + is_some: true, + value: v.into(), + }, + None => Self { + is_some: false, + value: StringId { id: 0 }, + }, + } + } +} + +impl From for Option { + fn from(ffi: FfiOptionStringId) -> Self { + if ffi.is_some { + Some(ffi.value.into()) + } else { + None + } + } +} + /// A wrapper around an optional version set id. /// cbindgen:derive-eq /// cbindgen:derive-neq @@ -100,6 +135,7 @@ impl From for Option { pub struct ConditionalRequirement { pub condition: FfiOptionVersionSetId, pub requirement: Requirement, + pub extra: FfiOptionStringId, } impl From for ConditionalRequirement { @@ -107,6 +143,7 @@ impl From for ConditionalRequirement { Self { condition: value.condition.into(), requirement: value.requirement.into(), + extra: value.extra.into(), } } } @@ -116,6 +153,7 @@ impl From for resolvo::ConditionalRequirement { Self { condition: value.condition.into(), requirement: value.requirement.into(), + extra: value.extra.into(), } } } @@ -622,6 +660,7 @@ pub extern "C" fn resolvo_conditional_requirement_single( ConditionalRequirement { condition: Option::::None.into(), requirement: Requirement::Single(version_set_id), + extra: None.into(), } } @@ -633,6 +672,7 @@ pub extern "C" fn resolvo_conditional_requirement_union( ConditionalRequirement { condition: Option::::None.into(), requirement: Requirement::Union(version_set_union_id), + extra: None.into(), } } diff --git a/src/conflict.rs b/src/conflict.rs index 428caf8..1656209 100644 --- a/src/conflict.rs +++ b/src/conflict.rs @@ -212,6 +212,8 @@ impl Conflict { } } } + &Clause::RequiresWithExtra(..) => todo!(), + &Clause::ConditionalWithExtra(..) => todo!(), } } diff --git a/src/requirement.rs b/src/requirement.rs index 9641e27..b249262 100644 --- a/src/requirement.rs +++ b/src/requirement.rs @@ -1,4 +1,4 @@ -use crate::{Interner, VersionSetId, VersionSetUnionId}; +use crate::{Interner, StringId, VersionSetId, VersionSetUnionId}; use itertools::Itertools; use std::fmt::Display; @@ -10,14 +10,21 @@ pub struct ConditionalRequirement { pub condition: Option, /// The requirement that is only active when the condition is met. pub requirement: Requirement, + /// The extra that must be enabled for the requirement to be active. + pub extra: Option, } impl ConditionalRequirement { /// Creates a new conditional requirement. - pub fn new(condition: Option, requirement: Requirement) -> Self { + pub fn new( + condition: Option, + requirement: Requirement, + extra: Option, + ) -> Self { Self { condition, requirement, + extra, } } /// Returns the version sets that satisfy the requirement. @@ -49,6 +56,7 @@ impl From for ConditionalRequirement { Self { condition: None, requirement: value, + extra: None, } } } @@ -58,6 +66,7 @@ impl From for ConditionalRequirement { Self { condition: None, requirement: value.into(), + extra: None, } } } @@ -67,6 +76,7 @@ impl From for ConditionalRequirement { Self { condition: None, requirement: value.into(), + extra: None, } } } @@ -76,6 +86,7 @@ impl From<(VersionSetId, Option)> for ConditionalRequirement { Self { condition, requirement: requirement.into(), + extra: None, } } } diff --git a/src/solver/clause.rs b/src/solver/clause.rs index 99bb693..038aaa1 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -83,6 +83,25 @@ pub(crate) enum Clause { /// We need to store the version set id because in the conflict graph, the version set id /// is used to identify the condition variable. Conditional(VariableId, VariableId, VersionSetId, Requirement), + /// A conditional clause that requires a feature to be enabled for the requirement to be active. + /// + /// In SAT terms: (¬A ∨ ¬C ∨ ¬F ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, + /// C is the condition, F is the feature, and B1 to B99 represent the possible candidates for + /// the provided [`Requirement`]. + ConditionalWithExtra( + VariableId, // solvable + VariableId, // condition + VersionSetId, // condition version set + VariableId, // extra + StringId, // extra name + Requirement, // requirement + ), + /// A requirement that requires a feature to be enabled for the requirement to be active. + /// + /// In SAT terms: (¬A ∨ ¬F ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, + /// F is the feature, and B1 to B99 represent the possible candidates for + /// the provided [`Requirement`]. + RequiresWithExtra(VariableId, VariableId, StringId, Requirement), /// Forbids the package on the right-hand side /// /// Note that the package on the left-hand side is not part of the clause, @@ -273,6 +292,100 @@ impl Clause { ) } + fn conditional_with_extra( + parent_id: VariableId, + requirement: Requirement, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, + extra_variable: VariableId, + extra_name: StringId, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Self, Option<[Literal; 2]>, bool) { + assert_ne!(decision_tracker.assigned_value(parent_id), Some(false)); + let mut requirement_candidates = requirement_candidates.into_iter(); + + let requirement_literal = + if decision_tracker.assigned_value(condition_variable) == Some(true) { + // then ~condition is false + // if the feature is enabled, then we need to watch the requirement candidates + if decision_tracker.assigned_value(extra_variable) == Some(true) { + requirement_candidates + .find(|&id| decision_tracker.assigned_value(id) != Some(false)) + .map(|id| id.positive()) + } else { + // if the feature is disabled, then we need to watch the feature variable + Some(extra_variable.negative()) + } + } else { + None + }; + ( + Clause::ConditionalWithExtra( + parent_id, + condition_variable, + condition_version_set_id, + extra_variable, + extra_name, + requirement, + ), + Some([ + parent_id.negative(), + requirement_literal.unwrap_or(condition_variable.negative()), + ]), + requirement_literal.is_none() + && decision_tracker.assigned_value(condition_variable) == Some(true) + && decision_tracker.assigned_value(extra_variable) == Some(true), + ) + } + + fn requires_with_extra( + solvable_id: VariableId, + extra_variable: VariableId, + extra_name: StringId, + requirement: Requirement, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Self, Option<[Literal; 2]>, bool) { + // It only makes sense to introduce a requires clause when the parent solvable + // is undecided or going to be installed + assert_ne!(decision_tracker.assigned_value(solvable_id), Some(false)); + + let kind = Clause::RequiresWithExtra(solvable_id, extra_variable, extra_name, requirement); + let mut candidates = requirement_candidates.into_iter().peekable(); + let first_candidate = candidates.peek().copied(); + + if decision_tracker.assigned_value(extra_variable) == Some(true) { + // Feature is enabled, so watch the requirement candidates + if let Some(first_candidate) = first_candidate { + match candidates.find(|&c| decision_tracker.assigned_value(c) != Some(false)) { + // Watch any candidate that is not assigned to false + Some(watched_candidate) => ( + kind, + Some([solvable_id.negative(), watched_candidate.positive()]), + false, + ), + // All candidates are assigned to false - conflict + None => ( + kind, + Some([solvable_id.negative(), first_candidate.positive()]), + true, + ), + } + } else { + // No candidates available + (kind, None, false) + } + } else { + // Feature is not enabled, so watch the feature variable + ( + kind, + Some([solvable_id.negative(), extra_variable.negative()]), + false, + ) + } + } + /// Tries to fold over all the literals in the clause. /// /// This function is useful to iterate, find, or filter the literals in a @@ -326,6 +439,35 @@ impl Clause { ) .try_fold(init, visit) } + Clause::ConditionalWithExtra( + package_id, + condition_variable, + _, + extra_variable, + _, + requirement, + ) => iter::once(package_id.negative()) + .chain(iter::once(condition_variable.negative())) + .chain(iter::once(extra_variable.negative())) + .chain( + requirements_to_sorted_candidates[&requirement] + .iter() + .flatten() + .map(|&s| s.positive()), + ) + .try_fold(init, visit), + Clause::RequiresWithExtra(solvable_id, extra_variable, _, requirement) => { + iter::once(solvable_id.negative()) + .chain(iter::once(solvable_id.negative())) + .chain(iter::once(extra_variable.negative())) + .chain( + requirements_to_sorted_candidates[&requirement] + .iter() + .flatten() + .map(|&s| s.positive()), + ) + .try_fold(init, visit) + } } } @@ -502,6 +644,66 @@ impl WatchedLiterals { ) } + /// Shorthand method to construct a [Clause::ConditionalWithExtra] without requiring + /// complicated arguments. + /// + /// The returned boolean value is true when adding the clause resulted in a + /// conflict. + pub fn conditional_with_extra( + package_id: VariableId, + requirement: Requirement, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, + extra_variable: VariableId, + extra_name: StringId, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Option, bool, Clause) { + let (kind, watched_literals, conflict) = Clause::conditional_with_extra( + package_id, + requirement, + condition_variable, + condition_version_set_id, + extra_variable, + extra_name, + decision_tracker, + requirement_candidates, + ); + ( + WatchedLiterals::from_kind_and_initial_watches(watched_literals), + conflict, + kind, + ) + } + + /// Shorthand method to construct a [Clause::RequiresWithExtra] without requiring + /// complicated arguments. + /// + /// The returned boolean value is true when adding the clause resulted in a + /// conflict. + pub fn requires_with_extra( + solvable_id: VariableId, + extra_variable: VariableId, + extra_name: StringId, + requirement: Requirement, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Option, bool, Clause) { + let (kind, watched_literals, conflict) = Clause::requires_with_extra( + solvable_id, + extra_variable, + extra_name, + requirement, + decision_tracker, + requirement_candidates, + ); + ( + WatchedLiterals::from_kind_and_initial_watches(watched_literals), + conflict, + kind, + ) + } + fn from_kind_and_initial_watches(watched_literals: Option<[Literal; 2]>) -> Option { let watched_literals = watched_literals?; debug_assert!(watched_literals[0] != watched_literals[1]); @@ -705,6 +907,37 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { requirement.display(self.interner), ) } + Clause::ConditionalWithExtra( + package_id, + condition_variable, + _, + extra_variable, + _, + requirement, + ) => { + write!( + f, + "ConditionalWithExtra({}({:?}), {}({:?}), {}({:?}), {})", + package_id.display(self.variable_map, self.interner), + package_id, + condition_variable.display(self.variable_map, self.interner), + condition_variable, + extra_variable.display(self.variable_map, self.interner), + extra_variable, + requirement.display(self.interner), + ) + } + Clause::RequiresWithExtra(solvable_id, extra_variable, _, requirement) => { + write!( + f, + "RequiresWithExtra({}({:?}), {}({:?}), {})", + solvable_id.display(self.variable_map, self.interner), + solvable_id, + extra_variable.display(self.variable_map, self.interner), + extra_variable, + requirement.display(self.interner), + ) + } } } } diff --git a/src/solver/mod.rs b/src/solver/mod.rs index bda9bd4..9e2f3d8 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -22,7 +22,8 @@ use crate::{ requirement::ConditionalRequirement, runtime::{AsyncRuntime, NowOrNeverRuntime}, solver::binary_encoding::AtMostOnceTracker, - Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, VersionSetId, + Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, StringId, + VersionSetId, }; mod binary_encoding; @@ -37,7 +38,15 @@ mod watch_map; #[derive(Default)] struct AddClauseOutput { new_requires_clauses: Vec<(VariableId, Requirement, ClauseId)>, - new_conditional_clauses: Vec<(VariableId, VariableId, Requirement, ClauseId)>, + /// A vector of tuples from a solvable variable, conditional variable and extra(feature name) variable to + /// the clauses that need to be watched. + new_conditional_clauses: Vec<( + VariableId, + Option, + Option, + Requirement, + ClauseId, + )>, conflicting_clauses: Vec, negative_assertions: Vec<(VariableId, ClauseId)>, clauses_to_watch: Vec, @@ -152,8 +161,14 @@ pub struct Solver { pub(crate) clauses: Clauses, requires_clauses: IndexMap, ahash::RandomState>, - conditional_clauses: - IndexMap<(VariableId, VariableId), Vec<(Requirement, ClauseId)>, ahash::RandomState>, + + /// A map from a solvable variable, conditional variable and extra(feature name) variable to + /// the clauses that need to be watched. + conditional_clauses: IndexMap< + (VariableId, Option, Option), + Vec<(Requirement, ClauseId)>, + ahash::RandomState, + >, watches: WatchMap, /// A mapping from requirements to the variables that represent the @@ -666,11 +681,11 @@ impl Solver { .push((requirement, clause_id)); } - for (solvable_id, condition_variable, requirement, clause_id) in + for (solvable_id, condition_variable, extra_variable, requirement, clause_id) in output.new_conditional_clauses { self.conditional_clauses - .entry((solvable_id, condition_variable)) + .entry((solvable_id, condition_variable, extra_variable)) .or_default() .push((requirement, clause_id)); } @@ -791,6 +806,7 @@ impl Solver { ( solvable_id, None, + None, requirements .iter() .map(|(r, c)| (*r, *c)) @@ -801,15 +817,16 @@ impl Solver { let conditional_iter = self.conditional_clauses .iter() - .map(|((solvable_id, condition), clauses)| { + .map(|((solvable_id, condition, extra), clauses)| { ( *solvable_id, - Some(*condition), + *condition, + *extra, clauses.iter().map(|(r, c)| (*r, *c)).collect::>(), ) }); - for (solvable_id, condition, requirements) in requires_iter.chain(conditional_iter) { + for (solvable_id, condition, extra, requirements) in requires_iter.chain(conditional_iter) { let is_explicit_requirement = solvable_id == VariableId::root(); if let Some(best_decision) = &best_decision { @@ -837,6 +854,13 @@ impl Solver { } } + if let Some(extra_variable) = extra { + // If the extra is not enabled, skip this requirement entirely + if self.decision_tracker.assigned_value(extra_variable) != Some(true) { + continue; + } + } + for (requirement, clause_id) in requirements { let mut candidate = ControlFlow::Break(()); @@ -1589,6 +1613,7 @@ async fn add_clauses_for_solvables( solvable_id: SolvableOrRootId, requirement: Requirement, condition: Option<(SolvableId, VersionSetId)>, + extra: Option<(VariableId, StringId)>, candidates: Vec<&'i [SolvableId]>, }, NonMatchingCandidates { @@ -1759,37 +1784,83 @@ async fn add_clauses_for_solvables( // condition is `VersionSetId` right now but it will become a `Requirement` // in the next versions of resolvo - if let Some(condition) = conditional_requirement.condition { - let condition_candidates = - cache.get_or_cache_matching_candidates(condition).await?; - - for &condition_candidate in condition_candidates { - let candidates = candidates.clone(); + match ( + conditional_requirement.condition, + conditional_requirement.extra, + ) { + (None, Some(extra)) => { + let extra_variable = variable_map.intern_extra(extra); + + // Add a task result for the condition pending_futures.push( async move { Ok(TaskResult::SortedCandidates { solvable_id, requirement: conditional_requirement.requirement, - condition: Some((condition_candidate, condition)), - candidates, + condition: None, + extra: Some((extra_variable, extra)), + candidates: candidates.clone(), }) } .boxed_local(), ); } - } else { - // Add a task result for the condition - pending_futures.push( - async move { - Ok(TaskResult::SortedCandidates { - solvable_id, - requirement: conditional_requirement.requirement, - condition: None, - candidates: candidates.clone(), - }) + (Some(condition), Some(extra)) => { + let condition_candidates = + cache.get_or_cache_matching_candidates(condition).await?; + let extra_variable = variable_map.intern_extra(extra); + + for &condition_candidate in condition_candidates { + let candidates = candidates.clone(); + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: Some((condition_candidate, condition)), + extra: Some((extra_variable, extra)), + candidates, + }) + } + .boxed_local(), + ); } - .boxed_local(), - ); + } + (Some(condition), None) => { + let condition_candidates = + cache.get_or_cache_matching_candidates(condition).await?; + + for &condition_candidate in condition_candidates { + let candidates = candidates.clone(); + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: Some((condition_candidate, condition)), + extra: None, + candidates, + }) + } + .boxed_local(), + ); + } + } + (None, None) => { + // Add a task result for the condition + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: None, + extra: None, + candidates: candidates.clone(), + }) + } + .boxed_local(), + ); + } } } @@ -1857,6 +1928,7 @@ async fn add_clauses_for_solvables( solvable_id, requirement, condition, + extra, candidates, } => { tracing::trace!( @@ -1926,69 +1998,152 @@ async fn add_clauses_for_solvables( ); } - if let Some((condition, condition_version_set_id)) = condition { - let condition_variable = variable_map.intern_solvable(condition); + match (condition, extra) { + ( + Some((condition_variable, condition_version_set_id)), + Some((extra_variable, extra_name)), + ) => { + let condition_variable = variable_map.intern_solvable(condition_variable); + + let (watched_literals, conflict, kind) = + WatchedLiterals::conditional_with_extra( + variable, + requirement, + condition_variable, + condition_version_set_id, + extra_variable, + extra_name, + decision_tracker, + version_set_variables.iter().flatten().copied(), + ); - // Add a condition clause - let (watched_literals, conflict, kind) = WatchedLiterals::conditional( - variable, - requirement, - condition_variable, - condition_version_set_id, - decision_tracker, - version_set_variables.iter().flatten().copied(), - ); + // Add the conditional clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); - // Add the conditional clause - let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); - let has_watches = watched_literals.is_some(); - let clause_id = clauses.alloc(watched_literals, kind); + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output.new_conditional_clauses.push(( + variable, + Some(condition_variable), + Some(extra_variable), + requirement, + clause_id, + )); - if has_watches { - output.clauses_to_watch.push(clause_id); + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } } + (None, Some((extra_variable, extra_name))) => { + let (watched_literals, conflict, kind) = + WatchedLiterals::requires_with_extra( + variable, + extra_variable, + extra_name, + requirement, + decision_tracker, + version_set_variables.iter().flatten().copied(), + ); - output.new_conditional_clauses.push(( - variable, - condition_variable, - requirement, - clause_id, - )); + // Add the requirements clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); - if conflict { - output.conflicting_clauses.push(clause_id); - } else if no_candidates { - // Add assertions for unit clauses (i.e. those with no matching candidates) - output.negative_assertions.push((variable, clause_id)); + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output + .new_requires_clauses + .push((variable, requirement, clause_id)); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } } - } else { - let (watched_literals, conflict, kind) = WatchedLiterals::requires( - variable, - requirement, - version_set_variables.iter().flatten().copied(), - decision_tracker, - ); + (Some((condition, condition_version_set_id)), None) => { + let condition_variable = variable_map.intern_solvable(condition); + + // Add a condition clause + let (watched_literals, conflict, kind) = WatchedLiterals::conditional( + variable, + requirement, + condition_variable, + condition_version_set_id, + decision_tracker, + version_set_variables.iter().flatten().copied(), + ); - // Add the requirements clause - let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); + // Add the conditional clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); - let has_watches = watched_literals.is_some(); - let clause_id = clauses.alloc(watched_literals, kind); + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output.new_conditional_clauses.push(( + variable, + Some(condition_variable), + None, + requirement, + clause_id, + )); - if has_watches { - output.clauses_to_watch.push(clause_id); + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } } + (None, None) => { + let (watched_literals, conflict, kind) = WatchedLiterals::requires( + variable, + requirement, + version_set_variables.iter().flatten().copied(), + decision_tracker, + ); - output - .new_requires_clauses - .push((variable, requirement, clause_id)); + // Add the requirements clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); - if conflict { - output.conflicting_clauses.push(clause_id); - } else if no_candidates { - // Add assertions for unit clauses (i.e. those with no matching candidates) - output.negative_assertions.push((variable, clause_id)); + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output + .new_requires_clauses + .push((variable, requirement, clause_id)); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } } } } diff --git a/src/solver/variable_map.rs b/src/solver/variable_map.rs index 608ed68..286f187 100644 --- a/src/solver/variable_map.rs +++ b/src/solver/variable_map.rs @@ -7,7 +7,7 @@ use crate::{ arena::ArenaId, id::{SolvableOrRootId, VariableId}, }, - Interner, NameId, SolvableId, + Interner, NameId, SolvableId, StringId, }; /// All variables in the solver are stored in a `VariableMap`. This map is used @@ -23,6 +23,9 @@ pub struct VariableMap { /// A map from solvable id to variable id. solvable_to_variable: HashMap, + /// A map from extra name to variable id. + extra_to_variable: HashMap, + /// Records the origins of all variables. origins: HashMap, } @@ -38,6 +41,9 @@ pub enum VariableOrigin { /// A variable that helps encode an at most one constraint. ForbidMultiple(NameId), + + /// A variable that represents an extra variable. + Extra(StringId), } impl Default for VariableMap { @@ -48,6 +54,7 @@ impl Default for VariableMap { Self { next_id: 1, // The first variable id is 1 because 0 is reserved for the root. solvable_to_variable: HashMap::default(), + extra_to_variable: HashMap::default(), origins, } } @@ -78,6 +85,22 @@ impl VariableMap { } } + /// Allocate a variable that represents an extra variable. + pub fn intern_extra(&mut self, extra_name: StringId) -> VariableId { + match self.extra_to_variable.entry(extra_name) { + Entry::Occupied(entry) => *entry.get(), + Entry::Vacant(entry) => { + let id = self.next_id; + self.next_id += 1; + let variable_id = VariableId::from_usize(id); + entry.insert(variable_id); + self.origins + .insert(variable_id, VariableOrigin::Extra(extra_name)); + variable_id + } + } + } + /// Allocate a variable that helps encode an at most one constraint. pub fn alloc_forbid_multiple_variable(&mut self, name: NameId) -> VariableId { let id = self.next_id; @@ -141,6 +164,9 @@ impl<'i, I: Interner> Display for VariableDisplay<'i, I> { VariableOrigin::ForbidMultiple(name) => { write!(f, "forbid-multiple({})", self.interner.display_name(name)) } + VariableOrigin::Extra(extra_name) => { + write!(f, "extra({})", self.interner.display_string(extra_name)) + } } } }