Skip to content

Commit 2ca9e71

Browse files
Add SpecializedFunction.
1 parent a961427 commit 2ca9e71

File tree

8 files changed

+305
-25
lines changed

8 files changed

+305
-25
lines changed

crates/cairo-lang-lowering/src/cache/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,7 @@ impl FunctionCached {
13901390
FunctionLongId::Generated(id) => {
13911391
FunctionCached::Generated(GeneratedFunctionCached::new(id, ctx))
13921392
}
1393+
FunctionLongId::Specialized(_) => todo!(),
13931394
}
13941395
}
13951396
fn embed(self, ctx: &mut CacheLoadingContext<'_>) -> FunctionId {

crates/cairo-lang-lowering/src/concretize/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ fn concretize_function(
2424
})
2525
.intern(db))
2626
}
27+
FunctionLongId::Specialized(_) => unreachable!("This should not be called."),
2728
}
2829
}
2930

crates/cairo-lang-lowering/src/db.rs

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
1616
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
1717
use cairo_lang_utils::{Intern, LookupIntern, Upcast};
1818
use defs::ids::NamedLanguageElementId;
19-
use itertools::{Itertools, chain};
19+
use itertools::{Itertools, chain, zip_eq};
2020
use num_traits::ToPrimitive;
2121

2222
use crate::add_withdraw_gas::add_withdraw_gas;
23-
use crate::blocks::Blocks;
23+
use crate::blocks::{Blocks, BlocksBuilder};
2424
use crate::borrow_check::{
2525
PotentialDestructCalls, borrow_check, borrow_check_possible_withdraw_gas,
2626
};
@@ -29,17 +29,19 @@ use crate::concretize::concretize_lowered;
2929
use crate::destructs::add_destructs;
3030
use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind};
3131
use crate::graph_algorithms::feedback_set::flag_add_withdraw_gas;
32-
use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId};
32+
use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId};
3333
use crate::inline::get_inline_diagnostics;
3434
use crate::inline::statements_weights::{ApproxCasmInlineWeight, InlineWeight};
35+
use crate::lower::context::{VarRequest, VariableAllocator};
3536
use crate::lower::{MultiLowering, lower_semantic_function};
3637
use crate::optimizations::config::OptimizationConfig;
3738
use crate::optimizations::scrub_units::scrub_units;
3839
use crate::optimizations::strategy::{OptimizationStrategy, OptimizationStrategyId};
3940
use crate::panic::lower_panics;
4041
use crate::utils::InliningStrategy;
4142
use crate::{
42-
BlockEnd, BlockId, DependencyType, Location, Lowered, LoweringStage, MatchInfo, Statement, ids,
43+
Block, BlockEnd, BlockId, DependencyType, Location, Lowered, LoweringStage, MatchInfo,
44+
Statement, StatementCall, StatementConst, VarUsage, VariableId, ids,
4345
};
4446

4547
/// A trait for estimation of the code size of a function.
@@ -394,6 +396,9 @@ fn priv_function_with_body_lowering(
394396
ids::FunctionWithBodyLongId::Generated { key, .. } => {
395397
multi_lowering.generated_lowerings[key].clone()
396398
}
399+
ids::FunctionWithBodyLongId::Specialized(_specialized) => {
400+
unreachable!("There is no generic version of a specialized function.")
401+
}
397402
};
398403
Ok(Arc::new(lowered))
399404
}
@@ -434,10 +439,83 @@ fn lowered_body(
434439
) -> Maybe<Arc<Lowered>> {
435440
match stage {
436441
LoweringStage::Monomorphized => {
437-
let generic_function_id = function.function_with_body_id(db);
438-
db.function_with_body_lowering_diagnostics(generic_function_id)?.check_error_free()?;
439-
let mut lowered = (*db.function_with_body_lowering(generic_function_id)?).clone();
440-
concretize_lowered(db, &mut lowered, &function.substitution(db)?)?;
442+
let lowered = if let ids::ConcreteFunctionWithBodyLongId::Specialized(specialized) =
443+
function.lookup_intern(db)
444+
{
445+
let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
446+
447+
let base_semantic = specialized.base.base_semantic_function(db);
448+
449+
let mut variables = VariableAllocator::new(
450+
db,
451+
base_semantic.function_with_body_id(db),
452+
Default::default(),
453+
)?;
454+
let mut statement = vec![];
455+
let mut parameters = vec![];
456+
for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
457+
let var_id = variables.variables.alloc(base.variables[*param].clone());
458+
if let Some(arg) = arg {
459+
statement.push(Statement::Const(StatementConst {
460+
value: arg.clone(),
461+
output: var_id,
462+
}));
463+
continue;
464+
}
465+
parameters.push(var_id);
466+
}
467+
468+
let location = LocationId::from_stable_location(
469+
db,
470+
specialized.base.base_semantic_function(db).stable_location(db),
471+
);
472+
let inputs = variables
473+
.variables
474+
.iter()
475+
.map(|(var_id, _)| VarUsage { var_id, location })
476+
.collect();
477+
478+
let outputs: Vec<VariableId> = chain!(
479+
base.signature.extra_rets.iter().map(|ret| ret.ty()),
480+
[base.signature.return_type]
481+
)
482+
.map(|ty| variables.new_var(VarRequest { ty, location }))
483+
.collect_vec();
484+
485+
let mut block_builder = BlocksBuilder::new();
486+
487+
let ret_usage = outputs
488+
.iter()
489+
.map(|var_id| VarUsage { var_id: *var_id, location })
490+
.collect_vec();
491+
492+
statement.push(Statement::Call(StatementCall {
493+
function: specialized.base.function_id(db)?,
494+
with_coupon: false,
495+
inputs,
496+
outputs,
497+
location,
498+
}));
499+
block_builder.alloc(Block {
500+
statements: statement,
501+
end: BlockEnd::Return(ret_usage, location),
502+
});
503+
Lowered {
504+
signature: function.signature(db)?,
505+
variables: variables.variables,
506+
blocks: block_builder.build().unwrap(),
507+
parameters,
508+
diagnostics: Default::default(),
509+
}
510+
} else {
511+
let generic_function_id = function.function_with_body_id(db);
512+
db.function_with_body_lowering_diagnostics(generic_function_id)?
513+
.check_error_free()?;
514+
let mut lowered = (*db.function_with_body_lowering(generic_function_id)?).clone();
515+
concretize_lowered(db, &mut lowered, &function.substitution(db)?)?;
516+
lowered
517+
};
518+
441519
Ok(Arc::new(lowered))
442520
}
443521
LoweringStage::PreOptimizations => {

0 commit comments

Comments
 (0)