Skip to content

Commit e36a31b

Browse files
committed
Port closure kind deduction logic from rustc
1 parent 7637141 commit e36a31b

File tree

3 files changed

+205
-10
lines changed

3 files changed

+205
-10
lines changed

crates/hir-ty/src/infer/closure.rs

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::{cmp, convert::Infallible, mem};
55
use chalk_ir::{
66
cast::Cast,
77
fold::{FallibleTypeFolder, TypeFoldable},
8-
AliasEq, AliasTy, BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, WhereClause,
8+
BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind,
99
};
1010
use either::Either;
1111
use hir_def::{
@@ -22,13 +22,14 @@ use stdx::never;
2222

2323
use crate::{
2424
db::{HirDatabase, InternedClosure},
25-
from_placeholder_idx, make_binders,
25+
from_chalk_trait_id, from_placeholder_idx, make_binders,
2626
mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
2727
static_lifetime, to_chalk_trait_id,
2828
traits::FnTrait,
29-
utils::{self, generics, Generics},
30-
Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, FnAbi, FnPointer,
31-
FnSig, Interner, Substitution, Ty, TyExt,
29+
utils::{self, elaborate_clause_supertraits, generics, Generics},
30+
Adjust, Adjustment, AliasEq, AliasTy, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy,
31+
DynTyExt, FnAbi, FnPointer, FnSig, Interner, OpaqueTy, ProjectionTyExt, Substitution, Ty,
32+
TyExt, WhereClause,
3233
};
3334

3435
use super::{Expectation, InferenceContext};
@@ -47,6 +48,15 @@ impl InferenceContext<'_> {
4748
None => return,
4849
};
4950

51+
if let TyKind::Closure(closure_id, _) = closure_ty.kind(Interner) {
52+
if let Some(closure_kind) = self.deduce_closure_kind_from_expectations(&expected_ty) {
53+
self.result
54+
.closure_info
55+
.entry(*closure_id)
56+
.or_insert_with(|| (Vec::new(), closure_kind));
57+
}
58+
}
59+
5060
// Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
5161
let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
5262

@@ -65,6 +75,60 @@ impl InferenceContext<'_> {
6575
}
6676
}
6777

78+
// Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`.
79+
// Might need to port closure sig deductions too.
80+
fn deduce_closure_kind_from_expectations(&mut self, expected_ty: &Ty) -> Option<FnTrait> {
81+
match expected_ty.kind(Interner) {
82+
TyKind::Alias(AliasTy::Opaque(OpaqueTy { .. })) | TyKind::OpaqueType(..) => {
83+
let clauses = expected_ty
84+
.impl_trait_bounds(self.db)
85+
.into_iter()
86+
.flatten()
87+
.map(|b| b.into_value_and_skipped_binders().0);
88+
self.deduce_closure_kind_from_predicate_clauses(clauses)
89+
}
90+
TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| {
91+
self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id))
92+
}),
93+
TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => {
94+
let clauses = self.clauses_for_self_ty(*ty);
95+
self.deduce_closure_kind_from_predicate_clauses(clauses.into_iter())
96+
}
97+
TyKind::Function(_) => Some(FnTrait::Fn),
98+
_ => None,
99+
}
100+
}
101+
102+
fn deduce_closure_kind_from_predicate_clauses(
103+
&self,
104+
clauses: impl DoubleEndedIterator<Item = WhereClause>,
105+
) -> Option<FnTrait> {
106+
let mut expected_kind = None;
107+
108+
for clause in elaborate_clause_supertraits(self.db, clauses.rev()) {
109+
let trait_id = match clause {
110+
WhereClause::AliasEq(AliasEq {
111+
alias: AliasTy::Projection(projection), ..
112+
}) => Some(projection.trait_(self.db)),
113+
WhereClause::Implemented(trait_ref) => {
114+
Some(from_chalk_trait_id(trait_ref.trait_id))
115+
}
116+
_ => None,
117+
};
118+
if let Some(closure_kind) =
119+
trait_id.and_then(|trait_id| self.fn_trait_kind_from_trait_id(trait_id))
120+
{
121+
// `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min`
122+
expected_kind = Some(
123+
expected_kind
124+
.map_or_else(|| closure_kind, |current| cmp::max(current, closure_kind)),
125+
);
126+
}
127+
}
128+
129+
expected_kind
130+
}
131+
68132
fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> {
69133
// Search for a predicate like `<$self as FnX<Args>>::Output == Ret`
70134

@@ -111,6 +175,18 @@ impl InferenceContext<'_> {
111175

112176
None
113177
}
178+
179+
fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> {
180+
utils::fn_traits(self.db.upcast(), self.owner.module(self.db.upcast()).krate())
181+
.enumerate()
182+
.find_map(|(i, t)| (t == trait_id).then_some(i))
183+
.map(|i| match i {
184+
0 => FnTrait::Fn,
185+
1 => FnTrait::FnMut,
186+
2 => FnTrait::FnOnce,
187+
_ => unreachable!(),
188+
})
189+
}
114190
}
115191

