Skip to content

Commit 5f025f3

Browse files
committed
Auto merge of #141581 - lcnr:fold-clauses, r=compiler-errors
add additional `TypeFlags` fast paths Some crates, e.g. `diesel`, have items with a lot of where-clauses (more than 150). In these cases checking the `TypeFlags` of the whole `param_env` can be very beneficial. This adds `fn fold_clauses` to mirror the existing `fn visit_clauses` and then uses this in folders which fold `ParamEnv`s. Split out from #141451, depends on #141442. r? `@compiler-errors`
2 parents ebe9b00 + 0830ce0 commit 5f025f3

File tree

16 files changed

+138
-18
lines changed

16 files changed

+138
-18
lines changed

compiler/rustc_infer/src/infer/canonical/canonicalizer.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,10 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
497497
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
498498
if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p }
499499
}
500+
501+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
502+
if c.flags().intersects(self.needs_canonical_flags) { c.super_fold_with(self) } else { c }
503+
}
500504
}
501505

502506
impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {

compiler/rustc_infer/src/infer/resolve.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {
5555
ct.super_fold_with(self)
5656
}
5757
}
58+
59+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
60+
if !p.has_non_region_infer() { p } else { p.super_fold_with(self) }
61+
}
62+
63+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
64+
if !c.has_non_region_infer() { c } else { c.super_fold_with(self) }
65+
}
5866
}
5967

6068
/// The opportunistic region resolver opportunistically resolves regions

compiler/rustc_middle/src/ty/erase_regions.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,12 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for RegionEraserVisitor<'tcx> {
8686
p
8787
}
8888
}
89+
90+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
91+
if c.has_type_flags(TypeFlags::HAS_BINDER_VARS | TypeFlags::HAS_FREE_REGIONS) {
92+
c.super_fold_with(self)
93+
} else {
94+
c
95+
}
96+
}
8997
}

compiler/rustc_middle/src/ty/fold.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ where
177177
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
178178
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
179179
}
180+
181+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
182+
if c.has_vars_bound_at_or_above(self.current_index) { c.super_fold_with(self) } else { c }
183+
}
180184
}
181185

182186
impl<'tcx> TyCtxt<'tcx> {

compiler/rustc_middle/src/ty/predicate.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ impl<'tcx> Clause<'tcx> {
238238
}
239239
}
240240

241+
impl<'tcx> rustc_type_ir::inherent::Clauses<TyCtxt<'tcx>> for ty::Clauses<'tcx> {}
242+
241243
#[extension(pub trait ExistentialPredicateStableCmpExt<'tcx>)]
242244
impl<'tcx> ExistentialPredicate<'tcx> {
243245
/// Compares via an ordering that will not change if modules are reordered or other changes are

compiler/rustc_middle/src/ty/structural_impls.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,19 @@ impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clause<'tcx> {
570570
}
571571
}
572572

573+
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
574+
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
575+
self,
576+
folder: &mut F,
577+
) -> Result<Self, F::Error> {
578+
folder.try_fold_clauses(self)
579+
}
580+
581+
fn fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
582+
folder.fold_clauses(self)
583+
}
584+
}
585+
573586
impl<'tcx> TypeVisitable<TyCtxt<'tcx>> for ty::Predicate<'tcx> {
574587
fn visit_with<V: TypeVisitor<TyCtxt<'tcx>>>(&self, visitor: &mut V) -> V::Result {
575588
visitor.visit_predicate(*self)
@@ -615,6 +628,19 @@ impl<'tcx> TypeSuperVisitable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
615628
}
616629
}
617630

631+
impl<'tcx> TypeSuperFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
632+
fn try_super_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
633+
self,
634+
folder: &mut F,
635+
) -> Result<Self, F::Error> {
636+
ty::util::try_fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
637+
}
638+
639+
fn super_fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
640+
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
641+
}
642+
}
643+
618644
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Const<'tcx> {
619645
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
620646
self,
@@ -775,7 +801,6 @@ macro_rules! list_fold {
775801
}
776802

777803
list_fold! {
778-
ty::Clauses<'tcx> : mk_clauses,
779804
&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> : mk_poly_existential_predicates,
780805
&'tcx ty::List<PlaceElem<'tcx>> : mk_place_elems,
781806
&'tcx ty::List<ty::Pattern<'tcx>> : mk_patterns,

compiler/rustc_next_trait_solver/src/canonicalizer.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,4 +572,15 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
572572
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
573573
if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p }
574574
}
575+
576+
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
577+
match self.canonicalize_mode {
578+
CanonicalizeMode::Input { keep_static: true }
579+
| CanonicalizeMode::Response { max_input_universe: _ } => {}
580+
CanonicalizeMode::Input { keep_static: false } => {
581+
panic!("erasing 'static in env")
582+
}
583+
}
584+
if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c }
585+
}
575586
}

