Skip to content

Commit fdfa707

Browse files
authored
Merge pull request #19747 from Veykril/push-kqxvxrxozswr
fix: Fix `move_bounds` assists not working for lifetimes
2 parents aaefc26 + f9c83ed commit fdfa707

File tree

4 files changed

+111
-18
lines changed

4 files changed

+111
-18
lines changed

crates/hir-expand/src/builtin/derive_macro.rs

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Builtin derives.
22
3+
use either::Either;
34
use intern::sym;
45
use itertools::{Itertools, izip};
56
use parser::SyntaxKind;
@@ -1179,10 +1180,10 @@ fn coerce_pointee_expand(
11791180
};
11801181
new_predicates.push(
11811182
make::where_pred(
1182-
make::ty_path(make::path_from_segments(
1183+
Either::Right(make::ty_path(make::path_from_segments(
11831184
[make::path_segment(new_bounds_target)],
11841185
false,
1185-
)),
1186+
))),
11861187
new_bounds,
11871188
)
11881189
.clone_for_update(),
@@ -1245,7 +1246,9 @@ fn coerce_pointee_expand(
12451246
substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
12461247
})
12471248
});
1248-
new_predicates.push(make::where_pred(pred_target, new_bounds).clone_for_update());
1249+
new_predicates.push(
1250+
make::where_pred(Either::Right(pred_target), new_bounds).clone_for_update(),
1251+
);
12491252
}
12501253
}
12511254

