Skip to content

Commit c7abbcb

Browse files
Add SpecializedFunction.
1 parent 0639e14 commit c7abbcb

File tree

8 files changed

+308
-32
lines changed

8 files changed

+308
-32
lines changed

crates/cairo-lang-lowering/src/cache/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,9 @@ impl FunctionCached {
13901390
FunctionLongId::Generated(id) => {
13911391
FunctionCached::Generated(GeneratedFunctionCached::new(id, ctx))
13921392
}
1393+
FunctionLongId::Specialized(_) => {
1394+
unreachable!("Specialization of functions only occurs post concretization.")
1395+
}
13931396
}
13941397
}
13951398
fn embed(self, ctx: &mut CacheLoadingContext<'_>) -> FunctionId {

crates/cairo-lang-lowering/src/concretize/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ fn concretize_function(
2424
})
2525
.intern(db))
2626
}
27+
FunctionLongId::Specialized(_) => {
28+
unreachable!("Specialization of functions only occurs post concretization.")
29+
}
2730
}
2831
}
2932

crates/cairo-lang-lowering/src/db.rs

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
1616
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
1717
use cairo_lang_utils::{Intern, LookupIntern, Upcast};
1818
use defs::ids::NamedLanguageElementId;
19-
use itertools::{Itertools, chain};
19+
use itertools::{Itertools, chain, zip_eq};
2020
use num_traits::ToPrimitive;
2121

2222
use crate::add_withdraw_gas::add_withdraw_gas;
23-
use crate::blocks::Blocks;
23+
use crate::blocks::{Blocks, BlocksBuilder};
2424
use crate::borrow_check::{
2525
PotentialDestructCalls, borrow_check, borrow_check_possible_withdraw_gas,
2626
};
@@ -29,17 +29,19 @@ use crate::concretize::concretize_lowered;
2929
use crate::destructs::add_destructs;
3030
use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind};
3131
use crate::graph_algorithms::feedback_set::flag_add_withdraw_gas;
32-
use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId};
32+
use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId};
3333
use crate::inline::get_inline_diagnostics;
3434
use crate::inline::statements_weights::{ApproxCasmInlineWeight, InlineWeight};
35+
use crate::lower::context::{VarRequest, VariableAllocator};
3536
use crate::lower::{MultiLowering, lower_semantic_function};
3637
use crate::optimizations::config::OptimizationConfig;
3738
use crate::optimizations::scrub_units::scrub_units;
3839
use crate::optimizations::strategy::{OptimizationStrategy, OptimizationStrategyId};
3940
use crate::panic::lower_panics;
4041
use crate::utils::InliningStrategy;
4142
use crate::{
42-
BlockEnd, BlockId, DependencyType, Location, Lowered, LoweringStage, MatchInfo, Statement, ids,
43+
Block, BlockEnd, BlockId, DependencyType, Location, Lowered, LoweringStage, MatchInfo,
44+
Statement, StatementCall, StatementConst, VarUsage, VariableId, ids,
4345
};
4446

