Skip to content

Commit f9c1202

Browse files
committed
autodiff: adding recursion max depth in typetree
Signed-off-by: Karan Janthe <[email protected]>
1 parent 731a98a commit f9c1202

File tree

7 files changed

+221
-90
lines changed

7 files changed

+221
-90
lines changed

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 124 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,8 +2257,33 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
22572257

22582258
/// Generate TypeTree for a specific type.
22592259
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
2260+
2261+
/// Maximum recursion depth for TypeTree generation to prevent stack overflow
2262+
/// from pathological deeply nested types. Combined with cycle detection.
2263+
const MAX_TYPETREE_DEPTH: usize = 32;
2264+
22602265
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
2261-
if ty.is_scalar() {
2266+
let mut visited = Vec::new();
2267+
typetree_from_ty_impl(tcx, ty, 0, &mut visited)
2268+
}
2269+
2270+
fn typetree_from_ty_impl<'tcx>(
2271+
tcx: TyCtxt<'tcx>,
2272+
ty: Ty<'tcx>,
2273+
depth: usize,
2274+
visited: &mut Vec<Ty<'tcx>>,
2275+
) -> TypeTree {
2276+
if depth > MAX_TYPETREE_DEPTH {
2277+
return TypeTree::new();
2278+
}
2279+
2280+
if visited.contains(&ty) {
2281+
return TypeTree::new();
2282+
}
2283+
2284+
visited.push(ty);
2285+
2286+
let result = if ty.is_scalar() {
22622287
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
22632288
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
22642289
} else if ty.is_floating_point() {
@@ -2267,116 +2292,118 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22672292
x if x == tcx.types.f32 => (Kind::Float, 4),
22682293
x if x == tcx.types.f64 => (Kind::Double, 8),
22692294
x if x == tcx.types.f128 => (Kind::F128, 16),
2270-
_ => return TypeTree::new(),
2295+
_ => (Kind::Integer, 0),
22712296
}
22722297
} else {
2273-
return TypeTree::new();
2298+
(Kind::Integer, 0)
22742299
};
22752300

