Skip to content

Commit 84d6723

Browse files
Infer async closure args from Fn bound even if there is no corresponding Future bound
1 parent 80eb5a8 commit 84d6723

File tree

2 files changed

+76
-11
lines changed

2 files changed

+76
-11
lines changed

compiler/rustc_hir_typeck/src/closure.rs

+27-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use rustc_middle::span_bug;
1414
use rustc_middle::ty::visit::{TypeVisitable, TypeVisitableExt};
1515
use rustc_middle::ty::{self, GenericArgs, Ty, TyCtxt, TypeSuperVisitable, TypeVisitor};
1616
use rustc_span::def_id::LocalDefId;
17-
use rustc_span::Span;
17+
use rustc_span::{Span, DUMMY_SP};
1818
use rustc_target::spec::abi::Abi;
1919
use rustc_trait_selection::error_reporting::traits::ArgKind;
2020
use rustc_trait_selection::traits;
@@ -539,6 +539,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
539539
/// we identify the `FnOnce<Args, Output = ?Fut>` bound, and if the output type is
540540
/// an inference variable `?Fut`, we check if that is bounded by a `Future<Output = Ty>`
541541
/// projection.
542+
///
543+
/// This function is actually best-effort with the return type; if we don't find a
544+
/// `Future` projection, we still will return arguments that we extracted from the `FnOnce`
545+
/// projection, and the output will be an unconstrained type variable instead.
542546
fn extract_sig_from_projection_and_future_bound(
543547
&self,
544548
cause_span: Option<Span>,
@@ -564,24 +568,36 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
564568
};
565569

566570
// FIXME: We may want to elaborate here, though I assume this will be exceedingly rare.
571+
let mut return_ty = None;
567572
for bound in self.obligations_for_self_ty(return_vid) {
568573
if let Some(ret_projection) = bound.predicate.as_projection_clause()
569574
&& let Some(ret_projection) = ret_projection.no_bound_vars()
570575
&& self.tcx.is_lang_item(ret_projection.def_id(), LangItem::FutureOutput)
571576
{
572-
let sig = projection.rebind(self.tcx.mk_fn_sig(
573-
input_tys,
574-
ret_projection.term.expect_type(),
575-
false,
576-
hir::Safety::Safe,
577-
Abi::Rust,
578-
));
579-
580-
return Some(ExpectedSig { cause_span, sig });
577+
return_ty = Some(ret_projection.term.expect_type());
578+
break;
581579
}
582580
}
583581

584-
None
582+
// SUBTLE: If we didn't find a `Future<Output = ...>` bound for the return
583+
// vid, we still want to attempt to provide inference guidance for the async
584+
// closure's arguments. Instantiate a new vid to plug into the output type.
585+
//
586+
// You may be wondering, what if it's higher-ranked? Well, given that we
587+
// found a type variable for the `FnOnce::Output` projection above, we know
588+
// that the output can't mention any of the vars.
589+
let return_ty =
590+
return_ty.unwrap_or_else(|| self.next_ty_var(cause_span.unwrap_or(DUMMY_SP)));
591+
592+
let sig = projection.rebind(self.tcx.mk_fn_sig(
593+
input_tys,
594+
return_ty,
595+
false,
596+
hir::Safety::Safe,
597+
Abi::Rust,
598+
));
599+
600+
return Some(ExpectedSig { cause_span, sig });
585601
}
586602

587603
fn sig_of_closure(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//@ check-pass
2+
//@ edition: 2021
3+
4+
// Make sure that we infer the args of an async closure even if it's passed to
5+
// a function that requires the async closure implement `Fn*` but does *not* have
6+
// a `Future` bound on the return type.
7+
8+
#![feature(async_closure)]
9+
10+
use std::future::Future;
11+
12+
trait TryStream {
13+
type Ok;
14+
type Err;
15+
}
16+
17+
trait TryFuture {
18+
type Ok;
19+
type Err;
20+
}
21+
22+
impl<F, T, E> TryFuture for F where F: Future<Output = Result<T, E>> {
23+
type Ok = T;
24+
type Err = E;
25+
}
26+
27+
trait TryStreamExt: TryStream {
28+
fn try_for_each<F, Fut>(&self, f: F)
29+
where
30+
F: FnMut(Self::Ok) -> Fut,
31+
Fut: TryFuture<Ok = (), Err = Self::Err>;
32+
}
33+
34+
impl<S> TryStreamExt for S where S: TryStream {
35+
fn try_for_each<F, Fut>(&self, f: F)
36+
where
37+
F: FnMut(Self::Ok) -> Fut,
38+
Fut: TryFuture<Ok = (), Err = Self::Err>,
39+
{ }
40+
}
41+
42+
fn test(stream: impl TryStream<Ok = &'static str, Err = ()>) {
43+
stream.try_for_each(async |s| {
44+
s.trim(); // Make sure we know the type of `s` at this point.
45+
Ok(())
46+
});
47+
}
48+
49+
fn main() {}

0 commit comments

Comments
 (0)