Skip to content

Commit 9a757d6

Browse files
committed
add eq to InferCtxtExt
1 parent 660c283 commit 9a757d6

File tree

3 files changed

+49
-38
lines changed

3 files changed

+49
-38
lines changed
Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,28 @@
1+
use rustc_infer::infer::at::ToTrace;
12
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
2-
use rustc_infer::infer::InferCtxt;
3-
use rustc_middle::ty::Ty;
3+
use rustc_infer::infer::{InferCtxt, InferOk};
4+
use rustc_infer::traits::query::NoSolution;
5+
use rustc_infer::traits::ObligationCause;
6+
use rustc_middle::ty::{self, Ty};
47
use rustc_span::DUMMY_SP;
58

9+
use super::Goal;
10+
611
/// Methods used inside of the canonical queries of the solver.
12+
///
13+
/// Most notably these do not care about diagnostics information.
14+
/// If you find this while looking for methods to use outside of the
15+
/// solver, you may look at the implementation of these method for
16+
/// help.
717
pub(super) trait InferCtxtExt<'tcx> {
818
fn next_ty_infer(&self) -> Ty<'tcx>;
19+
20+
fn eq<T: ToTrace<'tcx>>(
21+
&self,
22+
param_env: ty::ParamEnv<'tcx>,
23+
lhs: T,
24+
rhs: T,
25+
) -> Result<Vec<Goal<'tcx, ty::Predicate<'tcx>>>, NoSolution>;
926
}
1027

1128
impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
@@ -15,4 +32,23 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
1532
span: DUMMY_SP,
1633
})
1734
}
35+
36+
#[instrument(level = "debug", skip(self, param_env), ret)]
37+
fn eq<T: ToTrace<'tcx>>(
38+
&self,
39+
param_env: ty::ParamEnv<'tcx>,
40+
lhs: T,
41+
rhs: T,
42+
) -> Result<Vec<Goal<'tcx, ty::Predicate<'tcx>>>, NoSolution> {
43+
self.at(&ObligationCause::dummy(), param_env)
44+
.define_opaque_types(false)
45+
.eq(lhs, rhs)
46+
.map(|InferOk { value: (), obligations }| {
47+
obligations.into_iter().map(|o| o.into()).collect()
48+
})
49+
.map_err(|e| {
50+
debug!(?e, "failed to equate");
51+
NoSolution
52+
})
53+
}
1854
}

compiler/rustc_trait_selection/src/solve/project_goals.rs

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use crate::traits::{specialization_graph, translate_substs};
22

33
use super::assembly::{self, Candidate, CandidateSource};
4+
use super::infcx_ext::InferCtxtExt;
45
use super::{Certainty, EvalCtxt, Goal, QueryResult};
56
use rustc_errors::ErrorGuaranteed;
67
use rustc_hir::def::DefKind;
78
use rustc_hir::def_id::DefId;
8-
use rustc_infer::infer::{InferCtxt, InferOk};
9+
use rustc_infer::infer::InferCtxt;
910
use rustc_infer::traits::query::NoSolution;
1011
use rustc_infer::traits::specialization_graph::LeafDef;
11-
use rustc_infer::traits::{ObligationCause, Reveal};
12+
use rustc_infer::traits::Reveal;
1213
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
1314
use rustc_middle::ty::ProjectionPredicate;
1415
use rustc_middle::ty::TypeVisitable;
@@ -112,23 +113,15 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
112113
let impl_substs = ecx.infcx.fresh_substs_for_item(DUMMY_SP, impl_def_id);
113114
let impl_trait_ref = impl_trait_ref.subst(tcx, impl_substs);
114115

115-
let Ok(InferOk { obligations, .. }) = ecx.infcx
116-
.at(&ObligationCause::dummy(), goal.param_env)
117-
.define_opaque_types(false)
118-
.eq(goal_trait_ref, impl_trait_ref)
119-
.map_err(|e| debug!("failed to equate trait refs: {e:?}"))
120-
else {
121-
return Err(NoSolution)
122-
};
116+
let mut nested_goals = ecx.infcx.eq(goal.param_env, goal_trait_ref, impl_trait_ref)?;
123117
let where_clause_bounds = tcx
124118
.predicates_of(impl_def_id)
125119
.instantiate(tcx, impl_substs)
126120
.predicates
127121
.into_iter()
128122
.map(|pred| goal.with(tcx, pred));
129123

130-
let nested_goals =
131-
obligations.into_iter().map(|o| o.into()).chain(where_clause_bounds).collect();
124+
nested_goals.extend(where_clause_bounds);
132125
let trait_ref_certainty = ecx.evaluate_all(nested_goals)?;
133126

134127
let Some(assoc_def) = fetch_eligible_assoc_item_def(
@@ -185,16 +178,8 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
185178
ty.map_bound(|ty| ty.into())
186179
};
187180

188-
let Ok(InferOk { obligations, .. }) = ecx.infcx
189-
.at(&ObligationCause::dummy(), goal.param_env)
190-
.define_opaque_types(false)
191-
.eq(goal.predicate.term, term.subst(tcx, substs))
192-
.map_err(|e| debug!("failed to equate trait refs: {e:?}"))
193-
else {
194-
return Err(NoSolution);
195-
};
196-
197-
let nested_goals = obligations.into_iter().map(|o| o.into()).collect();
181+
let nested_goals =
182+
ecx.infcx.eq(goal.param_env, goal.predicate.term, term.subst(tcx, substs))?;
198183
let rhs_certainty = ecx.evaluate_all(nested_goals)?;
199184

200185
Ok(trait_ref_certainty.unify_and(rhs_certainty))

compiler/rustc_trait_selection/src/solve/trait_goals.rs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
use std::iter;
44

55
use super::assembly::{self, Candidate, CandidateSource};
6+
use super::infcx_ext::InferCtxtExt;
67
use super::{Certainty, EvalCtxt, Goal, QueryResult};
78
use rustc_hir::def_id::DefId;
8-
use rustc_infer::infer::InferOk;
99
use rustc_infer::traits::query::NoSolution;
10-
use rustc_infer::traits::ObligationCause;
1110
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
1211
use rustc_middle::ty::TraitPredicate;
1312
use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -45,24 +44,15 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
4544
let impl_substs = ecx.infcx.fresh_substs_for_item(DUMMY_SP, impl_def_id);
4645
let impl_trait_ref = impl_trait_ref.subst(tcx, impl_substs);
4746

48-
let Ok(InferOk { obligations, .. }) = ecx.infcx
49-
.at(&ObligationCause::dummy(), goal.param_env)
50-
.define_opaque_types(false)
51-
.eq(goal.predicate.trait_ref, impl_trait_ref)
52-
.map_err(|e| debug!("failed to equate trait refs: {e:?}"))
53-
else {
54-
return Err(NoSolution);
55-
};
47+
let mut nested_goals =
48+
ecx.infcx.eq(goal.param_env, goal.predicate.trait_ref, impl_trait_ref)?;
5649
let where_clause_bounds = tcx
5750
.predicates_of(impl_def_id)
5851
.instantiate(tcx, impl_substs)
5952
.predicates
6053
.into_iter()
6154
.map(|pred| goal.with(tcx, pred));
62-
63-
let nested_goals =
64-
obligations.into_iter().map(|o| o.into()).chain(where_clause_bounds).collect();
65-
55+
nested_goals.extend(where_clause_bounds);
6656
ecx.evaluate_all(nested_goals)
6757
})
6858
}

0 commit comments

Comments
 (0)