Skip to content

Commit 8e26bb0

Browse files
committed
snapshot: avoid leaking inference vars
1 parent 4f864db commit 8e26bb0

File tree

15 files changed

+348
-154
lines changed

15 files changed

+348
-154
lines changed

compiler/rustc_hir_typeck/src/coercion.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
10761076
let coerce = Coerce::new(self, cause, AllowTwoPhase::No);
10771077
coerce
10781078
.autoderef(rustc_span::DUMMY_SP, expr_ty)
1079-
.find_map(|(ty, steps)| self.probe(|_| coerce.unify(ty, target)).ok().map(|_| steps))
1079+
.find_map(|(ty, steps)| self.probe(|_| coerce.unify(ty, target).ok().map(|_| steps)))
10801080
}
10811081

10821082
/// Given a type, this function will calculate and return the type given

compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rustc_hir_analysis::astconv::AstConv;
1313
use rustc_infer::infer;
1414
use rustc_infer::infer::error_reporting::sub_relations::SubRelations;
1515
use rustc_infer::infer::error_reporting::TypeErrCtxt;
16+
use rustc_infer::infer::snapshot::NoLeaksUnchecked;
1617
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
1718
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind};
1819
use rustc_middle::ty::{self, Const, Ty, TyCtxt, TypeVisitableExt};
@@ -177,11 +178,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
177178
if ocx.select_all_or_error().is_empty() {
178179
let normalized_fn_sig = self.resolve_vars_if_possible(normalized_fn_sig);
179180
if !normalized_fn_sig.has_infer() {
180-
return normalized_fn_sig;
181+
return NoLeaksUnchecked { value: normalized_fn_sig };
181182
}
182183
}
183-
fn_sig
184+
NoLeaksUnchecked { value: fn_sig }
184185
})
186+
.value
185187
}),
186188
autoderef_steps: Box::new(|ty| {
187189
let mut autoderef = self.autoderef(DUMMY_SP, ty).silence_errors();

compiler/rustc_hir_typeck/src/method/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use rustc_hir as hir;
1616
use rustc_hir::def::{CtorOf, DefKind, Namespace};
1717
use rustc_hir::def_id::DefId;
1818
use rustc_infer::infer::{self, InferOk};
19+
use rustc_infer::trivial_no_snapshot_leaks;
1920
use rustc_middle::query::Providers;
2021
use rustc_middle::traits::ObligationCause;
2122
use rustc_middle::ty::{self, GenericParamDefKind, Ty, TypeVisitableExt};
@@ -43,6 +44,8 @@ pub struct MethodCallee<'tcx> {
4344
pub sig: ty::FnSig<'tcx>,
4445
}
4546

47+
// FIXME: This is wrong, method error may leak inference vars.
48+
trivial_no_snapshot_leaks!('tcx, MethodError<'tcx>);
4649
#[derive(Debug)]
4750
pub enum MethodError<'tcx> {
4851
// Did not find an applicable method, but we did find various near-misses that may work.
@@ -79,8 +82,9 @@ pub struct NoMatchData<'tcx> {
7982
pub mode: probe::Mode,
8083
}
8184

82-
// A pared down enum describing just the places from which a method
83-
// candidate can arise. Used for error reporting only.
85+
trivial_no_snapshot_leaks!('tcx, CandidateSource);
86+
/// A pared down enum describing just the places from which a method
87+
/// candidate can arise. Used for error reporting only.
8488
#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
8589
pub enum CandidateSource {
8690
Impl(DefId),

compiler/rustc_hir_typeck/src/method/probe.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ use rustc_hir_analysis::autoderef::{self, Autoderef};
1313
use rustc_infer::infer::canonical::OriginalQueryValues;
1414
use rustc_infer::infer::canonical::{Canonical, QueryResponse};
1515
use rustc_infer::infer::error_reporting::TypeAnnotationNeeded::E0282;
16+
use rustc_infer::infer::snapshot::NoSnapshotLeaks;
1617
use rustc_infer::infer::DefineOpaqueTypes;
1718
use rustc_infer::infer::{self, InferOk, TyCtxtInferExt};
19+
use rustc_infer::trivial_no_snapshot_leaks;
1820
use rustc_middle::middle::stability;
1921
use rustc_middle::query::Providers;
2022
use rustc_middle::ty::fast_reject::{simplify_type, TreatParams};
@@ -97,6 +99,8 @@ impl<'a, 'tcx> Deref for ProbeContext<'a, 'tcx> {
9799
}
98100
}
99101

102+
// FIXME: This is wrong as this type may leak inference variables.`
103+
trivial_no_snapshot_leaks!('tcx, Candidate<'tcx>);
100104
#[derive(Debug, Clone)]
101105
pub(crate) struct Candidate<'tcx> {
102106
// Candidates are (I'm not quite sure, but they are mostly) basically
@@ -152,6 +156,7 @@ pub(crate) enum CandidateKind<'tcx> {
152156
),
153157
}
154158

