@@ -16,11 +16,11 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
16
16
use cairo_lang_utils:: unordered_hash_set:: UnorderedHashSet ;
17
17
use cairo_lang_utils:: { Intern , LookupIntern , Upcast } ;
18
18
use defs:: ids:: NamedLanguageElementId ;
19
- use itertools:: { Itertools , chain} ;
19
+ use itertools:: { Itertools , chain, zip_eq } ;
20
20
use num_traits:: ToPrimitive ;
21
21
22
22
use crate :: add_withdraw_gas:: add_withdraw_gas;
23
- use crate :: blocks:: Blocks ;
23
+ use crate :: blocks:: { Blocks , BlocksBuilder } ;
24
24
use crate :: borrow_check:: {
25
25
PotentialDestructCalls , borrow_check, borrow_check_possible_withdraw_gas,
26
26
} ;
@@ -29,17 +29,19 @@ use crate::concretize::concretize_lowered;
29
29
use crate :: destructs:: add_destructs;
30
30
use crate :: diagnostic:: { LoweringDiagnostic , LoweringDiagnosticKind } ;
31
31
use crate :: graph_algorithms:: feedback_set:: flag_add_withdraw_gas;
32
- use crate :: ids:: { ConcreteFunctionWithBodyId , FunctionId , FunctionLongId } ;
32
+ use crate :: ids:: { ConcreteFunctionWithBodyId , FunctionId , FunctionLongId , LocationId } ;
33
33
use crate :: inline:: get_inline_diagnostics;
34
34
use crate :: inline:: statements_weights:: { ApproxCasmInlineWeight , InlineWeight } ;
35
+ use crate :: lower:: context:: { VarRequest , VariableAllocator } ;
35
36
use crate :: lower:: { MultiLowering , lower_semantic_function} ;
36
37
use crate :: optimizations:: config:: OptimizationConfig ;
37
38
use crate :: optimizations:: scrub_units:: scrub_units;
38
39
use crate :: optimizations:: strategy:: { OptimizationStrategy , OptimizationStrategyId } ;
39
40
use crate :: panic:: lower_panics;
40
41
use crate :: utils:: InliningStrategy ;
41
42
use crate :: {
42
- BlockEnd , BlockId , DependencyType , Location , Lowered , LoweringStage , MatchInfo , Statement , ids,
43
+ Block , BlockEnd , BlockId , DependencyType , Location , Lowered , LoweringStage , MatchInfo ,
44
+ Statement , StatementCall , StatementConst , VarUsage , VariableId , ids,
43
45
} ;
44
46
45
47
/// A trait for estimation of the code size of a function.
@@ -394,6 +396,9 @@ fn priv_function_with_body_lowering(
394
396
ids:: FunctionWithBodyLongId :: Generated { key, .. } => {
395
397
multi_lowering. generated_lowerings [ key] . clone ( )
396
398
}
399
+ ids:: FunctionWithBodyLongId :: Specialized ( _specialized) => {
400
+ unreachable ! ( "There is no generic version of a specialized function." )
401
+ }
397
402
} ;
398
403
Ok ( Arc :: new ( lowered) )
399
404
}
@@ -432,34 +437,94 @@ fn lowered_body(
432
437
function : ids:: ConcreteFunctionWithBodyId ,
433
438
stage : LoweringStage ,
434
439
) -> Maybe < Arc < Lowered > > {
435
- match stage {
436
- LoweringStage :: Monomorphized => {
437
- let generic_function_id = function. function_with_body_id ( db) ;
438
- db. function_with_body_lowering_diagnostics ( generic_function_id) ?. check_error_free ( ) ?;
439
- let mut lowered = ( * db. function_with_body_lowering ( generic_function_id) ?) . clone ( ) ;
440
- concretize_lowered ( db, & mut lowered, & function. substitution ( db) ?) ?;
441
- Ok ( Arc :: new ( lowered) )
442
- }
440
+ let lowered = match stage {
441
+ LoweringStage :: Monomorphized => match try_get_specialized_lowered ( db, function) ? {
442
+ Some ( lowered) => lowered,
443
+ None => {
444
+ let generic_function_id = function. function_with_body_id ( db) ;
445
+ db. function_with_body_lowering_diagnostics ( generic_function_id) ?
446
+ . check_error_free ( ) ?;
447
+ let mut lowered = ( * db. function_with_body_lowering ( generic_function_id) ?) . clone ( ) ;
448
+ concretize_lowered ( db, & mut lowered, & function. substitution ( db) ?) ?;
449
+ lowered
450
+ }
451
+ } ,
443
452
LoweringStage :: PreOptimizations => {
444
453
let mut lowered = ( * db. lowered_body ( function, LoweringStage :: Monomorphized ) ?) . clone ( ) ;
445
454
add_withdraw_gas ( db, function, & mut lowered) ?;
446
455
lower_panics ( db, function, & mut lowered) ?;
447
456
add_destructs ( db, function, & mut lowered) ;
448
457
scrub_units ( db, & mut lowered) ;
449
- Ok ( Arc :: new ( lowered) )
458
+ lowered
450
459
}
451
460
LoweringStage :: PostBaseline => {
452
461
let mut lowered =
453
462
( * db. lowered_body ( function, LoweringStage :: PreOptimizations ) ?) . clone ( ) ;
454
463
db. baseline_optimization_strategy ( ) . apply_strategy ( db, function, & mut lowered) ?;
455
- Ok ( Arc :: new ( lowered) )
464
+ lowered
456
465
}
457
466
LoweringStage :: Final => {
458
467
let mut lowered = ( * db. lowered_body ( function, LoweringStage :: PostBaseline ) ?) . clone ( ) ;
459
468
db. final_optimization_strategy ( ) . apply_strategy ( db, function, & mut lowered) ?;
460
- Ok ( Arc :: new ( lowered) )
469
+ lowered
461
470
}
471
+ } ;
472
+ Ok ( Arc :: new ( lowered) )
473
+ }
474
+
475
+ /// If the function is a specialized function, returns the lowering for that function.
476
+ /// Otherwise, returns None.
477
+ fn try_get_specialized_lowered (
478
+ db : & dyn LoweringGroup ,
479
+ function : ConcreteFunctionWithBodyId ,
480
+ ) -> Maybe < Option < Lowered > > {
481
+ let ids:: ConcreteFunctionWithBodyLongId :: Specialized ( specialized) = function. lookup_intern ( db)
482
+ else {
483
+ return Ok ( None ) ;
484
+ } ;
485
+ let base = db. lowered_body ( specialized. base , LoweringStage :: Monomorphized ) ?;
486
+ let base_semantic = specialized. base . base_semantic_function ( db) ;
487
+ let mut variables =
488
+ VariableAllocator :: new ( db, base_semantic. function_with_body_id ( db) , Default :: default ( ) ) ?;
489
+ let mut statement = vec ! [ ] ;
490
+ let mut parameters = vec ! [ ] ;
491
+ for ( param, arg) in zip_eq ( & base. parameters , specialized. args . iter ( ) ) {
492
+ let var_id = variables. variables . alloc ( base. variables [ * param] . clone ( ) ) ;
493
+ if let Some ( arg) = arg {
494
+ statement. push ( Statement :: Const ( StatementConst { value : arg. clone ( ) , output : var_id } ) ) ;
495
+ continue ;
496
+ }
497
+ parameters. push ( var_id) ;
462
498
}
499
+ let location = LocationId :: from_stable_location (
500
+ db,
501
+ specialized. base . base_semantic_function ( db) . stable_location ( db) ,
502
+ ) ;
503
+ let inputs =
504
+ variables. variables . iter ( ) . map ( |( var_id, _) | VarUsage { var_id, location } ) . collect ( ) ;
505
+ let outputs: Vec < VariableId > =
506
+ chain ! ( base. signature. extra_rets. iter( ) . map( |ret| ret. ty( ) ) , [ base. signature. return_type] )
507
+ . map ( |ty| variables. new_var ( VarRequest { ty, location } ) )
508
+ . collect_vec ( ) ;
509
+ let mut block_builder = BlocksBuilder :: new ( ) ;
510
+ let ret_usage =
511
+ outputs. iter ( ) . map ( |var_id| VarUsage { var_id : * var_id, location } ) . collect_vec ( ) ;
512
+ statement. push ( Statement :: Call ( StatementCall {
513
+ function : specialized. base . function_id ( db) ?,
514
+ with_coupon : false ,
515
+ inputs,
516
+ outputs,
517
+ location,
518
+ } ) ) ;
519
+ block_builder
520
+ . alloc ( Block { statements : statement, end : BlockEnd :: Return ( ret_usage, location) } ) ;
521
+ Ok ( Some ( Lowered {
522
+ signature : function. signature ( db) ?,
523
+ variables : variables. variables ,
524
+ blocks : block_builder. build ( ) . unwrap ( ) ,
525
+ parameters,
526
+ diagnostics : Default :: default ( ) ,
527
+ } ) )
463
528
}
464
529
465
530
/// Given the lowering of a function, returns the set of direct dependencies of that function,
0 commit comments