compiler/rustc_next_trait_solver/src/resolve.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::delegate::SolverDelegate;
1111
// EAGER RESOLUTION
1212

1313
/// Resolves ty, region, and const vars to their inferred values or their root vars.
14-
pub struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
14+
struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
1515
where
1616
D: SolverDelegate<Interner = I>,
1717
I: Interner,
@@ -22,8 +22,20 @@ where
2222
cache: DelayedMap<I::Ty, I::Ty>,
2323
}
2424

25+
pub fn eager_resolve_vars<D: SolverDelegate, T: TypeFoldable<D::Interner>>(
26+
delegate: &D,
27+
value: T,
28+
) -> T {
29+
if value.has_infer() {
30+
let mut folder = EagerResolver::new(delegate);
31+
value.fold_with(&mut folder)
32+
} else {
33+
value
34+
}
35+
}
36+
2537
impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
26-
pub fn new(delegate: &'a D) -> Self {
38+
fn new(delegate: &'a D) -> Self {
2739
EagerResolver { delegate, cache: Default::default() }
2840
}
2941
}
@@ -90,4 +102,8 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
90102
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
91103
if p.has_infer() { p.super_fold_with(self) } else { p }
92104
}
105+
106+
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
107+
if c.has_infer() { c.super_fold_with(self) } else { c }
108+
}
93109
}

compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use tracing::{debug, instrument, trace};
2222

2323
use crate::canonicalizer::Canonicalizer;
2424
use crate::delegate::SolverDelegate;
25-
use crate::resolve::EagerResolver;
25+
use crate::resolve::eager_resolve_vars;
2626
use crate::solve::eval_ctxt::CurrentGoalKind;
2727
use crate::solve::{
2828
CanonicalInput, CanonicalResponse, Certainty, EvalCtxt, ExternalConstraintsData, Goal,
@@ -61,8 +61,7 @@ where
6161
// so we only canonicalize the lookup table and ignore
6262
// duplicate entries.
6363
let opaque_types = self.delegate.clone_opaque_types_lookup_table();
64-
let (goal, opaque_types) =
65-
(goal, opaque_types).fold_with(&mut EagerResolver::new(self.delegate));
64+
let (goal, opaque_types) = eager_resolve_vars(self.delegate, (goal, opaque_types));
6665

6766
let mut orig_values = Default::default();
6867
let canonical = Canonicalizer::canonicalize_input(
@@ -162,8 +161,8 @@ where
162161

163162
let external_constraints =
164163
self.compute_external_query_constraints(certainty, normalization_nested_goals);
165-
let (var_values, mut external_constraints) = (self.var_values, external_constraints)
166-
.fold_with(&mut EagerResolver::new(self.delegate));
164+
let (var_values, mut external_constraints) =
165+
eager_resolve_vars(self.delegate, (self.var_values, external_constraints));
167166

168167
// Remove any trivial or duplicated region constraints once we've resolved regions
169168
let mut unique = HashSet::default();
@@ -474,7 +473,7 @@ where
474473
{
475474
let var_values = CanonicalVarValues { var_values: delegate.cx().mk_args(var_values) };
476475
let state = inspect::State { var_values, data };
477-
let state = state.fold_with(&mut EagerResolver::new(delegate));
476+
let state = eager_resolve_vars(delegate, state);
478477
Canonicalizer::canonicalize_response(delegate, max_input_universe, &mut vec![], state)
479478
}
480479

compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,22 @@ where
925925
}
926926
}
927927
}
928+
929+
fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result {
930+
if p.has_non_region_infer() || p.has_placeholders() {
931+
p.super_visit_with(self)
932+
} else {
933+
ControlFlow::Continue(())
934+
}
935+
}
936+
937+
fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result {
938+
if c.has_non_region_infer() || c.has_placeholders() {
939+
c.super_visit_with(self)
940+
} else {
941+
ControlFlow::Continue(())
942+
}
943+
}
928944
}
929945

930946
let mut visitor = ContainsTermOrNotNameable {

0 commit comments

Comments
 (0)