159+
trivial_no_snapshot_leaks!('tcx, ProbeResult);
155160
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
156161
enum ProbeResult {
157162
NoMatch,
@@ -195,6 +200,8 @@ impl AutorefOrPtrAdjustment {
195200
}
196201
}
197202

203+
// FIXME: This is wrong as this type may leak inference variables.`
204+
trivial_no_snapshot_leaks!('tcx, Pick<'tcx>);
198205
#[derive(Debug, Clone)]
199206
pub struct Pick<'tcx> {
200207
pub item: ty::AssocItem,
@@ -368,6 +375,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
368375
op: OP,
369376
) -> Result<R, MethodError<'tcx>>
370377
where
378+
R: NoSnapshotLeaks<'tcx>,
371379
OP: FnOnce(ProbeContext<'_, 'tcx>) -> Result<R, MethodError<'tcx>>,
372380
{
373381
let mut orig_values = OriginalQueryValues::default();

compiler/rustc_infer/src/infer/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ mod projection;
6262
pub mod region_constraints;
6363
mod relate;
6464
pub mod resolve;
65-
pub(crate) mod snapshot;
65+
pub mod snapshot;
6666
pub mod type_variable;
6767

6868
#[must_use]

compiler/rustc_infer/src/infer/snapshot/fudge.rs

Lines changed: 101 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use ut::UnifyKey;
1212

1313
use std::ops::Range;
1414

15+
use super::{NoSnapshotLeaks, VariableLengths};
16+
1517
fn vars_since_snapshot<'tcx, T>(
1618
table: &mut UnificationTable<'_, 'tcx, T>,
1719
snapshot_var_len: usize,
@@ -88,82 +90,108 @@ impl<'tcx> InferCtxt<'tcx> {
8890
where
8991
F: FnOnce() -> Result<T, E>,
9092
T: TypeFoldable<TyCtxt<'tcx>>,
93+
E: NoSnapshotLeaks<'tcx>,
9194
{
92-
let variable_lengths = self.variable_lengths();
93-
let (mut fudger, value) = self.probe_unchecked(|_| {
94-
match f() {
95-
Ok(value) => {
96-
let value = self.resolve_vars_if_possible(value);
97-
98-
// At this point, `value` could in principle refer
99-
// to inference variables that have been created during
100-
// the snapshot. Once we exit `probe()`, those are
101-
// going to be popped, so we will have to
102-
// eliminate any references to them.
103-
104-
let mut inner = self.inner.borrow_mut();
105-
let type_vars =
106-
inner.type_variables().vars_since_snapshot(variable_lengths.type_vars);
107-
let int_vars = vars_since_snapshot(
108-
&mut inner.int_unification_table(),
109-
variable_lengths.int_vars,
110-
);
111-
let float_vars = vars_since_snapshot(
112-
&mut inner.float_unification_table(),
113-
variable_lengths.float_vars,
114-
);
115-
let region_vars = inner
116-
.unwrap_region_constraints()
117-
.vars_since_snapshot(variable_lengths.region_vars);
118-
let const_vars = const_vars_since_snapshot(
119-
&mut inner.const_unification_table(),
120-
variable_lengths.const_vars,
121-
);
122-
123-
let fudger = InferenceFudger {
124-
infcx: self,
125-
type_vars,
126-
int_vars,
127-
float_vars,
128-
region_vars,
129-
const_vars,
130-
};
95+
self.probe(|_| f().map(|value| FudgeInference(self.resolve_vars_if_possible(value))))
96+
.map(|FudgeInference(value)| value)
97+
}
98+
}
13199

132-
Ok((fudger, value))
100+
#[macro_export]
101+
macro_rules! fudge_vars_no_snapshot_leaks {
102+
($tcx:lifetime, $t:ty) => {
103+
const _: () = {
104+
use $crate::infer::snapshot::fudge::InferenceFudgeData;
105+
impl<$tcx> $crate::infer::snapshot::NoSnapshotLeaks<$tcx> for $t {
106+
type DataStart = $crate::infer::snapshot::VariableLengths;
107+
type DataEnd = InferenceFudgeData;
108+
fn mk_data_snapshot_start(infcx: &InferCtxt<$tcx>) -> Self::DataStart {
109+
infcx.variable_lengths()
110+
}
111+
fn mk_data_snapshot_end(
112+
infcx: &InferCtxt<$tcx>,
113+
variable_lengths: Self::DataStart,
114+
) -> Self::DataEnd {
115+
InferenceFudgeData::new(infcx, variable_lengths)
116+
}
117+
fn avoid_leaks(self, infcx: &InferCtxt<'tcx>, fudge_data: Self::DataEnd) -> Self {
118+
fudge_data.fudge_inference(infcx, self)
133119
}
134-
Err(e) => Err(e),
135120
}
136-
})?;
137-
138-
// At this point, we need to replace any of the now-popped
139-
// type/region variables that appear in `value` with a fresh
140-
// variable of the appropriate kind. We can't do this during
141-
// the probe because they would just get popped then too. =)
142-
143-
// Micro-optimization: if no variables have been created, then
144-
// `value` can't refer to any of them. =) So we can just return it.
145-
if fudger.type_vars.0.is_empty()
146-
&& fudger.int_vars.is_empty()
147-
&& fudger.float_vars.is_empty()
148-
&& fudger.region_vars.0.is_empty()
149-
&& fudger.const_vars.0.is_empty()
150-
{
151-
Ok(value)
152-
} else {
153-
Ok(value.fold_with(&mut fudger))
154-
}
121+
};
122+
};
123+
}
124+
125+
struct FudgeInference<T>(T);
126+
impl<'tcx, T: TypeFoldable<TyCtxt<'tcx>>> NoSnapshotLeaks<'tcx> for FudgeInference<T> {
127+
type DataStart = VariableLengths;
128+
type DataEnd = InferenceFudgeData;
129+
fn mk_data_snapshot_start(infcx: &InferCtxt<'tcx>) -> Self::DataStart {
130+
infcx.variable_lengths()
131+
}
132+
fn mk_data_snapshot_end(
133+
infcx: &InferCtxt<'tcx>,
134+
variable_lengths: Self::DataStart,
135+
) -> Self::DataEnd {
136+
InferenceFudgeData::new(infcx, variable_lengths)
137+
}
138+
fn avoid_leaks(self, infcx: &InferCtxt<'tcx>, fudge_data: Self::DataEnd) -> Self {
139+
FudgeInference(fudge_data.fudge_inference(infcx, self.0))
155140
}
156141
}
157142

158-
pub struct InferenceFudger<'a, 'tcx> {
159-
infcx: &'a InferCtxt<'tcx>,
143+
pub struct InferenceFudgeData {
160144
type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
161145
int_vars: Range<IntVid>,
162146
float_vars: Range<FloatVid>,
163147
region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
164148
const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
165149
}
166150

151+
impl InferenceFudgeData {
152+
pub fn new<'tcx>(
153+
infcx: &InferCtxt<'tcx>,
154+
variable_lengths: VariableLengths,
155+
) -> InferenceFudgeData {
156+
let mut inner = infcx.inner.borrow_mut();
157+
let type_vars = inner.type_variables().vars_since_snapshot(variable_lengths.type_vars);
158+
let int_vars =
159+
vars_since_snapshot(&mut inner.int_unification_table(), variable_lengths.int_vars);
160+
let float_vars =
161+
vars_since_snapshot(&mut inner.float_unification_table(), variable_lengths.float_vars);
162+
let region_vars =
163+
inner.unwrap_region_constraints().vars_since_snapshot(variable_lengths.region_vars);
164+
let const_vars = const_vars_since_snapshot(
165+
&mut inner.const_unification_table(),
166+
variable_lengths.const_vars,
167+
);
168+
169+
InferenceFudgeData { type_vars, int_vars, float_vars, region_vars, const_vars }
170+
}
171+
172+
pub fn fudge_inference<'tcx, T: TypeFoldable<TyCtxt<'tcx>>>(
173+
self,
174+
infcx: &InferCtxt<'tcx>,
175+
value: T,
176+
) -> T {
177+
if self.type_vars.0.is_empty()
178+
&& self.int_vars.is_empty()
179+
&& self.float_vars.is_empty()
180+
&& self.region_vars.0.is_empty()
181+
&& self.const_vars.0.is_empty()
182+
{
183+
value
184+
} else {
185+
value.fold_with(&mut InferenceFudger { infcx, data: self })
186+
}
187+
}
188+
}
189+
190+
struct InferenceFudger<'a, 'tcx> {
191+
infcx: &'a InferCtxt<'tcx>,
192+
data: InferenceFudgeData,
193+
}
194+
167195
impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
168196
fn interner(&self) -> TyCtxt<'tcx> {
169197
self.infcx.tcx
@@ -172,11 +200,11 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
172200
fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
173201
match *ty.kind() {
174202
ty::Infer(ty::InferTy::TyVar(vid)) => {
175-
if self.type_vars.0.contains(&vid) {
203+
if self.data.type_vars.0.contains(&vid) {
176204
// This variable was created during the fudging.
177205
// Recreate it with a fresh variable here.
178-
let idx = vid.as_usize() - self.type_vars.0.start.as_usize();
179-
let origin = self.type_vars.1[idx];
206+
let idx = vid.as_usize() - self.data.type_vars.0.start.as_usize();
207+
let origin = self.data.type_vars.1[idx];
180208
self.infcx.next_ty_var(origin)
181209
} else {
182210
// This variable was created before the
@@ -191,14 +219,14 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
191219
}
192220
}
193221
ty::Infer(ty::InferTy::IntVar(vid)) => {
194-
if self.int_vars.contains(&vid) {
222+
if self.data.int_vars.contains(&vid) {
195223
self.infcx.next_int_var()
196224
} else {
197225
ty
198226
}
199227
}
200228
ty::Infer(ty::InferTy::FloatVar(vid)) => {
201-
if self.float_vars.contains(&vid) {
229+
if self.data.float_vars.contains(&vid) {
202230
self.infcx.next_float_var()
203231
} else {
204232
ty
@@ -210,22 +238,22 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
210238

211239
fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
212240
if let ty::ReVar(vid) = *r
213-
&& self.region_vars.0.contains(&vid)
241+
&& self.data.region_vars.0.contains(&vid)
214242
{
215-
let idx = vid.index() - self.region_vars.0.start.index();
216-
let origin = self.region_vars.1[idx];
243+
let idx = vid.index() - self.data.region_vars.0.start.index();
244+
let origin = self.data.region_vars.1[idx];
217245
return self.infcx.next_region_var(origin);
218246
}
219247
r
220248
}
221249

222250
fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
223251
if let ty::ConstKind::Infer(ty::InferConst::Var(vid)) = ct.kind() {
224-
if self.const_vars.0.contains(&vid) {
252+
if self.data.const_vars.0.contains(&vid) {
225253
// This variable was created during the fudging.
226254
// Recreate it with a fresh variable here.
227-
let idx = vid.index() - self.const_vars.0.start.index();
228-
let origin = self.const_vars.1[idx];
255+
let idx = vid.index() - self.data.const_vars.0.start.index();
256+
let origin = self.data.const_vars.1[idx];
229257
self.infcx.next_const_var(ct.ty(), origin)
230258
} else {
231259
ct

0 commit comments

Comments
 (0)