Skip to content

Commit 6d32e0f

Browse files
committed
fix: Support generics in extract_function assist
This change attempts to resolve issue #7637: Extract into Function does not create a generic function with constraints when extracting generic code. In `FunctionBody::analyze_container`, when the ancestor matches `ast::Fn`, we can perserve both the `generic_param_list` and the `where_clause`. These can then be included in the newly extracted function output via `format_function`. From what I can tell, the only other ancestor type that could potentially have a generic param list would be `ast::ClosureExpr`. In this case, we perserve the `generic_param_list`, but no where clause is ever present. In this inital implementation, all the generic params and where clauses from the parent function will be copied to the newly extracted function. An obvious improvement would be to filter this output in some way to only include generic parameters that are actually used in the function body. I'm not experienced enough with this codebase to know how challenging doing this kind of filtration would be. I don't believe this implementation will work in contexts where the generic parameters and where clauses are defined multiple layers above the function being extracted, such as with nested function declarations. Resolving this seems like another obvious improvement, but one that will potentially require more significant changes to the structure of `analyze_container` that I wasn't comfortable trying to make as a first change.
1 parent 65874df commit 6d32e0f

File tree

1 file changed

+86
-15
lines changed

1 file changed

+86
-15
lines changed

crates/ide-assists/src/handlers/extract_function.rs

+86-15
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use syntax::{
1818
ast::{
1919
self,
2020
edit::{AstNodeEdit, IndentLevel},
21-
AstNode,
21+
AstNode, HasGenericParams,
2222
},
2323
match_ast, ted, SyntaxElement,
2424
SyntaxKind::{self, COMMENT},
@@ -266,6 +266,8 @@ struct ContainerInfo {
266266
parent_loop: Option<SyntaxNode>,
267267
/// The function's return type, const's type etc.
268268
ret_type: Option<hir::Type>,
269+
generic_param_list: Option<ast::GenericParamList>,
270+
where_clause: Option<ast::WhereClause>,
269271
}
270272

271273
/// Control flow that is exported from extracted function
@@ -676,11 +678,11 @@ impl FunctionBody {
676678
parent_loop.get_or_insert(loop_.syntax().clone());
677679
}
678680
};
679-
let (is_const, expr, ty) = loop {
681+
let (is_const, expr, ty, generic_param_list, where_clause) = loop {
680682
let anc = ancestors.next()?;
681683
break match_ast! {
682684
match anc {
683-
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body())),
685+
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body()), closure.generic_param_list(), None),
684686
ast::BlockExpr(block_expr) => {
685687
let (constness, block) = match block_expr.modifier() {
686688
Some(ast::BlockModifier::Const(_)) => (true, block_expr),
@@ -689,7 +691,7 @@ impl FunctionBody {
689691
_ => continue,
690692
};
691693
let expr = Some(ast::Expr::BlockExpr(block));
692-
(constness, expr.clone(), infer_expr_opt(expr))
694+
(constness, expr.clone(), infer_expr_opt(expr), None, None)
693695
},
694696
ast::Fn(fn_) => {
695697
let func = sema.to_def(&fn_)?;
@@ -699,23 +701,23 @@ impl FunctionBody {
699701
ret_ty = async_ret;
700702
}
701703
}
702-
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty))
704+
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty), fn_.generic_param_list(), fn_.where_clause())
703705
},
704706
ast::Static(statik) => {
705-
(true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db)))
707+
(true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db)), None, None)
706708
},
707709
ast::ConstArg(ca) => {
708-
(true, ca.expr(), infer_expr_opt(ca.expr()))
710+
(true, ca.expr(), infer_expr_opt(ca.expr()), None, None)
709711
},
710712
ast::Const(konst) => {
711-
(true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db)))
713+
(true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db)), None, None)
712714
},
713715
ast::ConstParam(cp) => {
714-
(true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db)))
716+
(true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db)), None, None)
715717
},
716718
ast::ConstBlockPat(cbp) => {
717719
let expr = cbp.block_expr().map(ast::Expr::BlockExpr);
718-
(true, expr.clone(), infer_expr_opt(expr))
720+
(true, expr.clone(), infer_expr_opt(expr), None, None)
719721
},
720722
ast::Variant(__) => return None,
721723
ast::Meta(__) => return None,
@@ -743,7 +745,14 @@ impl FunctionBody {
743745
container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
744746
container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
745747
});
746-
Some(ContainerInfo { is_in_tail, is_const, parent_loop, ret_type: ty })
748+
Some(ContainerInfo {
749+
is_in_tail,
750+
is_const,
751+
parent_loop,
752+
ret_type: ty,
753+
generic_param_list,
754+
where_clause,
755+
})
747756
}
748757