116192
// The below functions handle capture and closure kind (Fn, FnMut, ..)
@@ -962,8 +1038,14 @@ impl InferenceContext<'_> {
9621038
}
9631039
}
9641040
self.restrict_precision_for_unsafe();
965-
// closure_kind should be done before adjust_for_move_closure
966-
let closure_kind = self.closure_kind();
1041+
// `closure_kind` should be done before adjust_for_move_closure
1042+
// If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does.
1043+
// rustc also does diagnostics here if the latter is not a subtype of the former.
1044+
let closure_kind = self
1045+
.result
1046+
.closure_info
1047+
.get(&closure)
1048+
.map_or_else(|| self.closure_kind(), |info| info.1);
9671049
match capture_by {
9681050
CaptureBy::Value => self.adjust_for_move_closure(),
9691051
CaptureBy::Ref => (),

crates/hir-ty/src/infer/unify.rs

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ use chalk_solve::infer::ParameterEnaVariableExt;
1010
use either::Either;
1111
use ena::unify::UnifyKey;
1212
use hir_expand::name;
13+
use smallvec::SmallVec;
1314
use triomphe::Arc;
1415

1516
use super::{InferOk, InferResult, InferenceContext, TypeError};
1617
use crate::{
1718
consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime,
1819
to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue,
19-
DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar,
20-
Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution,
21-
TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind,
20+
DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal, GoalData, Guidance, InEnvironment,
21+
InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution,
22+
Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind, WhereClause,
2223
};
2324

2425
impl InferenceContext<'_> {
@@ -31,6 +32,72 @@ impl InferenceContext<'_> {
3132
{
3233
self.table.canonicalize(t)
3334
}
35+
36+
pub(super) fn clauses_for_self_ty(
37+
&mut self,
38+
self_ty: InferenceVar,
39+
) -> SmallVec<[WhereClause; 4]> {
40+
self.table.resolve_obligations_as_possible();
41+
42+
let root = self.table.var_unification_table.inference_var_root(self_ty);
43+
let pending_obligations = mem::take(&mut self.table.pending_obligations);
44+
let obligations = pending_obligations
45+
.iter()
46+
.filter_map(|obligation| match obligation.value.value.goal.data(Interner) {
47+
GoalData::DomainGoal(DomainGoal::Holds(
48+
clause @ WhereClause::AliasEq(AliasEq {
49+
alias: AliasTy::Projection(projection),
50+
..
51+
}),
52+
)) => {
53+
let projection_self = projection.self_type_parameter(self.db);
54+
let uncanonical = chalk_ir::Substitute::apply(
55+
&obligation.free_vars,
56+
projection_self,
57+
Interner,
58+
);
59+
if matches!(
60+
self.resolve_ty_shallow(&uncanonical).kind(Interner),
61+
TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
62+
) {
63+
Some(chalk_ir::Substitute::apply(
64+
&obligation.free_vars,
65+
clause.clone(),
66+
Interner,
67+
))
68+
} else {
69+
None
70+
}
71+
}
72+
GoalData::DomainGoal(DomainGoal::Holds(
73+
clause @ WhereClause::Implemented(trait_ref),
74+
)) => {
75+
let trait_ref_self = trait_ref.self_type_parameter(Interner);
76+
let uncanonical = chalk_ir::Substitute::apply(
77+
&obligation.free_vars,
78+
trait_ref_self,
79+
Interner,
80+
);
81+
if matches!(
82+
self.resolve_ty_shallow(&uncanonical).kind(Interner),
83+
TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
84+
) {
85+
Some(chalk_ir::Substitute::apply(
86+
&obligation.free_vars,
87+
clause.clone(),
88+
Interner,
89+
))
90+
} else {
91+
None
92+
}
93+
}
94+
_ => None,
95+
})
96+
.collect();
97+
self.table.pending_obligations = pending_obligations;
98+
99+
obligations
100+
}
34101
}
35102

36103
#[derive(Debug, Clone)]

crates/hir-ty/src/utils.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,52 @@ impl Iterator for SuperTraits<'_> {
112112
}
113113
}
114114

115+
pub(super) fn elaborate_clause_supertraits(
116+
db: &dyn HirDatabase,
117+
clauses: impl Iterator<Item = WhereClause>,
118+
) -> ClauseElaborator<'_> {
119+
let mut elaborator = ClauseElaborator { db, stack: Vec::new(), seen: FxHashSet::default() };
120+
elaborator.extend_deduped(clauses);
121+
122+
elaborator
123+
}
124+
125+
pub(super) struct ClauseElaborator<'a> {
126+
db: &'a dyn HirDatabase,
127+
stack: Vec<WhereClause>,
128+
seen: FxHashSet<WhereClause>,
129+
}
130+
131+
impl<'a> ClauseElaborator<'a> {
132+
fn extend_deduped(&mut self, clauses: impl IntoIterator<Item = WhereClause>) {
133+
self.stack.extend(clauses.into_iter().filter(|c| self.seen.insert(c.clone())))
134+
}
135+
136+
fn elaborate_supertrait(&mut self, clause: &WhereClause) {
137+
if let WhereClause::Implemented(trait_ref) = clause {
138+
direct_super_trait_refs(self.db, trait_ref, |t| {
139+
let clause = WhereClause::Implemented(t);
140+
if self.seen.insert(clause.clone()) {
141+
self.stack.push(clause);
142+
}
143+
});
144+
}
145+
}
146+
}
147+
148+
impl Iterator for ClauseElaborator<'_> {
149+
type Item = WhereClause;
150+
151+
fn next(&mut self) -> Option<Self::Item> {
152+
if let Some(next) = self.stack.pop() {
153+
self.elaborate_supertrait(&next);
154+
Some(next)
155+
} else {
156+
None
157+
}
158+
}
159+
}
160+
115161
fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId, cb: impl FnMut(TraitId)) {
116162
let resolver = trait_.resolver(db);
117163
let generic_params = db.generic_params(trait_.into());

0 commit comments

Comments
 (0)