@@ -1260,10 +1263,10 @@ fn coerce_pointee_expand(
12601263
// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
12611264
where_clause.add_predicate(
12621265
make::where_pred(
1263-
make::ty_path(make::path_from_segments(
1266+
Either::Right(make::ty_path(make::path_from_segments(
12641267
[make::path_segment(make::name_ref(&pointee_param_name.text()))],
12651268
false,
1266-
)),
1269+
))),
12671270
[make::type_bound(make::ty_path(make::path_from_segments(
12681271
[
12691272
make::path_segment(make::name_ref("core")),

crates/ide-assists/src/handlers/move_bounds.rs

+40-12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use either::Either;
12
use syntax::{
23
ast::{
34
self, AstNode, HasName, HasTypeBounds,
@@ -30,10 +31,11 @@ pub(crate) fn move_bounds_to_where_clause(
3031
) -> Option<()> {
3132
let type_param_list = ctx.find_node_at_offset::<ast::GenericParamList>()?;
3233

33-
let mut type_params = type_param_list.type_or_const_params();
34+
let mut type_params = type_param_list.generic_params();
3435
if type_params.all(|p| match p {
35-
ast::TypeOrConstParam::Type(t) => t.type_bound_list().is_none(),
36-
ast::TypeOrConstParam::Const(_) => true,
36+
ast::GenericParam::TypeParam(t) => t.type_bound_list().is_none(),
37+
ast::GenericParam::LifetimeParam(l) => l.type_bound_list().is_none(),
38+
ast::GenericParam::ConstParam(_) => true,
3739
}) {
3840
return None;
3941
}
@@ -53,20 +55,23 @@ pub(crate) fn move_bounds_to_where_clause(
5355
match parent {
5456
ast::Fn(it) => it.get_or_create_where_clause(),
5557
ast::Trait(it) => it.get_or_create_where_clause(),
58+
ast::TraitAlias(it) => it.get_or_create_where_clause(),
5659
ast::Impl(it) => it.get_or_create_where_clause(),
5760
ast::Enum(it) => it.get_or_create_where_clause(),
5861
ast::Struct(it) => it.get_or_create_where_clause(),
62+
ast::TypeAlias(it) => it.get_or_create_where_clause(),
5963
_ => return,
6064
}
6165
};
6266

63-
for toc_param in type_param_list.type_or_const_params() {
64-
let type_param = match toc_param {
65-
ast::TypeOrConstParam::Type(x) => x,
66-
ast::TypeOrConstParam::Const(_) => continue,
67+
for generic_param in type_param_list.generic_params() {
68+
let param: &dyn HasTypeBounds = match &generic_param {
69+
ast::GenericParam::TypeParam(t) => t,
70+
ast::GenericParam::LifetimeParam(l) => l,
71+
ast::GenericParam::ConstParam(_) => continue,
6772
};
68-
if let Some(tbl) = type_param.type_bound_list() {
69-
if let Some(predicate) = build_predicate(type_param) {
73+
if let Some(tbl) = param.type_bound_list() {
74+
if let Some(predicate) = build_predicate(generic_param) {
7075
where_clause.add_predicate(predicate)
7176
}
7277
tbl.remove()
@@ -76,9 +81,23 @@ pub(crate) fn move_bounds_to_where_clause(
7681
)
7782
}
7883

79-
fn build_predicate(param: ast::TypeParam) -> Option<ast::WherePred> {
80-
let path = make::ext::ident_path(&param.name()?.syntax().to_string());
81-
let predicate = make::where_pred(make::ty_path(path), param.type_bound_list()?.bounds());
84+
fn build_predicate(param: ast::GenericParam) -> Option<ast::WherePred> {
85+
let target = match &param {
86+
ast::GenericParam::TypeParam(t) => {
87+
Either::Right(make::ty_path(make::ext::ident_path(&t.name()?.to_string())))
88+
}
89+
ast::GenericParam::LifetimeParam(l) => Either::Left(l.lifetime()?),
90+
ast::GenericParam::ConstParam(_) => return None,
91+
};
92+
let predicate = make::where_pred(
93+
target,
94+
match param {
95+
ast::GenericParam::TypeParam(t) => t.type_bound_list()?,
96+
ast::GenericParam::LifetimeParam(l) => l.type_bound_list()?,
97+
ast::GenericParam::ConstParam(_) => return None,
98+
}
99+
.bounds(),
100+
);
82101
Some(predicate.clone_for_update())
83102
}
84103

@@ -123,4 +142,13 @@ mod tests {
123142
r#"struct Pair<T>(T, T) where T: u32;"#,
124143
);
125144
}
145+
146+
#[test]
147+
fn move_bounds_to_where_clause_trait() {
148+
check_assist(
149+
move_bounds_to_where_clause,
150+
r#"trait T<'a: 'static, $0T: u32> {}"#,
151+
r#"trait T<'a, T> where 'a: 'static, T: u32 {}"#,
152+
);
153+
}
126154
}

crates/syntax/src/ast/edit_in_place.rs

+61
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,67 @@ impl GenericParamsOwnerEdit for ast::Trait {
109109
}
110110
}
111111

112+
impl GenericParamsOwnerEdit for ast::TraitAlias {
113+
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
114+
match self.generic_param_list() {
115+
Some(it) => it,
116+
None => {
117+
let position = if let Some(name) = self.name() {
118+
Position::after(name.syntax)
119+
} else if let Some(trait_token) = self.trait_token() {
120+
Position::after(trait_token)
121+
} else {
122+
Position::last_child_of(self.syntax())
123+
};
124+
create_generic_param_list(position)
125+
}
126+
}
127+
}
128+
129+
fn get_or_create_where_clause(&self) -> ast::WhereClause {
130+
if self.where_clause().is_none() {
131+
let position = match self.semicolon_token() {
132+
Some(tok) => Position::before(tok),
133+
None => Position::last_child_of(self.syntax()),
134+
};
135+
create_where_clause(position);
136+
}
137+
self.where_clause().unwrap()
138+
}
139+
}
140+
141+
impl GenericParamsOwnerEdit for ast::TypeAlias {
142+
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
143+
match self.generic_param_list() {
144+
Some(it) => it,
145+
None => {
146+
let position = if let Some(name) = self.name() {
147+
Position::after(name.syntax)
148+
} else if let Some(trait_token) = self.type_token() {
149+
Position::after(trait_token)
150+
} else {
151+
Position::last_child_of(self.syntax())
152+
};
153+
create_generic_param_list(position)
154+
}
155+
}
156+
}
157+
158+
fn get_or_create_where_clause(&self) -> ast::WhereClause {
159+
if self.where_clause().is_none() {
160+
let position = match self.eq_token() {
161+
Some(tok) => Position::before(tok),
162+
None => match self.semicolon_token() {
163+
Some(tok) => Position::before(tok),
164+
None => Position::last_child_of(self.syntax()),
165+
},
166+
};
167+
create_where_clause(position);
168+
}
169+
self.where_clause().unwrap()
170+
}
171+
}
172+
112173
impl GenericParamsOwnerEdit for ast::Struct {
113174
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
114175
match self.generic_param_list() {

crates/syntax/src/ast/make.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
1414
mod quote;
1515

16+
use either::Either;
1617
use itertools::Itertools;
1718
use parser::{Edition, T};
1819
use rowan::NodeOrToken;
@@ -881,7 +882,7 @@ pub fn match_arm_list(arms: impl IntoIterator<Item = ast::MatchArm>) -> ast::Mat
881882
}
882883

883884
pub fn where_pred(
884-
path: ast::Type,
885+
path: Either<ast::Lifetime, ast::Type>,
885886
bounds: impl IntoIterator<Item = ast::TypeBound>,
886887
) -> ast::WherePred {
887888
let bounds = bounds.into_iter().join(" + ");

0 commit comments

Comments
 (0)