749758
fn return_ty(&self, ctx: &AssistContext) -> Option<RetType> {
@@ -1311,26 +1320,32 @@ fn format_function(
13111320
let const_kw = if fun.mods.is_const { "const " } else { "" };
13121321
let async_kw = if fun.control_flow.is_async { "async " } else { "" };
13131322
let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
1323+
let generic_params = format_generic_param_list(fun);
1324+
let where_clause = format_where_clause(fun);
13141325
match ctx.config.snippet_cap {
13151326
Some(_) => format_to!(
13161327
fn_def,
1317-
"\n\n{}{}{}{}fn $0{}{}",
1328+
"\n\n{}{}{}{}fn $0{}{}{}{}",
13181329
new_indent,
13191330
const_kw,
13201331
async_kw,
13211332
unsafe_kw,
13221333
fun.name,
1323-
params
1334+
generic_params,
1335+
params,
1336+
where_clause
13241337
),
13251338
None => format_to!(
13261339
fn_def,
1327-
"\n\n{}{}{}{}fn {}{}",
1340+
"\n\n{}{}{}{}fn {}{}{}{}",
13281341
new_indent,
13291342
const_kw,
13301343
async_kw,
13311344
unsafe_kw,
13321345
fun.name,
1333-
params
1346+
generic_params,
1347+
params,
1348+
where_clause,
13341349
),
13351350
}
13361351
if let Some(ret_ty) = ret_ty {
@@ -1341,6 +1356,20 @@ fn format_function(
13411356
fn_def
13421357
}
13431358

1359+
fn format_generic_param_list(fun: &Function) -> String {
1360+
match &fun.mods.generic_param_list {
1361+
Some(it) => format!("{}", it),
1362+
None => "".to_string(),
1363+
}
1364+
}
1365+
1366+
fn format_where_clause(fun: &Function) -> String {
1367+
match &fun.mods.where_clause {
1368+
Some(it) => format!(" {}", it),
1369+
None => "".to_string(),
1370+
}
1371+
}
1372+
13441373
impl Function {
13451374
fn make_param_list(&self, ctx: &AssistContext, module: hir::Module) -> ast::ParamList {
13461375
let self_param = self.self_param.clone();
@@ -4709,6 +4738,48 @@ fn $0fun_name() {
47094738
/* a comment */
47104739
let x = 0;
47114740
}
4741+
"#,
4742+
);
4743+
}
4744+
4745+
#[test]
4746+
fn preserve_generics() {
4747+
check_assist(
4748+
extract_function,
4749+
r#"
4750+
fn func<T: Debug>(i: T) {
4751+
$0foo(i);$0
4752+
}
4753+
"#,
4754+
r#"
4755+
fn func<T: Debug>(i: T) {
4756+
fun_name(i);
4757+
}
4758+
4759+
fn $0fun_name<T: Debug>(i: T) {
4760+
foo(i);
4761+
}
4762+
"#,
4763+
);
4764+
}
4765+
4766+
#[test]
4767+
fn preserve_where_clause() {
4768+
check_assist(
4769+
extract_function,
4770+
r#"
4771+
fn func<T>(i: T) where T: Debug {
4772+
$0foo(i);$0
4773+
}
4774+
"#,
4775+
r#"
4776+
fn func<T>(i: T) where T: Debug {
4777+
fun_name(i);
4778+
}
4779+
4780+
fn $0fun_name<T>(i: T) where T: Debug {
4781+
foo(i);
4782+
}
47124783
"#,
47134784
);
47144785
}

0 commit comments

Comments
 (0)