diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs index 2cd67cc4da213..029452753927e 100644 --- a/compiler/rustc_infer/src/infer/at.rs +++ b/compiler/rustc_infer/src/infer/at.rs @@ -238,6 +238,34 @@ impl<'a, 'tcx> At<'a, 'tcx> { } } + // FIXME(arbitrary_self_types): remove this interface + // when the new solver is stabilised. + /// Almost like `eq_trace` except this type relating procedure will + /// also generate the obligations arising from equating projection + /// candidates. + pub fn eq_with_proj( + self, + define_opaque_types: DefineOpaqueTypes, + expected: T, + actual: T, + ) -> InferResult<'tcx, ()> + where + T: ToTrace<'tcx>, + { + assert!(!self.infcx.next_trait_solver); + let trace = ToTrace::to_trace(self.cause, expected, actual); + let mut op = TypeRelating::new( + self.infcx, + trace, + self.param_env, + define_opaque_types, + ty::Invariant, + ) + .through_projections(true); + op.relate(expected, actual)?; + Ok(InferOk { value: (), obligations: op.into_obligations() }) + } + pub fn relate( self, define_opaque_types: DefineOpaqueTypes, @@ -369,6 +397,18 @@ impl<'tcx> ToTrace<'tcx> for ty::Term<'tcx> { } } +impl<'tcx> ToTrace<'tcx> for ty::Binder<'tcx, ty::Term<'tcx>> { + fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> { + TypeTrace { + cause: cause.clone(), + values: ValuePairs::Terms(ExpectedFound { + expected: a.skip_binder(), + found: b.skip_binder(), + }), + } + } +} + impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> { fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> { TypeTrace { cause: cause.clone(), values: ValuePairs::TraitRefs(ExpectedFound::new(a, b)) } diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs index 009271a8378f3..9c3a41a50cb39 100644 --- a/compiler/rustc_infer/src/infer/relate/type_relating.rs +++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs @@ -22,6 +22,11 @@ pub(crate) struct TypeRelating<'infcx, 'tcx> { param_env: ty::ParamEnv<'tcx>, define_opaque_types: DefineOpaqueTypes, + /// This indicates whether the relation should + /// report obligations arising from equating aliasing terms + /// involving associated types, instead of rejection. + through_projections: bool, + // Mutable fields. ambient_variance: ty::Variance, obligations: PredicateObligations<'tcx>, @@ -67,9 +72,15 @@ impl<'infcx, 'tcx> TypeRelating<'infcx, 'tcx> { ambient_variance, obligations: PredicateObligations::new(), cache: Default::default(), + through_projections: false, } } + pub(crate) fn through_projections(mut self, walk_through: bool) -> Self { + self.through_projections = walk_through; + self + } + pub(crate) fn into_obligations(self) -> PredicateObligations<'tcx> { self.obligations } @@ -128,6 +139,7 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, 'tcx> { if self.cache.contains(&(self.ambient_variance, a, b)) { return Ok(a); } + let mut relate_result = a; match (a.kind(), b.kind()) { (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => { @@ -201,6 +213,34 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, 'tcx> { )?); } + ( + ty::Alias(ty::Projection | ty::Opaque, _), + ty::Alias(ty::Projection | ty::Opaque, _), + ) => { + super_combine_tys(infcx, self, a, b)?; + } + + (&ty::Alias(ty::Projection, ty::AliasTy { def_id, args, .. }), _) + if matches!(self.ambient_variance, ty::Variance::Invariant) + && self.through_projections => + { + self.register_predicates([ty::ProjectionPredicate { + projection_term: ty::AliasTerm::new(self.cx(), def_id, args), + term: b.into(), + }]); + relate_result = b; + } + + (_, &ty::Alias(ty::Projection, ty::AliasTy { def_id, args, .. })) + if matches!(self.ambient_variance, ty::Variance::Invariant) + && self.through_projections => + { + self.register_predicates([ty::ProjectionPredicate { + projection_term: ty::AliasTerm::new(self.cx(), def_id, args), + term: a.into(), + }]); + } + _ => { super_combine_tys(infcx, self, a, b)?; } @@ -208,7 +248,7 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, 'tcx> { assert!(self.cache.insert((self.ambient_variance, a, b))); - Ok(a) + Ok(relate_result) } #[instrument(skip(self), level = "trace")] diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 0dce504903ca4..c0be9128a539c 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -1,5 +1,6 @@ //! Code for projecting associated types out of trait references. +use std::iter::Extend; use std::ops::ControlFlow; use rustc_data_structures::sso::SsoHashSet; @@ -65,6 +66,7 @@ enum ProjectionCandidate<'tcx> { Select(Selection<'tcx>), } +#[derive(Debug)] enum ProjectionCandidateSet<'tcx> { None, Single(ProjectionCandidate<'tcx>), @@ -648,15 +650,26 @@ fn project<'cx, 'tcx>( } let mut candidates = ProjectionCandidateSet::None; + let mut derived_obligations = PredicateObligations::default(); // Make sure that the following procedures are kept in order. ParamEnv // needs to be first because it has highest priority, and Select checks // the return value of push_candidate which assumes it's ran at last. - assemble_candidates_from_param_env(selcx, obligation, &mut candidates); + assemble_candidates_from_param_env( + selcx, + obligation, + &mut candidates, + &mut derived_obligations, + ); assemble_candidates_from_trait_def(selcx, obligation, &mut candidates); - assemble_candidates_from_object_ty(selcx, obligation, &mut candidates); + assemble_candidates_from_object_ty( + selcx, + obligation, + &mut candidates, + &mut derived_obligations, + ); if let ProjectionCandidateSet::Single(ProjectionCandidate::Object(_)) = candidates { // Avoid normalization cycle from selection (see @@ -669,7 +682,13 @@ fn project<'cx, 'tcx>( match candidates { ProjectionCandidateSet::Single(candidate) => { - confirm_candidate(selcx, obligation, candidate) + confirm_candidate(selcx, obligation, candidate).map(move |proj| { + if let Projected::Progress(progress) = proj { + Projected::Progress(progress.with_addl_obligations(derived_obligations)) + } else { + proj + } + }) } ProjectionCandidateSet::None => { let tcx = selcx.tcx(); @@ -691,6 +710,7 @@ fn assemble_candidates_from_param_env<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, candidate_set: &mut ProjectionCandidateSet<'tcx>, + derived_obligations: &mut impl Extend>, ) { assemble_candidates_from_predicates( selcx, @@ -698,7 +718,7 @@ fn assemble_candidates_from_param_env<'cx, 'tcx>( candidate_set, ProjectionCandidate::ParamEnv, obligation.param_env.caller_bounds().iter(), - false, + derived_obligations, ); } @@ -712,6 +732,7 @@ fn assemble_candidates_from_param_env<'cx, 'tcx>( /// ``` /// /// Here, for example, we could conclude that the result is `i32`. +#[instrument(level = "debug", skip(selcx))] fn assemble_candidates_from_trait_def<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, @@ -774,6 +795,7 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, candidate_set: &mut ProjectionCandidateSet<'tcx>, + derived_obligations: &mut impl Extend>, ) { debug!("assemble_candidates_from_object_ty(..)"); @@ -806,21 +828,18 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>( candidate_set, ProjectionCandidate::Object, env_predicates, - false, + derived_obligations, ); } -#[instrument( - level = "debug", - skip(selcx, candidate_set, ctor, env_predicates, potentially_unnormalized_candidates) -)] +#[instrument(level = "debug", skip(selcx, env_predicates, derived_obligations))] fn assemble_candidates_from_predicates<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, candidate_set: &mut ProjectionCandidateSet<'tcx>, ctor: fn(ty::PolyProjectionPredicate<'tcx>) -> ProjectionCandidate<'tcx>, env_predicates: impl Iterator>, - potentially_unnormalized_candidates: bool, + derived_obligations: &mut impl Extend>, ) { let infcx = selcx.infcx; let drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx()); @@ -838,28 +857,39 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>( continue; } - let is_match = infcx.probe(|_| { - selcx.match_projection_projections( - obligation, - data, - potentially_unnormalized_candidates, - ) - }); + let is_match = + infcx.probe(|_| selcx.match_projection_projections(obligation, data, false)); match is_match { ProjectionMatchesProjection::Yes => { - candidate_set.push_candidate(ctor(data)); - - if potentially_unnormalized_candidates - && !obligation.predicate.has_non_region_infer() + debug!(?data, "push"); + if let ProjectionCandidateSet::Single( + ProjectionCandidate::ParamEnv(proj) + | ProjectionCandidate::Object(proj) + | ProjectionCandidate::TraitDef(proj), + ) = candidate_set { - // HACK: Pick the first trait def candidate for a fully - // inferred predicate. This is to allow duplicates that - // differ only in normalization. - return; + match infcx.commit_if_ok(|_| { + infcx.at(&obligation.cause, obligation.param_env).eq_with_proj( + DefineOpaqueTypes::No, + data.term(), + proj.term(), + ) + }) { + Ok(InferOk { value: (), obligations }) => { + derived_obligations.extend(obligations); + } + Err(e) => { + debug!(?e, "refuse to unify candidates"); + candidate_set.push_candidate(ctor(data)); + } + } + } else { + candidate_set.push_candidate(ctor(data)); } } ProjectionMatchesProjection::Ambiguous => { + debug!("mark ambiguous"); candidate_set.mark_ambiguous(); } ProjectionMatchesProjection::No => {} @@ -868,7 +898,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>( } } -#[instrument(level = "debug", skip(selcx, obligation, candidate_set))] +#[instrument(level = "debug", skip(selcx))] fn assemble_candidates_from_impls<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, diff --git a/tests/ui/associated-types/associated-type-projection-nonambig-supertrait.rs b/tests/ui/associated-types/associated-type-projection-nonambig-supertrait.rs new file mode 100644 index 0000000000000..c57a535f594ad --- /dev/null +++ b/tests/ui/associated-types/associated-type-projection-nonambig-supertrait.rs @@ -0,0 +1,48 @@ +//@ revisions: traditional next_solver +//@ [next_solver] compile-flags: -Znext-solver +//@ check-pass + +use std::marker::PhantomData; + +pub trait Receiver { + type Target: ?Sized; +} + +pub trait Deref: Receiver::Target> { + type Target: ?Sized; + fn deref(&self) -> &::Target; +} + +impl Receiver for T { + type Target = ::Target; +} + +// === +pub struct Type(PhantomData<(Id, T)>); +pub struct AliasRef>(PhantomData<(Id, T)>); + +pub trait TypePtr: Deref::Id, Self>> + Sized { + // ^ the impl head here provides the first candidate + // ::Target := Type<::Id> + type Id; +} + +pub struct Alias(PhantomData<(Id, T)>); + +impl Deref for Alias +where + T: TypePtr + Deref>, + // ^ the impl head here provides the second candidate + // ::Target := Type + // and additionally a normalisation is mandatory due to + // the following supertrait relation trait + // Deref: Receiver::Target> +{ + type Target = AliasRef; + + fn deref(&self) -> &::Target { + todo!() + } +} + +fn main() {}