From 0effd969a424a903e895e8d2551dae28a2e2c91d Mon Sep 17 00:00:00 2001 From: Ilya Lesokhin Date: Tue, 29 Apr 2025 10:58:43 +0300 Subject: [PATCH] Add SpecializedFunction. --- crates/cairo-lang-lowering/src/cache/mod.rs | 1 + .../cairo-lang-lowering/src/concretize/mod.rs | 1 + crates/cairo-lang-lowering/src/db.rs | 85 ++++++++++-- crates/cairo-lang-lowering/src/ids.rs | 121 +++++++++++++++--- crates/cairo-lang-lowering/src/inline/mod.rs | 3 + crates/cairo-lang-lowering/src/lower/mod.rs | 3 + .../src/lower/specialized_test.rs | 64 +++++++++ .../src/lower/test_data/specialized | 33 +++++ 8 files changed, 286 insertions(+), 25 deletions(-) create mode 100644 crates/cairo-lang-lowering/src/lower/specialized_test.rs create mode 100644 crates/cairo-lang-lowering/src/lower/test_data/specialized diff --git a/crates/cairo-lang-lowering/src/cache/mod.rs b/crates/cairo-lang-lowering/src/cache/mod.rs index aa6f259db01..3cb9d91b4ec 100644 --- a/crates/cairo-lang-lowering/src/cache/mod.rs +++ b/crates/cairo-lang-lowering/src/cache/mod.rs @@ -1393,6 +1393,7 @@ impl FunctionCached { FunctionLongId::Generated(id) => { FunctionCached::Generated(GeneratedFunctionCached::new(id, ctx)) } + FunctionLongId::Specialized(_) => todo!(), } } fn embed(self, ctx: &mut CacheLoadingContext<'_>) -> FunctionId { diff --git a/crates/cairo-lang-lowering/src/concretize/mod.rs b/crates/cairo-lang-lowering/src/concretize/mod.rs index 19bed2ee20a..fa23a5bc3ad 100644 --- a/crates/cairo-lang-lowering/src/concretize/mod.rs +++ b/crates/cairo-lang-lowering/src/concretize/mod.rs @@ -24,6 +24,7 @@ fn concretize_function( }) .intern(db)) } + FunctionLongId::Specialized(_) => unreachable!("This should not be called."), } } diff --git a/crates/cairo-lang-lowering/src/db.rs b/crates/cairo-lang-lowering/src/db.rs index 7e69a6d22dc..02962372aa5 100644 --- a/crates/cairo-lang-lowering/src/db.rs +++ b/crates/cairo-lang-lowering/src/db.rs @@ -16,11 +16,11 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; use cairo_lang_utils::unordered_hash_set::UnorderedHashSet; use cairo_lang_utils::{Intern, LookupIntern, Upcast}; use defs::ids::NamedLanguageElementId; -use itertools::{Itertools, chain}; +use itertools::{Itertools, chain, zip_eq}; use num_traits::ToPrimitive; use crate::add_withdraw_gas::add_withdraw_gas; -use crate::blocks::Blocks; +use crate::blocks::{Blocks, BlocksBuilder}; use crate::borrow_check::{ PotentialDestructCalls, borrow_check, borrow_check_possible_withdraw_gas, }; @@ -29,8 +29,9 @@ use crate::concretize::concretize_lowered; use crate::destructs::add_destructs; use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind}; use crate::graph_algorithms::feedback_set::flag_add_withdraw_gas; -use crate::ids::{FunctionId, FunctionLongId}; +use crate::ids::{FunctionId, FunctionLongId, LocationId}; use crate::inline::get_inline_diagnostics; +use crate::lower::context::{VarRequest, VariableAllocator}; use crate::lower::{MultiLowering, lower_semantic_function}; use crate::optimizations::config::OptimizationConfig; use crate::optimizations::scrub_units::scrub_units; @@ -38,7 +39,8 @@ use crate::optimizations::strategy::{OptimizationStrategy, OptimizationStrategyI use crate::panic::lower_panics; use crate::utils::InliningStrategy; use crate::{ - BlockId, DependencyType, FlatBlockEnd, FlatLowered, Location, MatchInfo, Statement, ids, + BlockId, DependencyType, FlatBlockEnd, FlatLowered, Location, MatchInfo, Statement, + StatementCall, StatementConst, VarUsage, VariableId, ids, }; // Salsa database interface. @@ -432,6 +434,9 @@ fn priv_function_with_body_lowering( ids::FunctionWithBodyLongId::Generated { key, .. } => { multi_lowering.generated_lowerings[key].clone() } + ids::FunctionWithBodyLongId::Specialized(_specialized) => { + unreachable!("There is no generic version of a specialized function.") + } }; Ok(Arc::new(lowered)) } @@ -471,6 +476,7 @@ fn priv_concrete_function_with_body_lowered_flat( function: ids::ConcreteFunctionWithBodyId, ) -> Maybe> { let semantic_db = db; + let generic_function_id = function.function_with_body_id(db); db.function_with_body_lowering_diagnostics(generic_function_id)?.check_error_free()?; let mut lowered = (*db.function_with_body_lowering(generic_function_id)?).clone(); @@ -485,11 +491,74 @@ fn concrete_function_with_body_postpanic_lowered( db: &dyn LoweringGroup, function: ids::ConcreteFunctionWithBodyId, ) -> Maybe> { - let mut lowered = (*db.priv_concrete_function_with_body_lowered_flat(function)?).clone(); + let mut lowered = if let ids::ConcreteFunctionWithBodyLongId::Specialized(specialized) = + function.lookup_intern(db) + { + let base = db.concrete_function_with_body_postpanic_lowered(specialized.base)?; + + let base_semantic = specialized.base.base_semantic_function(db); + + let mut variables = VariableAllocator::new( + db, + base_semantic.function_with_body_id(db), + Default::default(), + )?; + let mut statement = vec![]; + let mut parameters = vec![]; + for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) { + let var_id = variables.variables.alloc(base.variables[*param].clone()); + if let Some(arg) = arg { + statement + .push(Statement::Const(StatementConst { value: arg.clone(), output: var_id })); + continue; + } + parameters.push(var_id); + } - add_withdraw_gas(db, function, &mut lowered)?; - lower_panics(db, function, &mut lowered)?; - add_destructs(db, function, &mut lowered); + let location = LocationId::from_stable_location( + db, + specialized.base.base_semantic_function(db).stable_location(db), + ); + let inputs = + variables.variables.iter().map(|(var_id, _)| VarUsage { var_id, location }).collect(); + + let outputs: Vec = chain!( + base.signature.extra_rets.iter().map(|ret| ret.ty()), + [base.signature.return_type] + ) + .map(|ty| variables.new_var(VarRequest { ty, location })) + .collect_vec(); + + let mut block_builder = BlocksBuilder::new(); + + let ret_usage = + outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec(); + + statement.push(Statement::Call(StatementCall { + function: specialized.base.function_id(db)?, + with_coupon: false, + inputs, + outputs, + location, + })); + block_builder.alloc(crate::FlatBlock { + statements: statement, + end: FlatBlockEnd::Return(ret_usage, location), + }); + FlatLowered { + signature: function.signature(db)?, + variables: variables.variables, + blocks: block_builder.build().unwrap(), + parameters, + diagnostics: Default::default(), + } + } else { + let mut lowered = (*db.priv_concrete_function_with_body_lowered_flat(function)?).clone(); + add_withdraw_gas(db, function, &mut lowered)?; + lower_panics(db, function, &mut lowered)?; + add_destructs(db, function, &mut lowered); + lowered + }; scrub_units(db, &mut lowered); Ok(Arc::new(lowered)) diff --git a/crates/cairo-lang-lowering/src/ids.rs b/crates/cairo-lang-lowering/src/ids.rs index d5843250f25..fe9129b09cf 100644 --- a/crates/cairo-lang-lowering/src/ids.rs +++ b/crates/cairo-lang-lowering/src/ids.rs @@ -1,9 +1,12 @@ +use std::sync::Arc; + use cairo_lang_debug::DebugWithDb; use cairo_lang_defs::ids::{ NamedLanguageElementId, TopLevelLanguageElementId, TraitFunctionId, UnstableSalsaId, }; use cairo_lang_diagnostics::{DiagnosticAdded, DiagnosticNote, Maybe}; use cairo_lang_proc_macros::{DebugWithDb, SemanticObject}; +use cairo_lang_semantic::items::constant::ConstValue; use cairo_lang_semantic::items::functions::ImplGenericFunctionId; use cairo_lang_semantic::items::imp::ImplLongId; use cairo_lang_semantic::{GenericArgumentId, TypeLongId}; @@ -13,6 +16,7 @@ use cairo_lang_syntax::node::{TypedStablePtr, ast}; use cairo_lang_utils::{Intern, LookupIntern, define_short_id, try_extract_matches}; use defs::diagnostic_utils::StableLocation; use defs::ids::{ExternFunctionId, FreeFunctionId}; +use itertools::zip_eq; use semantic::items::functions::GenericFunctionId; use semantic::substitution::{GenericSubstitution, SubstitutionRewriter}; use semantic::{ExprVar, Mutability}; @@ -22,10 +26,11 @@ use crate::Location; use crate::db::LoweringGroup; use crate::ids::semantic::substitution::SemanticRewriter; -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] pub enum FunctionWithBodyLongId { Semantic(defs::ids::FunctionWithBodyId), Generated { parent: defs::ids::FunctionWithBodyId, key: GeneratedFunctionKey }, + Specialized(SpecializedFunction), } define_short_id!( FunctionWithBodyId, @@ -37,24 +42,30 @@ define_short_id!( impl FunctionWithBodyLongId { pub fn base_semantic_function( &self, - _db: &dyn LoweringGroup, + db: &dyn LoweringGroup, ) -> cairo_lang_defs::ids::FunctionWithBodyId { - match *self { - FunctionWithBodyLongId::Semantic(id) => id, - FunctionWithBodyLongId::Generated { parent, .. } => parent, + match self { + FunctionWithBodyLongId::Semantic(id) => *id, + FunctionWithBodyLongId::Generated { parent, .. } => *parent, + FunctionWithBodyLongId::Specialized(specialized) => { + specialized.base.base_semantic_function(db).function_with_body_id(db) + } } } pub fn to_concrete(&self, db: &dyn LoweringGroup) -> Maybe { - Ok(match *self { + Ok(match self { FunctionWithBodyLongId::Semantic(semantic) => ConcreteFunctionWithBodyLongId::Semantic( - semantic::ConcreteFunctionWithBodyId::from_generic(db, semantic)?, + semantic::ConcreteFunctionWithBodyId::from_generic(db, *semantic)?, ), FunctionWithBodyLongId::Generated { parent, key } => { ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction { - parent: semantic::ConcreteFunctionWithBodyId::from_generic(db, parent)?, - key, + parent: semantic::ConcreteFunctionWithBodyId::from_generic(db, *parent)?, + key: *key, }) } + FunctionWithBodyLongId::Specialized(specialized) => { + ConcreteFunctionWithBodyLongId::Specialized(specialized.clone()) + } }) } } @@ -82,10 +93,11 @@ impl SemanticFunctionWithBodyIdEx for cairo_lang_defs::ids::FunctionWithBodyId { } /// Concrete function with body. -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] pub enum ConcreteFunctionWithBodyLongId { Semantic(semantic::ConcreteFunctionWithBodyId), Generated(GeneratedFunction), + Specialized(SpecializedFunction), } define_short_id!( ConcreteFunctionWithBodyId, @@ -117,12 +129,18 @@ impl UnstableSalsaId for ConcreteFunctionWithBodyId { } impl ConcreteFunctionWithBodyLongId { pub fn function_with_body_id(&self, db: &dyn LoweringGroup) -> FunctionWithBodyId { - let long_id = match *self { + let long_id = match self { ConcreteFunctionWithBodyLongId::Semantic(id) => { FunctionWithBodyLongId::Semantic(id.function_with_body_id(db)) } ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction { parent, key }) => { - FunctionWithBodyLongId::Generated { parent: parent.function_with_body_id(db), key } + FunctionWithBodyLongId::Generated { + parent: parent.function_with_body_id(db), + key: *key, + } + } + ConcreteFunctionWithBodyLongId::Specialized(specialized_function) => { + return specialized_function.base.function_with_body_id(db); } }; long_id.intern(db) @@ -133,6 +151,9 @@ impl ConcreteFunctionWithBodyLongId { ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction { parent, .. }) => { parent.substitution(db) } + ConcreteFunctionWithBodyLongId::Specialized(specialized) => { + specialized.base.substitution(db) + } } } pub fn function_id(&self, db: &dyn LoweringGroup) -> Maybe { @@ -143,22 +164,29 @@ impl ConcreteFunctionWithBodyLongId { ConcreteFunctionWithBodyLongId::Generated(generated) => { FunctionLongId::Generated(*generated) } + ConcreteFunctionWithBodyLongId::Specialized(specialized) => { + FunctionLongId::Specialized(specialized.clone()) + } }; Ok(long_id.intern(db)) } pub fn base_semantic_function( &self, - _db: &dyn LoweringGroup, + db: &dyn LoweringGroup, ) -> semantic::ConcreteFunctionWithBodyId { - match *self { - ConcreteFunctionWithBodyLongId::Semantic(id) => id, + match self { + ConcreteFunctionWithBodyLongId::Semantic(id) => *id, ConcreteFunctionWithBodyLongId::Generated(generated) => generated.parent, + ConcreteFunctionWithBodyLongId::Specialized(specialized) => { + specialized.base.base_semantic_function(db) + } } } pub fn full_path(&self, db: &dyn LoweringGroup) -> String { match self { ConcreteFunctionWithBodyLongId::Semantic(semantic) => semantic.full_path(db), ConcreteFunctionWithBodyLongId::Generated(generated) => generated.full_path(db), + ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.full_path(db), } } } @@ -206,17 +234,22 @@ impl ConcreteFunctionWithBodyId { GeneratedFunctionKey::Loop(stable_ptr) => StableLocation::new(stable_ptr.untyped()), GeneratedFunctionKey::TraitFunc(_, stable_location) => stable_location, }, + ConcreteFunctionWithBodyLongId::Specialized(specialized_function) => { + specialized_function.base.stable_location(db)? + } }) } } /// Function. -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] pub enum FunctionLongId { /// An original function from the user code. Semantic(semantic::FunctionId), /// A function generated by the compiler. Generated(GeneratedFunction), + /// A specialized function. + Specialized(SpecializedFunction), } define_short_id!( FunctionId, @@ -227,7 +260,7 @@ define_short_id!( ); impl FunctionLongId { pub fn body(&self, db: &dyn LoweringGroup) -> Maybe> { - Ok(Some(match *self { + Ok(Some(match self { FunctionLongId::Semantic(id) => { let concrete_function = id.get_concrete(db); if let GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function }) = @@ -273,6 +306,9 @@ impl FunctionLongId { ConcreteFunctionWithBodyLongId::Semantic(body).intern(db) } FunctionLongId::Generated(generated) => generated.body(db), + FunctionLongId::Specialized(specialized) => { + ConcreteFunctionWithBodyLongId::Specialized(specialized.clone()).intern(db) + } })) } pub fn signature(&self, db: &dyn LoweringGroup) -> Maybe { @@ -281,6 +317,20 @@ impl FunctionLongId { Ok(Signature::from_semantic(db, db.concrete_function_signature(*semantic)?)) } FunctionLongId::Generated(generated) => generated.body(db).signature(db), + FunctionLongId::Specialized(specialized) => { + let mut base_sign = specialized.base.signature(db)?; + + base_sign.params = zip_eq(base_sign.params, specialized.args.iter()) + .filter_map(|(param, arg)| { + if arg.is_none() { + return Some(param); + } + None + }) + .collect::>(); + + Ok(base_sign) + } } } pub fn full_path(&self, db: &dyn LoweringGroup) -> String { @@ -293,6 +343,7 @@ impl FunctionLongId { match self { FunctionLongId::Semantic(id) => id.full_path(db), FunctionLongId::Generated(generated) => generated.parent.full_path(db), + FunctionLongId::Specialized(specialized) => specialized.full_path(db), } } } @@ -349,6 +400,7 @@ impl<'a> DebugWithDb for FunctionLongId { match self { FunctionLongId::Semantic(semantic) => write!(f, "{:?}", semantic.debug(db)), FunctionLongId::Generated(generated) => write!(f, "{:?}", generated.debug(db)), + FunctionLongId::Specialized(specialized) => write!(f, "{:?}", specialized.debug(db)), } } } @@ -416,6 +468,41 @@ impl<'a> DebugWithDb for GeneratedFunction { } } +/// Specialized function. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct SpecializedFunction { + /// The base function. + pub base: crate::ids::ConcreteFunctionWithBodyId, + /// Optional const assigment of the arguments. + pub args: Arc<[Option]>, +} + +impl SpecializedFunction { + pub fn body(&self, db: &dyn LoweringGroup) -> ConcreteFunctionWithBodyId { + let long_id = ConcreteFunctionWithBodyLongId::Specialized(self.clone()); + long_id.intern(db) + } + pub fn full_path(&self, db: &dyn LoweringGroup) -> String { + format!("{:?}", self.debug(db)) + } +} +impl<'a> DebugWithDb for SpecializedFunction { + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + db: &(dyn LoweringGroup + 'a), + ) -> std::fmt::Result { + write!(f, "{}{{", self.base.full_path(db))?; + for arg in self.args.iter() { + match arg { + Some(value) => write!(f, "{:?}, ", value.debug(db))?, + None => write!(f, "None, ")?, + } + } + write!(f, "}}") + } +} + /// Lowered signature of a function. #[derive(Clone, Debug, PartialEq, Eq, DebugWithDb, SemanticObject, Hash)] #[debug_db(dyn LoweringGroup + 'a)] diff --git a/crates/cairo-lang-lowering/src/inline/mod.rs b/crates/cairo-lang-lowering/src/inline/mod.rs index a66455538e8..d213c2e1206 100644 --- a/crates/cairo-lang-lowering/src/inline/mod.rs +++ b/crates/cairo-lang-lowering/src/inline/mod.rs @@ -38,6 +38,9 @@ pub fn get_inline_diagnostics( let inline_config = match function_id.lookup_intern(db) { FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(id)?, FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None, + FunctionWithBodyLongId::Specialized(specialized) => db.function_declaration_inline_config( + specialized.base.base_semantic_function(db).function_with_body_id(db), + )?, }; let mut diagnostics = LoweringDiagnostics::default(); diff --git a/crates/cairo-lang-lowering/src/lower/mod.rs b/crates/cairo-lang-lowering/src/lower/mod.rs index d6eff3d988b..7d713135ace 100644 --- a/crates/cairo-lang-lowering/src/lower/mod.rs +++ b/crates/cairo-lang-lowering/src/lower/mod.rs @@ -72,6 +72,9 @@ pub mod refs; #[cfg(test)] mod generated_test; +#[cfg(test)] +mod specialized_test; + /// Lowering of a function together with extra generated functions. #[derive(Clone, Debug, PartialEq, Eq)] pub struct MultiLowering { diff --git a/crates/cairo-lang-lowering/src/lower/specialized_test.rs b/crates/cairo-lang-lowering/src/lower/specialized_test.rs new file mode 100644 index 00000000000..495295fc2a1 --- /dev/null +++ b/crates/cairo-lang-lowering/src/lower/specialized_test.rs @@ -0,0 +1,64 @@ +use std::sync::Arc; + +use cairo_lang_debug::DebugWithDb; +use cairo_lang_semantic::db::SemanticGroup; +use cairo_lang_semantic::items::constant::ConstValue; +use cairo_lang_semantic::test_utils::setup_test_function; +use cairo_lang_test_utils::parse_test_file::TestRunnerResult; +use cairo_lang_utils::Intern; +use cairo_lang_utils::ordered_hash_map::OrderedHashMap; +use num_bigint::BigInt; +use num_traits::One; + +use crate::db::LoweringGroup; +use crate::fmt::LoweredFormatter; +use crate::ids::{ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, SpecializedFunction}; +use crate::test_utils::LoweringDatabaseForTesting; + +cairo_lang_test_utils::test_file_test!( + specialized, + "src/lower/test_data", + { + specialized :"specialized", + }, + test_specialized_function +); + +fn test_specialized_function( + inputs: &OrderedHashMap, + _args: &OrderedHashMap, +) -> TestRunnerResult { + let db = &mut LoweringDatabaseForTesting::default(); + let (test_function, semantic_diagnostics) = setup_test_function( + db, + inputs["function"].as_str(), + inputs["function_name"].as_str(), + inputs["module_code"].as_str(), + ) + .split(); + + let function_id = + ConcreteFunctionWithBodyId::from_semantic(db, test_function.concrete_function_id); + + let core = db.core_info(); + + let specialized_func = SpecializedFunction { + base: function_id, + args: Arc::new([None, Some(ConstValue::Int(BigInt::one(), core.felt252))]), + }; + + let specialized_func = ConcreteFunctionWithBodyLongId::Specialized(specialized_func).intern(db); + let lowered = db.final_concrete_function_with_body_lowered(specialized_func).unwrap(); + let lowered_formatter = LoweredFormatter::new(db, &lowered.variables); + let lowered = format!("{:?}", lowered.debug(&lowered_formatter)); + + let lowering_diagnostics = + db.module_lowering_diagnostics(test_function.module_id).unwrap_or_default(); + + TestRunnerResult::success(OrderedHashMap::from([ + ("full_path".into(), specialized_func.full_path(db)), + ("semantic_diagnostics".into(), semantic_diagnostics), + ("lowering".into(), lowered), + ("lowering_diagnostics".into(), lowering_diagnostics.format(db)), + ])) +} diff --git a/crates/cairo-lang-lowering/src/lower/test_data/specialized b/crates/cairo-lang-lowering/src/lower/test_data/specialized new file mode 100644 index 00000000000..a2017efee3a --- /dev/null +++ b/crates/cairo-lang-lowering/src/lower/test_data/specialized @@ -0,0 +1,33 @@ +//! > Test simple specialization. + +//! > test_runner_name +test_specialized_function + +//! > function +fn foo(x: felt252, y: felt252) -> felt252 { + x + y +} + +//! > function_name +foo + +//! > module_code + +//! > semantic_diagnostics + +//! > lowering_diagnostics + +//! > lowering_flat +Parameters: + +//! > full_path +test::foo{None, 1, } + +//! > lowering +Parameters: v0: core::felt252 +blk0 (root): +Statements: + (v1: core::felt252) <- 1 + (v2: core::felt252) <- core::felt252_add(v0, v1) +End: + Return(v2)