Skip to content

Commit 003f908

Browse files
Add SpecializedFunction.
1 parent 8614785 commit 003f908

File tree

9 files changed

+325
-30
lines changed

9 files changed

+325
-30
lines changed

crates/cairo-lang-lowering/src/cache/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,9 @@ impl FunctionCached {
13901390
FunctionLongId::Generated(id) => {
13911391
FunctionCached::Generated(GeneratedFunctionCached::new(id, ctx))
13921392
}
1393+
FunctionLongId::Specialized(_) => {
1394+
unreachable!("Specialization of functions only occurs post concretization.")
1395+
}
13931396
}
13941397
}
13951398
fn embed(self, ctx: &mut CacheLoadingContext<'_>) -> FunctionId {

crates/cairo-lang-lowering/src/concretize/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ fn concretize_function(
2424
})
2525
.intern(db))
2626
}
27+
FunctionLongId::Specialized(_) => {
28+
unreachable!("Specialization of functions only occurs post concretization.")
29+
}
2730
}
2831
}
2932

crates/cairo-lang-lowering/src/db.rs

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
1616
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
1717
use cairo_lang_utils::{Intern, LookupIntern, Upcast};
1818
use defs::ids::NamedLanguageElementId;
19-
use itertools::{Itertools, chain};
19+
use itertools::{Itertools, chain, zip_eq};
2020
use num_traits::ToPrimitive;
2121

2222
use crate::add_withdraw_gas::add_withdraw_gas;
23-
use crate::blocks::Blocks;
23+
use crate::blocks::{Blocks, BlocksBuilder};
2424
use crate::borrow_check::{
2525
PotentialDestructCalls, borrow_check, borrow_check_possible_withdraw_gas,
2626
};
@@ -29,17 +29,21 @@ use crate::concretize::concretize_lowered;
2929
use crate::destructs::add_destructs;
3030
use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind};
3131
use crate::graph_algorithms::feedback_set::flag_add_withdraw_gas;
32-
use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId};
32+
use crate::ids::{
33+
ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId, SpecializedFunction,
34+
};
3335
use crate::inline::get_inline_diagnostics;
3436
use crate::inline::statements_weights::{ApproxCasmInlineWeight, InlineWeight};
37+
use crate::lower::context::{VarRequest, VariableAllocator};
3538
use crate::lower::{MultiLowering, lower_semantic_function};
3639
use crate::optimizations::config::OptimizationConfig;
3740
use crate::optimizations::scrub_units::scrub_units;
3841
use crate::optimizations::strategy::{OptimizationStrategy, OptimizationStrategyId};
3942
use crate::panic::lower_panics;
4043
use crate::utils::InliningStrategy;
4144
use crate::{
42-
BlockEnd, BlockId, DependencyType, Location, Lowered, LoweringStage, MatchInfo, Statement, ids,
45+
Block, BlockEnd, BlockId, DependencyType, Location, Lowered, LoweringStage, MatchInfo,
46+
Statement, StatementCall, StatementConst, VarUsage, VariableId, ids,
4347
};
4448