2276-
return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
2277-
}
2278-
2279-
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
2280-
let inner_ty = if let Some(inner) = ty.builtin_deref(true) {
2281-
inner
2301+
if size > 0 {
2302+
TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }])
22822303
} else {
2283-
return TypeTree::new();
2284-
};
2285-
2286-
let child = typetree_from_ty(tcx, inner_ty);
2287-
return TypeTree(vec![Type {
2288-
offset: -1,
2289-
size: tcx.data_layout.pointer_size().bytes_usize(),
2290-
kind: Kind::Pointer,
2291-
child,
2292-
}]);
2293-
}
2294-
2295-
if ty.is_array() {
2304+
TypeTree::new()
2305+
}
2306+
} else if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
2307+
if let Some(inner_ty) = ty.builtin_deref(true) {
2308+
let child = typetree_from_ty_impl(tcx, inner_ty, depth + 1, visited);
2309+
TypeTree(vec![Type {
2310+
offset: -1,
2311+
size: tcx.data_layout.pointer_size().bytes_usize(),
2312+
kind: Kind::Pointer,
2313+
child,
2314+
}])
2315+
} else {
2316+
TypeTree::new()
2317+
}
2318+
} else if ty.is_array() {
22962319
if let ty::Array(element_ty, len_const) = ty.kind() {
22972320
let len = len_const.try_to_target_usize(tcx).unwrap_or(0);
22982321
if len == 0 {
2299-
return TypeTree::new();
2300-
}
2301-
2302-
let element_tree = typetree_from_ty(tcx, *element_ty);
2303-
2304-
let element_layout = tcx
2305-
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty))
2306-
.ok()
2307-
.map(|layout| layout.size.bytes_usize())
2308-
.unwrap_or(0);
2309-
2310-
if element_layout == 0 {
2311-
return TypeTree::new();
2322+
TypeTree::new()
2323+
} else {
2324+
let element_tree = typetree_from_ty_impl(tcx, *element_ty, depth + 1, visited);
2325+
let element_layout = tcx
2326+
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty))
2327+
.ok()
2328+
.map(|layout| layout.size.bytes_usize())
2329+
.unwrap_or(0);
2330+
2331+
if element_layout == 0 {
2332+
TypeTree::new()
2333+
} else {
2334+
// For homogeneous arrays, use offset -1 instead of individual entries
2335+
if element_tree.0.len() == 1 && element_tree.0[0].offset == -1 {
2336+
TypeTree(vec![Type {
2337+
offset: -1,
2338+
size: element_tree.0[0].size,
2339+
kind: element_tree.0[0].kind,
2340+
child: element_tree.0[0].child.clone(),
2341+
}])
2342+
} else {
2343+
let mut types = Vec::new();
2344+
for i in 0..len {
2345+
let base_offset = (i as usize * element_layout) as isize;
2346+
2347+
for elem_type in &element_tree.0 {
2348+
types.push(Type {
2349+
offset: if elem_type.offset == -1 {
2350+
base_offset
2351+
} else {
2352+
base_offset + elem_type.offset
2353+
},
2354+
size: elem_type.size,
2355+
kind: elem_type.kind,
2356+
child: elem_type.child.clone(),
2357+
});
2358+
}
2359+
}
2360+
TypeTree(types)
2361+
}
2362+
}
23122363
}
2313-
2364+
} else {
2365+
TypeTree::new()
2366+
}
2367+
} else if ty.is_slice() {
2368+
if let ty::Slice(element_ty) = ty.kind() {
2369+
typetree_from_ty_impl(tcx, *element_ty, depth + 1, visited)
2370+
} else {
2371+
TypeTree::new()
2372+
}
2373+
} else if let ty::Tuple(tuple_types) = ty.kind() {
2374+
if tuple_types.is_empty() {
2375+
TypeTree::new()
2376+
} else {
23142377
let mut types = Vec::new();
2315-
for i in 0..len {
2316-
let base_offset = (i as usize * element_layout) as isize;
2378+
let mut current_offset = 0;
2379+
2380+
for tuple_ty in tuple_types.iter() {
2381+
let element_tree = typetree_from_ty_impl(tcx, tuple_ty, depth + 1, visited);
2382+
let element_layout = tcx
2383+
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
2384+
.ok()
2385+
.map(|layout| layout.size.bytes_usize())
2386+
.unwrap_or(0);
23172387

23182388
for elem_type in &element_tree.0 {
23192389
types.push(Type {
23202390
offset: if elem_type.offset == -1 {
2321-
base_offset
2391+
current_offset as isize
23222392
} else {
2323-
base_offset + elem_type.offset
2393+
current_offset as isize + elem_type.offset
23242394
},
23252395
size: elem_type.size,
23262396
kind: elem_type.kind,
23272397
child: elem_type.child.clone(),
23282398
});
23292399
}
2330-
}
23312400

2332-
return TypeTree(types);
2333-
}
2334-
}
2335-
2336-
if ty.is_slice() {
2337-
if let ty::Slice(element_ty) = ty.kind() {
2338-
let element_tree = typetree_from_ty(tcx, *element_ty);
2339-
return element_tree;
2340-
}
2341-
}
2342-
2343-
if let ty::Tuple(tuple_types) = ty.kind() {
2344-
if tuple_types.is_empty() {
2345-
return TypeTree::new();
2346-
}
2347-
2348-
let mut types = Vec::new();
2349-
let mut current_offset = 0;
2350-
2351-
for tuple_ty in tuple_types.iter() {
2352-
let element_tree = typetree_from_ty(tcx, tuple_ty);
2353-
2354-
let element_layout = tcx
2355-
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
2356-
.ok()
2357-
.map(|layout| layout.size.bytes_usize())
2358-
.unwrap_or(0);
2359-
2360-
for elem_type in &element_tree.0 {
2361-
types.push(Type {
2362-
offset: if elem_type.offset == -1 {
2363-
current_offset as isize
2364-
} else {
2365-
current_offset as isize + elem_type.offset
2366-
},
2367-
size: elem_type.size,
2368-
kind: elem_type.kind,
2369-
child: elem_type.child.clone(),
2370-
});
2401+
current_offset += element_layout;
23712402
}
23722403

2373-
current_offset += element_layout;
2404+
TypeTree(types)
23742405
}
2375-
2376-
return TypeTree(types);
2377-
}
2378-
2379-
if let ty::Adt(adt_def, args) = ty.kind() {
2406+
} else if let ty::Adt(adt_def, args) = ty.kind() {
23802407
if adt_def.is_struct() {
23812408
let struct_layout =
23822409
tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty));
@@ -2385,7 +2412,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23852412

23862413
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
23872414
let field_ty = field_def.ty(tcx, args);
2388-
let field_tree = typetree_from_ty(tcx, field_ty);
2415+
let field_tree = typetree_from_ty_impl(tcx, field_ty, depth + 1, visited);
23892416

23902417
let field_offset = layout.fields.offset(field_idx).bytes_usize();
23912418

@@ -2403,10 +2430,17 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
24032430
}
24042431
}
24052432

