Skip to content

Commit 4f864db

Browse files
committed
yyeet
1 parent 9fb91aa commit 4f864db

File tree

4 files changed

+157
-33
lines changed

4 files changed

+157
-33
lines changed

compiler/rustc_infer/src/infer/snapshot/fudge.rs

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,7 @@ fn const_vars_since_snapshot<'tcx>(
4343
)
4444
}
4545

46-
struct VariableLengths {
47-
type_var_len: usize,
48-
const_var_len: usize,
49-
int_var_len: usize,
50-
float_var_len: usize,
51-
region_constraints_len: usize,
52-
}
53-
5446
impl<'tcx> InferCtxt<'tcx> {
55-
fn variable_lengths(&self) -> VariableLengths {
56-
let mut inner = self.inner.borrow_mut();
57-
VariableLengths {
58-
type_var_len: inner.type_variables().num_vars(),
59-
const_var_len: inner.const_unification_table().len(),
60-
int_var_len: inner.int_unification_table().len(),
61-
float_var_len: inner.float_unification_table().len(),
62-
region_constraints_len: inner.unwrap_region_constraints().num_region_vars(),
63-
}
64-
}
65-
6647
/// This rather funky routine is used while processing expected
6748
/// types. What happens here is that we want to propagate a
6849
/// coercion through the return type of a fn to its
@@ -109,7 +90,7 @@ impl<'tcx> InferCtxt<'tcx> {
10990
T: TypeFoldable<TyCtxt<'tcx>>,
11091
{
11192
let variable_lengths = self.variable_lengths();
112-
let (mut fudger, value) = self.probe(|_| {
93+
let (mut fudger, value) = self.probe_unchecked(|_| {
11394
match f() {
11495
Ok(value) => {
11596
let value = self.resolve_vars_if_possible(value);
@@ -122,21 +103,21 @@ impl<'tcx> InferCtxt<'tcx> {
122103

123104
let mut inner = self.inner.borrow_mut();
124105
let type_vars =
125-
inner.type_variables().vars_since_snapshot(variable_lengths.type_var_len);
106+
inner.type_variables().vars_since_snapshot(variable_lengths.type_vars);
126107
let int_vars = vars_since_snapshot(
127108
&mut inner.int_unification_table(),
128-
variable_lengths.int_var_len,
109+
variable_lengths.int_vars,
129110
);
130111
let float_vars = vars_since_snapshot(
131112
&mut inner.float_unification_table(),
132-
variable_lengths.float_var_len,
113+
variable_lengths.float_vars,
133114
);
134115
let region_vars = inner
135116
.unwrap_region_constraints()
136-
.vars_since_snapshot(variable_lengths.region_constraints_len);
117+
.vars_since_snapshot(variable_lengths.region_vars);
137118
let const_vars = const_vars_since_snapshot(
138119
&mut inner.const_unification_table(),
139-
variable_lengths.const_var_len,
120+
variable_lengths.const_vars,
140121
);
141122

142123
let fudger = InferenceFudger {

compiler/rustc_infer/src/infer/snapshot/mod.rs

Lines changed: 148 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
use std::ops::ControlFlow;
2+
13
use super::region_constraints::RegionSnapshot;
24
use super::InferCtxt;
35
use rustc_data_structures::undo_log::UndoLogs;
4-
use rustc_middle::ty;
6+
use rustc_middle::ty::{TypeFoldable, TypeSuperVisitable};
7+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitor};
58

69
mod fudge;
710
pub(crate) mod undo_log;
@@ -57,27 +60,42 @@ impl<'tcx> InferCtxt<'tcx> {
5760

5861
/// Execute `f` and commit the bindings if closure `f` returns `Ok(_)`.
5962
#[instrument(skip(self, f), level = "debug")]
60-
pub fn commit_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
63+
pub fn commit_if_ok<T, E, F, CX>(&self, f: F) -> Result<T, E>
6164
where
6265
F: FnOnce(&CombinedSnapshot<'tcx>) -> Result<T, E>,
66+
E: NoSnapshotLeaks<'tcx, CX>
6367
{
68+
let no_leaks_data = E::mk_data(self);
6469
let snapshot = self.start_snapshot();
6570
let r = f(&snapshot);
6671
debug!("commit_if_ok() -- r.is_ok() = {}", r.is_ok());
6772
match r {
68-
Ok(_) => {
73+
Ok(value) => {
6974
self.commit_from(snapshot);
75+
Ok(value)
7076
}
71-
Err(_) => {
77+
Err(err) => {
7278
self.rollback_to(snapshot);
79+
Err(E::avoid_leaks(err, no_leaks_data))
7380
}
7481
}
75-
r
7682
}
7783

7884
/// Execute `f` then unroll any bindings it creates.
7985
#[instrument(skip(self, f), level = "debug")]
80-
pub fn probe<R, F>(&self, f: F) -> R
86+
pub fn probe<R, F, CX>(&self, f: F) -> R
87+
where
88+
F: FnOnce(&CombinedSnapshot<'tcx>) -> R,
89+
R: NoSnapshotLeaks<'tcx, CX>,
90+
{
91+
let no_leaks_data = R::mk_data(self);
92+
let snapshot = self.start_snapshot();
93+
let r = f(&snapshot);
94+
self.rollback_to(snapshot);
95+
R::avoid_leaks(r, no_leaks_data)
96+
}
97+
98+
pub fn probe_unchecked<R, F>(&self, f: F) -> R
8199
where
82100
F: FnOnce(&CombinedSnapshot<'tcx>) -> R,
83101
{
@@ -99,4 +117,128 @@ impl<'tcx> InferCtxt<'tcx> {
99117
pub fn opaque_types_added_in_snapshot(&self, snapshot: &CombinedSnapshot<'tcx>) -> bool {
100118
self.inner.borrow().undo_log.opaque_types_in_snapshot(&snapshot.undo_snapshot)
101119
}
120+
121+
fn variable_lengths(&self) -> VariableLengths {
122+
let mut inner = self.inner.borrow_mut();
123+
VariableLengths {
124+
type_vars: inner.type_variables().num_vars(),
125+
const_vars: inner.const_unification_table().len(),
126+
int_vars: inner.int_unification_table().len(),
127+
float_vars: inner.float_unification_table().len(),
128+
region_vars: inner.unwrap_region_constraints().num_region_vars(),
129+
}
130+
}
131+
}
132+
133+
trait NoSnapshotLeaks<'tcx, CX> {
134+
fn mk_data(infcx: &InferCtxt<'tcx>) -> CX;
135+
fn avoid_leaks(self, data: CX) -> Self;
136+
}
137+
138+
pub struct CheckLeaks(HasSnapshotLeaksVisitor);
139+
impl<'tcx, T: TypeFoldable<TyCtxt<'tcx>>> NoSnapshotLeaks<'tcx, CheckLeaks> for T {
140+
fn mk_data(infcx: &InferCtxt<'tcx>) -> CheckLeaks {
141+
CheckLeaks(HasSnapshotLeaksVisitor {
142+
universe: infcx.universe(),
143+
variable_lengths: infcx.variable_lengths(),
144+
})
145+
}
146+
fn avoid_leaks(self, CheckLeaks(mut visitor): CheckLeaks) -> Self {
147+
if cfg!(debug_assertions) && self.visit_with(&mut visitor).is_break() {
148+
bug!("leaking vars from snapshot: {:?}", self)
149+
} else {
150+
self
151+
}
152+
}
153+
}
154+
155+
struct VariableLengths {
156+
region_vars: usize,
157+
type_vars: usize,
158+
int_vars: usize,
159+
float_vars: usize,
160+
const_vars: usize,
161+
}
162+
163+
struct HasSnapshotLeaksVisitor {
164+
universe: ty::UniverseIndex,
165+
variable_lengths: VariableLengths,
166+
}
167+
168+
fn continue_if(b: bool) -> ControlFlow<()> {
169+
if b { ControlFlow::Continue(()) } else { ControlFlow::Continue(()) }
170+
}
171+
172+
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for HasSnapshotLeaksVisitor {
173+
type Result = ControlFlow<()>;
174+
175+
fn visit_region(&mut self, r: ty::Region<'tcx>) -> Self::Result {
176+
match r.kind() {
177+
ty::ReVar(var) => continue_if(var.as_usize() < self.variable_lengths.region_vars),
178+
ty::RePlaceholder(p) => continue_if(self.universe.can_name(p.universe)),
179+
ty::ReEarlyParam(_)
180+
| ty::ReBound(_, _)
181+
| ty::ReLateParam(_)
182+
| ty::ReStatic
183+
| ty::ReErased
184+
| ty::ReError(_) => ControlFlow::Continue(()),
185+
}
186+
}
187+
fn visit_ty(&mut self, t: Ty<'tcx>) -> Self::Result {
188+
match t.kind() {
189+
ty::Infer(ty::TyVar(var)) => {
190+
continue_if(var.as_usize() < self.variable_lengths.type_vars)
191+
}
192+
ty::Infer(ty::IntVar(var)) => {
193+
continue_if(var.as_usize() < self.variable_lengths.int_vars)
194+
}
195+
ty::Infer(ty::FloatVar(var)) => {
196+
continue_if(var.as_usize() < self.variable_lengths.float_vars)
197+
}
198+
ty::Placeholder(p) => continue_if(self.universe.can_name(p.universe)),
199+
ty::Infer(ty::FreshTy(..) | ty::FreshIntTy(..) | ty::FreshFloatTy(..))
200+
| ty::Bool
201+
| ty::Char
202+
| ty::Int(_)
203+
| ty::Uint(_)
204+
| ty::Float(_)
205+
| ty::Adt(_, _)
206+
| ty::Foreign(_)
207+
| ty::Str
208+
| ty::Array(_, _)
209+
| ty::Slice(_)
210+
| ty::RawPtr(_)
211+
| ty::Ref(_, _, _)
212+
| ty::FnDef(_, _)
213+
| ty::FnPtr(_)
214+
| ty::Dynamic(_, _, _)
215+
| ty::Closure(_, _)
216+
| ty::CoroutineClosure(_, _)
217+
| ty::Coroutine(_, _)
218+
| ty::CoroutineWitness(_, _)
219+
| ty::Never
220+
| ty::Tuple(_)
221+
| ty::Alias(_, _)
222+
| ty::Param(_)
223+
| ty::Bound(_, _)
224+
| ty::Error(_) => t.super_visit_with(self),
225+
}
226+
}
227+
fn visit_const(&mut self, c: ty::Const<'tcx>) -> Self::Result {
228+
match c.kind() {
229+
ty::ConstKind::Infer(ty::InferConst::Var(var)) => {
230+
continue_if(var.as_usize() < self.variable_lengths.const_vars)
231+
}
232+
// FIXME(const_trait_impl): need to handle effect vars here and in `fudge_inference_if_ok`.
233+
ty::ConstKind::Infer(ty::InferConst::EffectVar(_)) => ControlFlow::Continue(()),
234+
ty::ConstKind::Placeholder(p) => continue_if(self.universe.can_name(p.universe)),
235+
ty::ConstKind::Infer(ty::InferConst::Fresh(_))
236+
| ty::ConstKind::Param(_)
237+
| ty::ConstKind::Bound(_, _)
238+
| ty::ConstKind::Unevaluated(_)
239+
| ty::ConstKind::Value(_)
240+
| ty::ConstKind::Expr(_)
241+
| ty::ConstKind::Error(_) => c.super_visit_with(self),
242+
}
243+
}
102244
}

compiler/rustc_infer/src/traits/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ pub type Selection<'tcx> = ImplSource<'tcx, PredicateObligation<'tcx>>;
124124
pub type ObligationInspector<'tcx> =
125125
fn(&InferCtxt<'tcx>, &PredicateObligation<'tcx>, Result<Certainty, NoSolution>);
126126

127+
#[derive(Clone, TypeVisitable)]
127128
pub struct FulfillmentError<'tcx> {
128129
pub obligation: PredicateObligation<'tcx>,
129130
pub code: FulfillmentErrorCode<'tcx>,
@@ -133,7 +134,7 @@ pub struct FulfillmentError<'tcx> {
133134
pub root_obligation: PredicateObligation<'tcx>,
134135
}
135136

136-
#[derive(Clone)]
137+
#[derive(Clone, TypeVisitable)]
137138
pub enum FulfillmentErrorCode<'tcx> {
138139
/// Inherently impossible to fulfill; this trait is implemented if and only
139140
/// if it is already implemented.

compiler/rustc_infer/src/traits/project.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub use rustc_middle::traits::{EvaluationResult, Reveal};
1515
pub(crate) type UndoLog<'tcx> =
1616
snapshot_map::UndoLog<ProjectionCacheKey<'tcx>, ProjectionCacheEntry<'tcx>>;
1717

18-
#[derive(Clone)]
18+
#[derive(Clone, TypeVisitable)]
1919
pub struct MismatchedProjectionTypes<'tcx> {
2020
pub err: ty::error::TypeError<'tcx>,
2121
}

0 commit comments

Comments
 (0)