4547
/// A trait for estimation of the code size of a function.
@@ -394,6 +396,9 @@ fn priv_function_with_body_lowering(
394396
ids::FunctionWithBodyLongId::Generated { key, .. } => {
395397
multi_lowering.generated_lowerings[key].clone()
396398
}
399+
ids::FunctionWithBodyLongId::Specialized(_specialized) => {
400+
unreachable!("There is no generic version of a specialized function.")
401+
}
397402
};
398403
Ok(Arc::new(lowered))
399404
}
@@ -432,34 +437,94 @@ fn lowered_body(
432437
function: ids::ConcreteFunctionWithBodyId,
433438
stage: LoweringStage,
434439
) -> Maybe<Arc<Lowered>> {
435-
match stage {
436-
LoweringStage::Monomorphized => {
437-
let generic_function_id = function.function_with_body_id(db);
438-
db.function_with_body_lowering_diagnostics(generic_function_id)?.check_error_free()?;
439-
let mut lowered = (*db.function_with_body_lowering(generic_function_id)?).clone();
440-
concretize_lowered(db, &mut lowered, &function.substitution(db)?)?;
441-
Ok(Arc::new(lowered))
442-
}
440+
let lowered = match stage {
441+
LoweringStage::Monomorphized => match try_get_specialized_lowered(db, function)? {
442+
Some(lowered) => lowered,
443+
None => {
444+
let generic_function_id = function.function_with_body_id(db);
445+
db.function_with_body_lowering_diagnostics(generic_function_id)?
446+
.check_error_free()?;
447+
let mut lowered = (*db.function_with_body_lowering(generic_function_id)?).clone();
448+
concretize_lowered(db, &mut lowered, &function.substitution(db)?)?;
449+
lowered
450+
}
451+
},
443452
LoweringStage::PreOptimizations => {
444453
let mut lowered = (*db.lowered_body(function, LoweringStage::Monomorphized)?).clone();
445454
add_withdraw_gas(db, function, &mut lowered)?;
446455
lower_panics(db, function, &mut lowered)?;
447456
add_destructs(db, function, &mut lowered);
448457
scrub_units(db, &mut lowered);
449-
Ok(Arc::new(lowered))
458+
lowered
450459
}
451460
LoweringStage::PostBaseline => {
452461
let mut lowered =
453462
(*db.lowered_body(function, LoweringStage::PreOptimizations)?).clone();
454463
db.baseline_optimization_strategy().apply_strategy(db, function, &mut lowered)?;
455-
Ok(Arc::new(lowered))
464+
lowered
456465
}
457466
LoweringStage::Final => {
458467
let mut lowered = (*db.lowered_body(function, LoweringStage::PostBaseline)?).clone();
459468
db.final_optimization_strategy().apply_strategy(db, function, &mut lowered)?;
460-
Ok(Arc::new(lowered))
469+
lowered
461470
}
471+
};
472+
Ok(Arc::new(lowered))
473+
}
474+
475+
/// If the function is a specialized function, returns the lowering for that function.
476+
/// Otherwise, returns None.
477+
fn try_get_specialized_lowered(
478+
db: &dyn LoweringGroup,
479+
function: ConcreteFunctionWithBodyId,
480+
) -> Maybe<Option<Lowered>> {
481+
let ids::ConcreteFunctionWithBodyLongId::Specialized(specialized) = function.lookup_intern(db)
482+
else {
483+
return Ok(None);
484+
};
485+
let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
486+
let base_semantic = specialized.base.base_semantic_function(db);
487+
let mut variables =
488+
VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
489+
let mut statement = vec![];
490+
let mut parameters = vec![];
491+
for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
492+
let var_id = variables.variables.alloc(base.variables[*param].clone());
493+
if let Some(arg) = arg {
494+
statement.push(Statement::Const(StatementConst { value: arg.clone(), output: var_id }));
495+
continue;
496+
}
497+
parameters.push(var_id);
462498
}
499+
let location = LocationId::from_stable_location(
500+
db,
501+
specialized.base.base_semantic_function(db).stable_location(db),
502+
);
503+
let inputs =
504+
variables.variables.iter().map(|(var_id, _)| VarUsage { var_id, location }).collect();
505+
let outputs: Vec<VariableId> =
506+
chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
507+
.map(|ty| variables.new_var(VarRequest { ty, location }))
508+
.collect_vec();
509+
let mut block_builder = BlocksBuilder::new();
510+
let ret_usage =
511+
outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
512+
statement.push(Statement::Call(StatementCall {
513+
function: specialized.base.function_id(db)?,
514+
with_coupon: false,
515+
inputs,
516+
outputs,
517+
location,
518+
}));
519+
block_builder
520+
.alloc(Block { statements: statement, end: BlockEnd::Return(ret_usage, location) });
521+
Ok(Some(Lowered {
522+
signature: function.signature(db)?,
523+
variables: variables.variables,
524+
blocks: block_builder.build().unwrap(),
525+
parameters,
526+
diagnostics: Default::default(),
527+
}))
463528
}
464529

465530
/// Given the lowering of a function, returns the set of direct dependencies of that function,

0 commit comments

Comments
 (0)