2406-
return TypeTree(types);
2433+
TypeTree(types)
2434+
} else {
2435+
TypeTree::new()
24072436
}
2437+
} else {
2438+
TypeTree::new()
24082439
}
2409-
}
2440+
} else {
2441+
TypeTree::new()
2442+
};
24102443

2411-
TypeTree::new()
2444+
visited.pop();
2445+
result
24122446
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
; Check that mixed struct with large array compiles successfully
2+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@float}"{{.*}}@test_mixed_struct{{.*}}"enzyme_type"="{[]:Pointer}"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc()
8+
.input("test.rs")
9+
.arg("-Zautodiff=Enable")
10+
.arg("-Zautodiff=NoPostopt")
11+
.opt_level("0")
12+
.emit("llvm-ir")
13+
.run();
14+
15+
llvm_filecheck().patterns("mixed.check").stdin_buf(rfs::read("test.ll")).run();
16+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[repr(C)]
6+
struct Container {
7+
header: i64,
8+
data: [f32; 1000],
9+
}
10+
11+
#[autodiff_reverse(d_test, Duplicated, Active)]
12+
#[no_mangle]
13+
#[inline(never)]
14+
fn test_mixed_struct(container: &Container) -> f32 {
15+
container.data[0] + container.data[999]
16+
}
17+
18+
fn main() {
19+
let container = Container { header: 42, data: [1.0; 1000] };
20+
let mut d_container = Container { header: 0, data: [0.0; 1000] };
21+
let result = d_test(&container, &mut d_container, 1.0);
22+
std::hint::black_box(result);
23+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; Check that recursive types compile successfully without infinite loops
2+
; The recursion depth limit should prevent stack overflow during TypeTree generation
3+
4+
; Function should compile and have enzyme_type attributes
5+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@float}"{{.*}}@test_recursion_depth{{.*}}"enzyme_type"="{[]:Pointer}"
6+
7+
; Compilation should complete without hanging or crashing
8+
; The mere fact that this test runs means recursion limiting is working
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
// This test ensures that recursive types don't cause infinite loops
8+
// The compiler should complete successfully due to recursion limits
9+
rustc()
10+
.input("test.rs")
11+
.arg("-Zautodiff=Enable")
12+
.arg("-Zautodiff=NoPostopt")
13+
.opt_level("0")
14+
.emit("llvm-ir")
15+
.run();
16+
17+
llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run();
18+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
// Create mutually recursive types that would cause cycles
6+
#[repr(C)]
7+
struct NodeA {
8+
value: f32,
9+
b_ref: Option<Box<NodeB>>,
10+
}
11+
12+
#[repr(C)]
13+
struct NodeB {
14+
value: f64,
15+
a_ref: Option<Box<NodeA>>, // Mutual recursion: A -> B -> A -> B...
16+
}
17+
18+
#[autodiff_reverse(d_test, Duplicated, Active)]
19+
#[no_mangle]
20+
#[inline(never)]
21+
fn test_recursion_depth(node: &NodeA) -> f32 {
22+
node.value
23+
}
24+
25+
fn main() {
26+
let node = NodeA { value: 1.0, b_ref: None };
27+
let mut d_node = NodeA { value: 0.0, b_ref: None };
28+
let result = d_test(&node, &mut d_node, 1.0);
29+
std::hint::black_box(result);
30+
}

0 commit comments

Comments
 (0)