Skip to content

Commit 1f92965

Browse files
committed
Auto merge of #13209 - lowr:feat/inference-for-generator, r=Veykril
feat: type inference for generators This PR implements basic type inference for generator and yield expressions. Things not included in this PR: - Generator upvars and generator witnesses are not implemented. They are only used to determine auto trait impls, so basic type inference should be fine without them, but method resolutions with auto trait bounds may not be resolved correctly. Open questions: - I haven't (yet) implemented `HirDisplay` for `TyKind::Generator`, so generator types are just shown as "{{generator}}" (in tests, inlay hints, hovers, etc.), which is not really nice. How should we show them? - I added moderate amount of stuffs to minicore. I especially didn't want to add `impl<T> Deref for &T` and `impl<T> Deref for &mut T` exclusively for tests for generators; should I move them into the test fixtures or can they be placed in minicore? cc #4309
2 parents 73ab709 + 9ede5f0 commit 1f92965

File tree

14 files changed

+388
-34
lines changed

14 files changed

+388
-34
lines changed

crates/hir-def/src/body/lower.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ use crate::{
2929
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint},
3030
db::DefDatabase,
3131
expr::{
32-
dummy_expr_id, Array, BindingAnnotation, Expr, ExprId, FloatTypeWrapper, Label, LabelId,
33-
Literal, MatchArm, Pat, PatId, RecordFieldPat, RecordLitField, Statement,
32+
dummy_expr_id, Array, BindingAnnotation, ClosureKind, Expr, ExprId, FloatTypeWrapper,
33+
Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, RecordFieldPat, RecordLitField,
34+
Statement,
3435
},
3536
intern::Interned,
3637
item_scope::BuiltinShadowMode,
@@ -97,6 +98,7 @@ pub(super) fn lower(
9798
name_to_pat_grouping: Default::default(),
9899
is_lowering_inside_or_pat: false,
99100
is_lowering_assignee_expr: false,
101+
is_lowering_generator: false,
100102
}
101103
.collect(params, body)
102104
}
@@ -111,6 +113,7 @@ struct ExprCollector<'a> {
111113
name_to_pat_grouping: FxHashMap<Name, Vec<PatId>>,
112114
is_lowering_inside_or_pat: bool,
113115
is_lowering_assignee_expr: bool,
116+
is_lowering_generator: bool,
114117
}
115118

116119
impl ExprCollector<'_> {
@@ -358,6 +361,7 @@ impl ExprCollector<'_> {
358361
self.alloc_expr(Expr::Return { expr }, syntax_ptr)
359362
}
360363
ast::Expr::YieldExpr(e) => {
364+
self.is_lowering_generator = true;
361365
let expr = e.expr().map(|e| self.collect_expr(e));
362366
self.alloc_expr(Expr::Yield { expr }, syntax_ptr)
363367
}
@@ -459,13 +463,31 @@ impl ExprCollector<'_> {
459463
.ret_type()
460464
.and_then(|r| r.ty())
461465
.map(|it| Interned::new(TypeRef::from_ast(&self.ctx(), it)));
466+
467+
let prev_is_lowering_generator = self.is_lowering_generator;
468+
self.is_lowering_generator = false;
469+
462470
let body = self.collect_expr_opt(e.body());
471+
472+
let closure_kind = if self.is_lowering_generator {
473+
let movability = if e.static_token().is_some() {
474+
Movability::Static
475+
} else {
476+
Movability::Movable
477+
};
478+
ClosureKind::Generator(movability)
479+
} else {
480+
ClosureKind::Closure
481+
};
482+
self.is_lowering_generator = prev_is_lowering_generator;
483+
463484
self.alloc_expr(
464485
Expr::Closure {
465486
args: args.into(),
466487
arg_types: arg_types.into(),
467488
ret_type,
468489
body,
490+
closure_kind,
469491
},
470492
syntax_ptr,
471493
)

crates/hir-def/src/body/pretty.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::fmt::{self, Write};
55
use syntax::ast::HasName;
66

