@@ -2,7 +2,7 @@ use std::iter;
2
2
3
3
use ast:: make;
4
4
use either:: Either ;
5
- use hir:: { HirDisplay , InFile , Local , ModuleDef , Semantics , TypeInfo } ;
5
+ use hir:: { HirDisplay , InFile , Local , ModuleDef , PathResolution , Semantics , TypeInfo , TypeParam } ;
6
6
use ide_db:: {
7
7
defs:: { Definition , NameRefClass } ,
8
8
famous_defs:: FamousDefs ,
@@ -469,6 +469,24 @@ impl FunctionBody {
469
469
}
470
470
}
471
471
472
+ fn descendants ( & self ) -> impl Iterator < Item = SyntaxNode > {
473
+ match self {
474
+ FunctionBody :: Expr ( expr) => expr. syntax ( ) . descendants ( ) ,
475
+ FunctionBody :: Span { parent, .. } => parent. syntax ( ) . descendants ( ) ,
476
+ }
477
+ }
478
+
479
+ fn descendant_paths ( & self ) -> impl Iterator < Item = ast:: Path > {
480
+ self . descendants ( ) . filter_map ( |node| {
481
+ match_ast ! {
482
+ match node {
483
+ ast:: Path ( it) => Some ( it) ,
484
+ _ => None
485
+ }
486
+ }
487
+ } )
488
+ }
489
+
472
490
fn from_expr ( expr : ast:: Expr ) -> Option < Self > {
473
491
match expr {
474
492
ast:: Expr :: BreakExpr ( it) => it. expr ( ) . map ( Self :: Expr ) ,
@@ -678,11 +696,16 @@ impl FunctionBody {
678
696
parent_loop. get_or_insert ( loop_. syntax ( ) . clone ( ) ) ;
679
697
}
680
698
} ;
681
- let ( is_const, expr, ty, generic_param_list, where_clause) = loop {
699
+
700
+ let ( mut generic_param_list, mut where_clause) = ( None , None ) ;
701
+ let ( is_const, expr, ty) = loop {
682
702
let anc = ancestors. next ( ) ?;
683
703
break match_ast ! {
684
704
match anc {
685
- ast:: ClosureExpr ( closure) => ( false , closure. body( ) , infer_expr_opt( closure. body( ) ) , closure. generic_param_list( ) , None ) ,
705
+ ast:: ClosureExpr ( closure) => {
706
+ generic_param_list = closure. generic_param_list( ) ;
707
+ ( false , closure. body( ) , infer_expr_opt( closure. body( ) ) )
708
+ } ,
686
709
ast:: BlockExpr ( block_expr) => {
687
710
let ( constness, block) = match block_expr. modifier( ) {
688
711
Some ( ast:: BlockModifier :: Const ( _) ) => ( true , block_expr) ,
@@ -691,7 +714,7 @@ impl FunctionBody {
691
714
_ => continue ,
692
715
} ;
693
716
let expr = Some ( ast:: Expr :: BlockExpr ( block) ) ;
694
- ( constness, expr. clone( ) , infer_expr_opt( expr) , None , None )
717
+ ( constness, expr. clone( ) , infer_expr_opt( expr) )
695
718
} ,
696
719
ast:: Fn ( fn_) => {
697
720
let func = sema. to_def( & fn_) ?;
@@ -701,23 +724,25 @@ impl FunctionBody {
701
724
ret_ty = async_ret;
702
725
}
703
726
}
704
- ( fn_. const_token( ) . is_some( ) , fn_. body( ) . map( ast:: Expr :: BlockExpr ) , Some ( ret_ty) , fn_. generic_param_list( ) , fn_. where_clause( ) )
727
+ generic_param_list = fn_. generic_param_list( ) ;
728
+ where_clause = fn_. where_clause( ) ;
729
+ ( fn_. const_token( ) . is_some( ) , fn_. body( ) . map( ast:: Expr :: BlockExpr ) , Some ( ret_ty) )
705
730
} ,
706
731
ast:: Static ( statik) => {
707
- ( true , statik. body( ) , Some ( sema. to_def( & statik) ?. ty( sema. db) ) , None , None )
732
+ ( true , statik. body( ) , Some ( sema. to_def( & statik) ?. ty( sema. db) ) )
708
733
} ,
709
734
ast:: ConstArg ( ca) => {
710
- ( true , ca. expr( ) , infer_expr_opt( ca. expr( ) ) , None , None )
735
+ ( true , ca. expr( ) , infer_expr_opt( ca. expr( ) ) )
711
736
} ,
712
737
ast:: Const ( konst) => {
713
- ( true , konst. body( ) , Some ( sema. to_def( & konst) ?. ty( sema. db) ) , None , None )
738
+ ( true , konst. body( ) , Some ( sema. to_def( & konst) ?. ty( sema. db) ) )
714
739
} ,
715
740
ast:: ConstParam ( cp) => {
716
- ( true , cp. default_val( ) , Some ( sema. to_def( & cp) ?. ty( sema. db) ) , None , None )
741
+ ( true , cp. default_val( ) , Some ( sema. to_def( & cp) ?. ty( sema. db) ) )
717
742
} ,
718
743
ast:: ConstBlockPat ( cbp) => {
719
744
let expr = cbp. block_expr( ) . map( ast:: Expr :: BlockExpr ) ;
720
- ( true , expr. clone( ) , infer_expr_opt( expr) , None , None )
745
+ ( true , expr. clone( ) , infer_expr_opt( expr) )
721
746
} ,
722
747
ast:: Variant ( __) => return None ,
723
748
ast:: Meta ( __) => return None ,
@@ -1320,8 +1345,7 @@ fn format_function(
1320
1345
let const_kw = if fun. mods . is_const { "const " } else { "" } ;
1321
1346
let async_kw = if fun. control_flow . is_async { "async " } else { "" } ;
1322
1347
let unsafe_kw = if fun. control_flow . is_unsafe { "unsafe " } else { "" } ;
1323
- let generic_params = format_generic_param_list ( fun) ;
1324
- let where_clause = format_where_clause ( fun) ;
1348
+ let ( generic_params, where_clause) = format_generic_params_and_where_clause ( ctx, fun) ;
1325
1349
match ctx. config . snippet_cap {
1326
1350
Some ( _) => format_to ! (
1327
1351
fn_def,
@@ -1356,13 +1380,50 @@ fn format_function(
1356
1380
fn_def
1357
1381
}
1358
1382
1359
- fn format_generic_param_list ( fun : & Function ) -> String {
1383
+ fn format_generic_params_and_where_clause ( ctx : & AssistContext , fun : & Function ) -> ( String , String ) {
1384
+ ( format_generic_param_list ( fun, ctx) , format_where_clause ( fun) )
1385
+ }
1386
+
1387
+ fn format_generic_param_list ( fun : & Function , ctx : & AssistContext ) -> String {
1388
+ let type_params_in_descendant_paths =
1389
+ fun. body . descendant_paths ( ) . filter_map ( |it| match ctx. sema . resolve_path ( & it) {
1390
+ Some ( PathResolution :: TypeParam ( type_param) ) => Some ( type_param) ,
1391
+ _ => None ,
1392
+ } ) ;
1393
+
1394
+ let type_params_in_params = fun. params . iter ( ) . filter_map ( |p| p. ty . as_type_param ( ctx. db ( ) ) ) ;
1395
+
1396
+ let used_type_params: Vec < TypeParam > =
1397
+ type_params_in_descendant_paths. chain ( type_params_in_params) . collect ( ) ;
1398
+
1360
1399
match & fun. mods . generic_param_list {
1361
- Some ( it) => format ! ( "{}" , it) ,
1400
+ Some ( list) => {
1401
+ let filtered_generic_params = filter_generic_param_list ( ctx, list, used_type_params) ;
1402
+ if filtered_generic_params. is_empty ( ) {
1403
+ return "" . to_string ( ) ;
1404
+ }
1405
+ format ! ( "{}" , make:: generic_param_list( filtered_generic_params) )
1406
+ }
1362
1407
None => "" . to_string ( ) ,
1363
1408
}
1364
1409
}
1365
1410
1411
+ fn filter_generic_param_list (
1412
+ ctx : & AssistContext ,
1413
+ list : & ast:: GenericParamList ,
1414
+ used_type_params : Vec < TypeParam > ,
1415
+ ) -> Vec < ast:: GenericParam > {
1416
+ list. generic_params ( )
1417
+ . filter ( |p| match p {
1418
+ ast:: GenericParam :: ConstParam ( _) | ast:: GenericParam :: LifetimeParam ( _) => true ,
1419
+ ast:: GenericParam :: TypeParam ( type_param) => match & ctx. sema . to_def ( type_param) {
1420
+ Some ( def) => used_type_params. iter ( ) . contains ( def) ,
1421
+ _ => false ,
1422
+ } ,
1423
+ } )
1424
+ . collect ( )
1425
+ }
1426
+
1366
1427
fn format_where_clause ( fun : & Function ) -> String {
1367
1428
match & fun. mods . where_clause {
1368
1429
Some ( it) => format ! ( " {}" , it) ,
@@ -4763,6 +4824,73 @@ fn $0fun_name<T: Debug>(i: T) {
4763
4824
) ;
4764
4825
}
4765
4826
4827
+ #[ test]
4828
+ fn preserve_generics_from_body ( ) {
4829
+ check_assist (
4830
+ extract_function,
4831
+ r#"
4832
+ fn func<T: Default>() -> T {
4833
+ $0T::default()$0
4834
+ }
4835
+ "# ,
4836
+ r#"
4837
+ fn func<T: Default>() -> T {
4838
+ fun_name()
4839
+ }
4840
+
4841
+ fn $0fun_name<T: Default>() -> T {
4842
+ T::default()
4843
+ }
4844
+ "# ,
4845
+ ) ;
4846
+ }
4847
+
4848
+ #[ test]
4849
+ fn filter_unused_generics ( ) {
4850
+ check_assist (
4851
+ extract_function,
4852
+ r#"
4853
+ fn func<T: Debug, U: Copy>(i: T, u: U) {
4854
+ bar(u);
4855
+ $0foo(i);$0
4856
+ }
4857
+ "# ,
4858
+ r#"
4859
+ fn func<T: Debug, U: Copy>(i: T, u: U) {
4860
+ bar(u);
4861
+ fun_name(i);
4862
+ }
4863
+
4864
+ fn $0fun_name<T: Debug>(i: T) {
4865
+ foo(i);
4866
+ }
4867
+ "# ,
4868
+ ) ;
4869
+ }
4870
+
4871
+ #[ test]
4872
+ fn empty_generic_param_list ( ) {
4873
+ check_assist (
4874
+ extract_function,
4875
+ r#"
4876
+ fn func<T: Debug>(t: T, i: u32) {
4877
+ bar(t);
4878
+ $0foo(i);$0
4879
+ }
4880
+ "# ,
4881
+ r#"
4882
+ fn func<T: Debug>(t: T, i: u32) {
4883
+ bar(t);
4884
+ fun_name(i);
4885
+ }
4886
+
4887
+ fn $0fun_name(i: u32) {
4888
+ foo(i);
4889
+ }
4890
+ "# ,
4891
+ ) ;
4892
+ }
4893
+
4766
4894
#[ test]
4767
4895
fn preserve_where_clause ( ) {
4768
4896
check_assist (
0 commit comments