4549
/// A trait for estimation of the code size of a function.
@@ -432,34 +436,97 @@ fn lowered_body(
432436
function: ids::ConcreteFunctionWithBodyId,
433437
stage: LoweringStage,
434438
) -> Maybe<Arc<Lowered>> {
435-
match stage {
439+
let lowered = match stage {
436440
LoweringStage::Monomorphized => {
437-
let generic_function_id = function.function_with_body_id(db);
441+
let generic_function_id = match function.lookup_intern(db) {
442+
ids::ConcreteFunctionWithBodyLongId::Semantic(id) => {
443+
ids::FunctionWithBodyLongId::Semantic(id.function_with_body_id(db))
444+
}
445+
ids::ConcreteFunctionWithBodyLongId::Generated(id) => {
446+
id.function_with_body_long_id(db)
447+
}
448+
ids::ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
449+
return Ok(Arc::new(specialized_function_lowered(db, specialized)?));
450+
}
451+
}
452+
.intern(db);
453+
438454
db.function_with_body_lowering_diagnostics(generic_function_id)?.check_error_free()?;
439455
let mut lowered = (*db.function_with_body_lowering(generic_function_id)?).clone();
440456
concretize_lowered(db, &mut lowered, &function.substitution(db)?)?;
441-
Ok(Arc::new(lowered))
457+
lowered
442458
}
443459
LoweringStage::PreOptimizations => {
444460
let mut lowered = (*db.lowered_body(function, LoweringStage::Monomorphized)?).clone();
445461
add_withdraw_gas(db, function, &mut lowered)?;
446462
lower_panics(db, function, &mut lowered)?;
447463
add_destructs(db, function, &mut lowered);
448464
scrub_units(db, &mut lowered);
449-
Ok(Arc::new(lowered))
465+
lowered
450466
}
451467
LoweringStage::PostBaseline => {
452468
let mut lowered =
453469
(*db.lowered_body(function, LoweringStage::PreOptimizations)?).clone();
454470
db.baseline_optimization_strategy().apply_strategy(db, function, &mut lowered)?;
455-
Ok(Arc::new(lowered))
471+
lowered
456472
}
457473
LoweringStage::Final => {
458474
let mut lowered = (*db.lowered_body(function, LoweringStage::PostBaseline)?).clone();
459475
db.final_optimization_strategy().apply_strategy(db, function, &mut lowered)?;
460-
Ok(Arc::new(lowered))
476+
lowered
477+
}
478+
};
479+
Ok(Arc::new(lowered))
480+
}
481+
482+
/// Returns the lowering of a specialized function.
483+
fn specialized_function_lowered(
484+
db: &dyn LoweringGroup,
485+
specialized: SpecializedFunction,
486+
) -> Maybe<Lowered> {
487+
let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
488+
let base_semantic = specialized.base.base_semantic_function(db);
489+
let mut variables =
490+
VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
491+
let mut statement = vec![];
492+
let mut parameters = vec![];
493+
for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
494+
let var_id = variables.variables.alloc(base.variables[*param].clone());
495+
if let Some(arg) = arg {
496+
statement.push(Statement::Const(StatementConst { value: arg.clone(), output: var_id }));
497+
continue;
461498
}
499+
parameters.push(var_id);
462500
}
501+
let location = LocationId::from_stable_location(
502+
db,
503+
specialized.base.base_semantic_function(db).stable_location(db),
504+
);
505+
let inputs =
506+
variables.variables.iter().map(|(var_id, _)| VarUsage { var_id, location }).collect();
507+
let outputs: Vec<VariableId> =
508+
chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
509+
.map(|ty| variables.new_var(VarRequest { ty, location }))
510+
.collect_vec();
511+
let mut block_builder = BlocksBuilder::new();
512+
let ret_usage =
513+
outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
514+
statement.push(Statement::Call(StatementCall {
515+
function: specialized.base.function_id(db)?,
516+
with_coupon: false,
517+
inputs,
518+
outputs,
519+
location,
520+
}));
521+
block_builder
522+
.alloc(Block { statements: statement, end: BlockEnd::Return(ret_usage, location) });
523+
Ok(Lowered {
524+
signature: specialized.signature(db)?,
525+
variables: variables.variables,
526+
blocks: block_builder.build().unwrap(),
527+
parameters,
528+
diagnostics: Default::default(),
529+
})
463530
}
464531

465532
/// Given the lowering of a function, returns the set of direct dependencies of that function,

crates/cairo-lang-lowering/src/graph_algorithms/cycles.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use cairo_lang_diagnostics::Maybe;
22
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
3+
use cairo_lang_utils::{Intern, LookupIntern};
34

45
use crate::db::{LoweringGroup, get_direct_callees};
5-
use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionWithBodyId};
6+
use crate::ids::{self, ConcreteFunctionWithBodyId, FunctionId, FunctionWithBodyId};
67
use crate::{DependencyType, LoweringStage};
78

89
/// Query implementation of
@@ -31,7 +32,20 @@ pub fn function_with_body_direct_function_with_body_callees(
3132
.collect::<Maybe<Vec<Option<_>>>>()?
3233
.into_iter()
3334
.flatten()
34-
.map(|x| x.function_with_body_id(db))
35+
.map(|x| {
36+
match x.lookup_intern(db) {
37+
ids::ConcreteFunctionWithBodyLongId::Semantic(id) => {
38+
ids::FunctionWithBodyLongId::Semantic(id.function_with_body_id(db))
39+
}
40+
ids::ConcreteFunctionWithBodyLongId::Generated(id) => {
41+
id.function_with_body_long_id(db)
42+
}
43+
ids::ConcreteFunctionWithBodyLongId::Specialized(_) => {
44+
unreachable!("Specialization of functions only occurs post concretization.")
45+
}
46+
}
47+
.intern(db)
48+
})
3549
.collect())
3650
}
3751

0 commit comments

Comments
 (0)