77
use crate::{
8-
expr::{Array, BindingAnnotation, Literal, Statement},
8+
expr::{Array, BindingAnnotation, ClosureKind, Literal, Movability, Statement},
99
pretty::{print_generic_args, print_path, print_type_ref},
1010
type_ref::TypeRef,
1111
};
@@ -362,7 +362,10 @@ impl<'a> Printer<'a> {
362362
self.print_expr(*index);
363363
w!(self, "]");
364364
}
365-
Expr::Closure { args, arg_types, ret_type, body } => {
365+
Expr::Closure { args, arg_types, ret_type, body, closure_kind } => {
366+
if let ClosureKind::Generator(Movability::Static) = closure_kind {
367+
w!(self, "static ");
368+
}
366369
w!(self, "|");
367370
for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() {
368371
if i != 0 {

crates/hir-def/src/expr.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ pub enum Expr {
198198
arg_types: Box<[Option<Interned<TypeRef>>]>,
199199
ret_type: Option<Interned<TypeRef>>,
200200
body: ExprId,
201+
closure_kind: ClosureKind,
201202
},
202203
Tuple {
203204
exprs: Box<[ExprId]>,
@@ -211,6 +212,18 @@ pub enum Expr {
211212
Underscore,
212213
}
213214

215+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
216+
pub enum ClosureKind {
217+
Closure,
218+
Generator(Movability),
219+
}
220+
221+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
222+
pub enum Movability {
223+
Static,
224+
Movable,
225+
}
226+
214227
#[derive(Debug, Clone, Eq, PartialEq)]
215228
pub enum Array {
216229
ElementList { elements: Box<[ExprId]>, is_assignee_expr: bool },

crates/hir-ty/src/builder.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ use chalk_ir::{
99
AdtId, BoundVar, DebruijnIndex, Scalar,
1010
};
1111
use hir_def::{
12-
builtin_type::BuiltinType, generics::TypeOrConstParamData, ConstParamId, GenericDefId, TraitId,
13-
TypeAliasId,
12+
builtin_type::BuiltinType, generics::TypeOrConstParamData, ConstParamId, DefWithBodyId,
13+
GenericDefId, TraitId, TypeAliasId,
1414
};
1515
use smallvec::SmallVec;
1616

@@ -205,6 +205,38 @@ impl TyBuilder<()> {
205205
)
206206
}
207207

208+
/// Creates a `TyBuilder` to build `Substitution` for a generator defined in `parent`.
209+
///
210+
/// A generator's substitution consists of:
211+
/// - generic parameters in scope on `parent`
212+
/// - resume type of generator
213+
/// - yield type of generator ([`Generator::Yield`](std::ops::Generator::Yield))
214+
/// - return type of generator ([`Generator::Return`](std::ops::Generator::Return))
215+
/// in this order.
216+
///
217+
/// This method prepopulates the builder with placeholder substitution of `parent`, so you
218+
/// should only push exactly 3 `GenericArg`s before building.
219+
pub fn subst_for_generator(db: &dyn HirDatabase, parent: DefWithBodyId) -> TyBuilder<()> {
220+
let parent_subst = match parent.as_generic_def_id() {
221+
Some(parent) => generics(db.upcast(), parent).placeholder_subst(db),
222+
// Static initializers *may* contain generators.
223+
None => Substitution::empty(Interner),
224+
};
225+
let builder = TyBuilder::new(
226+
(),
227+
parent_subst
228+
.iter(Interner)
229+
.map(|arg| match arg.constant(Interner) {
230+
Some(c) => ParamKind::Const(c.data(Interner).ty.clone()),
231+
None => ParamKind::Type,
232+
})
233+
// These represent resume type, yield type, and return type of generator.
234+
.chain(std::iter::repeat(ParamKind::Type).take(3))
235+
.collect(),
236+
);
237+
builder.use_parent_substs(&parent_subst)
238+
}
239+
208240
pub fn build(self) -> Substitution {
209241
let ((), subst) = self.build_internal();
210242
subst

crates/hir-ty/src/chalk_db.rs

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use chalk_solve::rust_ir::{self, OpaqueTyDatumBound, WellKnownTrait};
1111

1212
use base_db::CrateId;
1313
use hir_def::{
14+
expr::Movability,
1415
lang_item::{lang_attr, LangItemTarget},
1516
AssocItemId, GenericDefId, HasModule, ItemContainerId, Lookup, ModuleId, TypeAliasId,
1617
};
@@ -26,9 +27,9 @@ use crate::{
2627
to_assoc_type_id, to_chalk_trait_id,
2728
traits::ChalkContext,
2829
utils::generics,
29-
AliasEq, AliasTy, BoundVar, CallableDefId, DebruijnIndex, FnDefId, Interner, ProjectionTy,
30-
ProjectionTyExt, QuantifiedWhereClause, Substitution, TraitRef, TraitRefExt, Ty, TyBuilder,
31-
TyExt, TyKind, WhereClause,
30+
wrap_empty_binders, AliasEq, AliasTy, BoundVar, CallableDefId, DebruijnIndex, FnDefId,
31+
Interner, ProjectionTy, ProjectionTyExt, QuantifiedWhereClause, Substitution, TraitRef,
32+
TraitRefExt, Ty, TyBuilder, TyExt, TyKind, WhereClause,
3233
};
3334

3435
pub(crate) type AssociatedTyDatum = chalk_solve::rust_ir::AssociatedTyDatum<Interner>;
@@ -372,17 +373,63 @@ impl<'a> chalk_solve::RustIrDatabase<Interner> for ChalkContext<'a> {
372373
}
373374
fn generator_datum(
374375
&self,
375-
_: chalk_ir::GeneratorId<Interner>,
376+
id: chalk_ir::GeneratorId<Interner>,
376377
) -> std::sync::Arc<chalk_solve::rust_ir::GeneratorDatum<Interner>> {
377-
// FIXME
378-
unimplemented!()
378+
let (parent, expr) = self.db.lookup_intern_generator(id.into());
379+
380+
// We fill substitution with unknown type, because we only need to know whether the generic
381+
// params are types or consts to build `Binders` and those being filled up are for
382+
// `resume_type`, `yield_type`, and `return_type` of the generator in question.
383+
let subst = TyBuilder::subst_for_generator(self.db, parent).fill_with_unknown().build();
384+
385+
let len = subst.len(Interner);
386+
let input_output = rust_ir::GeneratorInputOutputDatum {
387+
resume_type: TyKind::BoundVar(BoundVar::new(DebruijnIndex::INNERMOST, len - 3))
388+
.intern(Interner),
389+
yield_type: TyKind::BoundVar(BoundVar::new(DebruijnIndex::INNERMOST, len - 2))
390+
.intern(Interner),
391+
return_type: TyKind::BoundVar(BoundVar::new(DebruijnIndex::INNERMOST, len - 1))
392+
.intern(Interner),
393+
// FIXME: calculate upvars
394+
upvars: vec![],
395+
};
396+
397+
let it = subst
398+
.iter(Interner)
399+
.map(|it| it.constant(Interner).map(|c| c.data(Interner).ty.clone()));
400+
let input_output = crate::make_type_and_const_binders(it, input_output);
401+
402+
let movability = match self.db.body(parent)[expr] {
403+
hir_def::expr::Expr::Closure {
404+
closure_kind: hir_def::expr::ClosureKind::Generator(movability),
405+
..
406+
} => movability,
407+
_ => unreachable!("non generator expression interned as generator"),
408+
};
409+
let movability = match movability {
410+
Movability::Static => rust_ir::Movability::Static,
411+
Movability::Movable => rust_ir::Movability::Movable,
412+
};
413+
414+
Arc::new(rust_ir::GeneratorDatum { movability, input_output })
379415
}
380416
fn generator_witness_datum(
381417
&self,
382-
_: chalk_ir::GeneratorId<Interner>,
418+
id: chalk_ir::GeneratorId<Interner>,
383419
) -> std::sync::Arc<chalk_solve::rust_ir::GeneratorWitnessDatum<Interner>> {
384-
// FIXME
385-
unimplemented!()
420+
// FIXME: calculate inner types
421+
let inner_types =
422+
rust_ir::GeneratorWitnessExistential { types: wrap_empty_binders(vec![]) };
423+
424+
let (parent, _) = self.db.lookup_intern_generator(id.into());
425+
// See the comment in `generator_datum()` for unknown types.
426+
let subst = TyBuilder::subst_for_generator(self.db, parent).fill_with_unknown().build();
427+
let it = subst
428+
.iter(Interner)
429+
.map(|it| it.constant(Interner).map(|c| c.data(Interner).ty.clone()));
430+
let inner_types = crate::make_type_and_const_binders(it, inner_types);
431+
432+
Arc::new(rust_ir::GeneratorWitnessDatum { inner_types })
386433
}
387434

388435
fn unification_database(&self) -> &dyn chalk_ir::UnificationDatabase<Interner> {

crates/hir-ty/src/db.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
120120
fn intern_impl_trait_id(&self, id: ImplTraitId) -> InternedOpaqueTyId;
121121
#[salsa::interned]
122122
fn intern_closure(&self, id: (DefWithBodyId, ExprId)) -> InternedClosureId;
123+
#[salsa::interned]
124+
fn intern_generator(&self, id: (DefWithBodyId, ExprId)) -> InternedGeneratorId;
123125

124126
#[salsa::invoke(chalk_db::associated_ty_data_query)]
125127
fn associated_ty_data(&self, id: chalk_db::AssocTypeId) -> Arc<chalk_db::AssociatedTyDatum>;
@@ -233,6 +235,10 @@ impl_intern_key!(InternedOpaqueTyId);
233235
pub struct InternedClosureId(salsa::InternId);
234236
impl_intern_key!(InternedClosureId);
235237

238+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
239+
pub struct InternedGeneratorId(salsa::InternId);
240+
impl_intern_key!(InternedGeneratorId);
241+
236242
/// This exists just for Chalk, because Chalk just has a single `FnDefId` where
237243
/// we have different IDs for struct and enum variant constructors.
238244
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]

crates/hir-ty/src/display.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use hir_def::{
2020
};
2121
use hir_expand::{hygiene::Hygiene, name::Name};
2222
use itertools::Itertools;
23+
use smallvec::SmallVec;
2324
use syntax::SmolStr;
2425

2526
use crate::{
@@ -221,6 +222,7 @@ pub enum DisplaySourceCodeError {
221222
PathNotFound,
222223
UnknownType,
223224
Closure,
225+
Generator,
224226
}
225227

226228
pub enum HirDisplayError {
@@ -783,7 +785,34 @@ impl HirDisplay for Ty {
783785
write!(f, "{{unknown}}")?;
784786
}
785787
TyKind::InferenceVar(..) => write!(f, "_")?,
786-
TyKind::Generator(..) => write!(f, "{{generator}}")?,
788+
TyKind::Generator(_, subst) => {
789+
if f.display_target.is_source_code() {
790+
return Err(HirDisplayError::DisplaySourceCodeError(
791+
DisplaySourceCodeError::Generator,
792+
));
793+
}
794+
795+
let subst = subst.as_slice(Interner);
796+
let a: Option<SmallVec<[&Ty; 3]>> = subst
797+
.get(subst.len() - 3..)
798+
.map(|args| args.iter().map(|arg| arg.ty(Interner)).collect())
799+
.flatten();
800+
801+
if let Some([resume_ty, yield_ty, ret_ty]) = a.as_deref() {
802+
write!(f, "|")?;
803+
resume_ty.hir_fmt(f)?;
804+
write!(f, "|")?;
805+
806+
write!(f, " yields ")?;
807+
yield_ty.hir_fmt(f)?;
808+
809+
write!(f, " -> ")?;
810+
ret_ty.hir_fmt(f)?;
811+
} else {
812+
// This *should* be unreachable, but fallback just in case.
813+
write!(f, "{{generator}}")?;
814+
}
815+
}
787816
TyKind::GeneratorWitness(..) => write!(f, "{{generator witness}}")?,
788817
}
789818
Ok(())

crates/hir-ty/src/infer.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ pub struct InferenceResult {
339339
/// unresolved or missing subpatterns or subpatterns of mismatched types.
340340
pub type_of_pat: ArenaMap<PatId, Ty>,
341341
type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch>,
342-
/// Interned Unknown to return references to.
342+
/// Interned common types to return references to.
343343
standard_types: InternedStandardTypes,
344344
/// Stores the types which were implicitly dereferenced in pattern binding modes.
345345
pub pat_adjustments: FxHashMap<PatId, Vec<Ty>>,
@@ -419,6 +419,8 @@ pub(crate) struct InferenceContext<'a> {
419419
/// closures, but currently this is the only field that will change there,
420420
/// so it doesn't make sense.
421421
return_ty: Ty,
422+
/// The resume type and the yield type, respectively, of the generator being inferred.
423+
resume_yield_tys: Option<(Ty, Ty)>,
422424
diverges: Diverges,
423425
breakables: Vec<BreakableContext>,
424426
}
@@ -483,6 +485,7 @@ impl<'a> InferenceContext<'a> {
483485
table: unify::InferenceTable::new(db, trait_env.clone()),
484486
trait_env,
485487
return_ty: TyKind::Error.intern(Interner), // set in collect_fn_signature
488+
resume_yield_tys: None,
486489
db,
487490
owner,
488491
body,

crates/hir-ty/src/infer/closure.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::{
1212
use super::{Expectation, InferenceContext};
1313

1414
impl InferenceContext<'_> {
15+
// This function handles both closures and generators.
1516
pub(super) fn deduce_closure_type_from_expectations(
1617
&mut self,
1718
closure_expr: ExprId,
@@ -27,6 +28,11 @@ impl InferenceContext<'_> {
2728
// Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
2829
let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
2930

31+
// Generators are not Fn* so return early.
32+
if matches!(closure_ty.kind(Interner), TyKind::Generator(..)) {
33+
return;
34+
}
35+
3036
// Deduction based on the expected `dyn Fn` is done separately.
3137
if let TyKind::Dyn(dyn_ty) = expected_ty.kind(Interner) {
3238
if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) {

0 commit comments

Comments
 (0)