From 47fe867591c1783afcb6af33451e4318428d52d5 Mon Sep 17 00:00:00 2001 From: Xiangfei Ding Date: Thu, 10 Apr 2025 15:38:46 +0200 Subject: [PATCH] generate obligations instead of reporting ambiguity A new mode of type relating is introduced so that obligations are generated instead of outright rejecting projection clauses. This allows project candidates that are sourced from more than one predicates, such as supertrait bounds, provided that they do not conflict each other. --- compiler/rustc_infer/src/infer/at.rs | 40 +++++++++ .../src/infer/relate/type_relating.rs | 42 +++++++++- .../src/traits/project.rs | 82 +++++++++++++------ ...ted-type-projection-nonambig-supertrait.rs | 48 +++++++++++ 4 files changed, 185 insertions(+), 27 deletions(-) create mode 100644 tests/ui/associated-types/associated-type-projection-nonambig-supertrait.rs diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs index 2cd67cc4da213..029452753927e 100644 --- a/compiler/rustc_infer/src/infer/at.rs +++ b/compiler/rustc_infer/src/infer/at.rs @@ -238,6 +238,34 @@ impl<'a, 'tcx> At<'a, 'tcx> { } } + // FIXME(arbitrary_self_types): remove this interface + // when the new solver is stabilised. + /// Almost like `eq_trace` except this type relating procedure will + /// also generate the obligations arising from equating projection + /// candidates. + pub fn eq_with_proj( + self, + define_opaque_types: DefineOpaqueTypes, + expected: T, + actual: T, + ) -> InferResult<'tcx, ()> + where + T: ToTrace<'tcx>, + { + assert!(!self.infcx.next_trait_solver); + let trace = ToTrace::to_trace(self.cause, expected, actual); + let mut op = TypeRelating::new( + self.infcx, + trace, + self.param_env, + define_opaque_types, + ty::Invariant, + ) + .through_projections(true); + op.relate(expected, actual)?; + Ok(InferOk { value: (), obligations: op.into_obligations() }) + } + pub fn relate( self, define_opaque_types: DefineOpaqueTypes, @@ -369,6 +397,18 @@ impl<'tcx> ToTrace<'tcx> for ty::Term<'tcx> { } } +impl<'tcx> ToTrace<'tcx> for ty::Binder<'tcx, ty::Term<'tcx>> { + fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> { + TypeTrace { + cause: cause.clone(), + values: ValuePairs::Terms(ExpectedFound { + expected: a.skip_binder(), + found: b.skip_binder(), + }), + } + } +} + impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> { fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> { TypeTrace { cause: cause.clone(), values: ValuePairs::TraitRefs(ExpectedFound::new(a, b)) } diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs index 009271a8378f3..9c3a41a50cb39 100644 --- a/compiler/rustc_infer/src/infer/relate/type_relating.rs +++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs @@ -22,6 +22,11 @@ pub(crate) struct TypeRelating<'infcx, 'tcx> { param_env: ty::ParamEnv<'tcx>, define_opaque_types: DefineOpaqueTypes, + /// This indicates whether the relation should + /// report obligations arising from equating aliasing terms + /// involving associated types, instead of rejection. + through_projections: bool, + // Mutable fields. ambient_variance: ty::Variance, obligations: PredicateObligations<'tcx>, @@ -67,9 +72,15 @@ impl<'infcx, 'tcx> TypeRelating<'infcx, 'tcx> { ambient_variance, obligations: PredicateObligations::new(), cache: Default::default(), + through_projections: false, } } + pub(crate) fn through_projections(mut self, walk_through: bool) -> Self { + self.through_projections = walk_through; + self + } + pub(crate) fn into_obligations(self) -> PredicateObligations<'tcx> { self.obligations } @@ -128,6 +139,7 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, 'tcx> { if self.cache.contains(&(self.ambient_variance, a, b)) { return Ok(a); } + let mut relate_result = a; match (a.kind(), b.kind()) { (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => { @@ -201,6 +213,34 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, 'tcx> { )?); } + ( + ty::Alias(ty::Projection | ty::Opaque, _), + ty::Alias(ty::Projection | ty::Opaque, _), + ) => { + super_combine_tys(infcx, self, a, b)?; + } + + (&ty::Alias(ty::Projection, ty::AliasTy { def_id, args, .. }), _) + if matches!(self.ambient_variance, ty::Variance::Invariant) + && self.through_projections => + { + self.register_predicates([ty::ProjectionPredicate { + projection_term: ty::AliasTerm::new(self.cx(), def_id, args), + term: b.into(), + }]); + relate_result = b; + } + + (_, &ty::Alias(ty::Projection, ty::AliasTy { def_id, args, .. })) + if matches!(self.ambient_variance, ty::Variance::Invariant) + && self.through_projections => + { + self.register_predicates([ty::ProjectionPredicate { + projection_term: ty::AliasTerm::new(self.cx(), def_id, args), + term: a.into(), + }]); + } + _ => { super_combine_tys(infcx, self, a, b)?; } @@ -208,7 +248,7 @@ impl<'tcx> TypeRelation> for TypeRelating<'_, 'tcx> { assert!(self.cache.insert((self.ambient_variance, a, b))); - Ok(a) + Ok(relate_result) } #[instrument(skip(self), level = "trace")] diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 0dce504903ca4..c0be9128a539c 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -1,5 +1,6 @@ //! Code for projecting associated types out of trait references. +use std::iter::Extend; use std::ops::ControlFlow; use rustc_data_structures::sso::SsoHashSet; @@ -65,6 +66,7 @@ enum ProjectionCandidate<'tcx> { Select(Selection<'tcx>), } +#[derive(Debug)] enum ProjectionCandidateSet<'tcx> { None, Single(ProjectionCandidate<'tcx>), @@ -648,15 +650,26 @@ fn project<'cx, 'tcx>( } let mut candidates = ProjectionCandidateSet::None; + let mut derived_obligations = PredicateObligations::default(); // Make sure that the following procedures are kept in order. ParamEnv // needs to be first because it has highest priority, and Select checks // the return value of push_candidate which assumes it's ran at last. - assemble_candidates_from_param_env(selcx, obligation, &mut candidates); + assemble_candidates_from_param_env( + selcx, + obligation, + &mut candidates, + &mut derived_obligations, + ); assemble_candidates_from_trait_def(selcx, obligation, &mut candidates); - assemble_candidates_from_object_ty(selcx, obligation, &mut candidates); + assemble_candidates_from_object_ty( + selcx, + obligation, + &mut candidates, + &mut derived_obligations, + ); if let ProjectionCandidateSet::Single(ProjectionCandidate::Object(_)) = candidates { // Avoid normalization cycle from selection (see @@ -669,7 +682,13 @@ fn project<'cx, 'tcx>( match candidates { ProjectionCandidateSet::Single(candidate) => { - confirm_candidate(selcx, obligation, candidate) + confirm_candidate(selcx, obligation, candidate).map(move |proj| { + if let Projected::Progress(progress) = proj { + Projected::Progress(progress.with_addl_obligations(derived_obligations)) + } else { + proj + } + }) } ProjectionCandidateSet::None => { let tcx = selcx.tcx(); @@ -691,6 +710,7 @@ fn assemble_candidates_from_param_env<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, candidate_set: &mut ProjectionCandidateSet<'tcx>, + derived_obligations: &mut impl Extend>, ) { assemble_candidates_from_predicates( selcx, @@ -698,7 +718,7 @@ fn assemble_candidates_from_param_env<'cx, 'tcx>( candidate_set, ProjectionCandidate::ParamEnv, obligation.param_env.caller_bounds().iter(), - false, + derived_obligations, ); } @@ -712,6 +732,7 @@ fn assemble_candidates_from_param_env<'cx, 'tcx>( /// ``` /// /// Here, for example, we could conclude that the result is `i32`. +#[instrument(level = "debug", skip(selcx))] fn assemble_candidates_from_trait_def<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, @@ -774,6 +795,7 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, candidate_set: &mut ProjectionCandidateSet<'tcx>, + derived_obligations: &mut impl Extend>, ) { debug!("assemble_candidates_from_object_ty(..)"); @@ -806,21 +828,18 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>( candidate_set, ProjectionCandidate::Object, env_predicates, - false, + derived_obligations, ); } -#[instrument( - level = "debug", - skip(selcx, candidate_set, ctor, env_predicates, potentially_unnormalized_candidates) -)] +#[instrument(level = "debug", skip(selcx, env_predicates, derived_obligations))] fn assemble_candidates_from_predicates<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, candidate_set: &mut ProjectionCandidateSet<'tcx>, ctor: fn(ty::PolyProjectionPredicate<'tcx>) -> ProjectionCandidate<'tcx>, env_predicates: impl Iterator>, - potentially_unnormalized_candidates: bool, + derived_obligations: &mut impl Extend>, ) { let infcx = selcx.infcx; let drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx()); @@ -838,28 +857,39 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>( continue; } - let is_match = infcx.probe(|_| { - selcx.match_projection_projections( - obligation, - data, - potentially_unnormalized_candidates, - ) - }); + let is_match = + infcx.probe(|_| selcx.match_projection_projections(obligation, data, false)); match is_match { ProjectionMatchesProjection::Yes => { - candidate_set.push_candidate(ctor(data)); - - if potentially_unnormalized_candidates - && !obligation.predicate.has_non_region_infer() + debug!(?data, "push"); + if let ProjectionCandidateSet::Single( + ProjectionCandidate::ParamEnv(proj) + | ProjectionCandidate::Object(proj) + | ProjectionCandidate::TraitDef(proj), + ) = candidate_set { - // HACK: Pick the first trait def candidate for a fully - // inferred predicate. This is to allow duplicates that - // differ only in normalization. - return; + match infcx.commit_if_ok(|_| { + infcx.at(&obligation.cause, obligation.param_env).eq_with_proj( + DefineOpaqueTypes::No, + data.term(), + proj.term(), + ) + }) { + Ok(InferOk { value: (), obligations }) => { + derived_obligations.extend(obligations); + } + Err(e) => { + debug!(?e, "refuse to unify candidates"); + candidate_set.push_candidate(ctor(data)); + } + } + } else { + candidate_set.push_candidate(ctor(data)); } } ProjectionMatchesProjection::Ambiguous => { + debug!("mark ambiguous"); candidate_set.mark_ambiguous(); } ProjectionMatchesProjection::No => {} @@ -868,7 +898,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>( } } -#[instrument(level = "debug", skip(selcx, obligation, candidate_set))] +#[instrument(level = "debug", skip(selcx))] fn assemble_candidates_from_impls<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTermObligation<'tcx>, diff --git a/tests/ui/associated-types/associated-type-projection-nonambig-supertrait.rs b/tests/ui/associated-types/associated-type-projection-nonambig-supertrait.rs new file mode 100644 index 0000000000000..c57a535f594ad --- /dev/null +++ b/tests/ui/associated-types/associated-type-projection-nonambig-supertrait.rs @@ -0,0 +1,48 @@ +//@ revisions: traditional next_solver +//@ [next_solver] compile-flags: -Znext-solver +//@ check-pass + +use std::marker::PhantomData; + +pub trait Receiver { + type Target: ?Sized; +} + +pub trait Deref: Receiver::Target> { + type Target: ?Sized; + fn deref(&self) -> &::Target; +} + +impl Receiver for T { + type Target = ::Target; +} + +// === +pub struct Type(PhantomData<(Id, T)>); +pub struct AliasRef>(PhantomData<(Id, T)>); + +pub trait TypePtr: Deref::Id, Self>> + Sized { + // ^ the impl head here provides the first candidate + // ::Target := Type<::Id> + type Id; +} + +pub struct Alias(PhantomData<(Id, T)>); + +impl Deref for Alias +where + T: TypePtr + Deref>, + // ^ the impl head here provides the second candidate + // ::Target := Type + // and additionally a normalisation is mandatory due to + // the following supertrait relation trait + // Deref: Receiver::Target> +{ + type Target = AliasRef; + + fn deref(&self) -> &::Target { + todo!() + } +} + +fn main() {}