Skip to content

Commit 8d019e4

Browse files
authored
Detect recursion in inline.rs and bail (#770)
* Detect recursion in inline.rs and bail * Bail on inline error
1 parent e5c2953 commit 8d019e4

File tree

5 files changed

+181
-33
lines changed

5 files changed

+181
-33
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66
77
use super::apply_rewrite_rules;
88
use super::simple_passes::outgoing_edges;
9+
use super::{get_name, get_names};
910
use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
1011
use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
1112
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
13+
use rustc_session::Session;
1214
use std::mem::take;
1315

1416
type FunctionMap = FxHashMap<Word, Function>;
1517

16-
pub fn inline(module: &mut Module) {
18+
pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
19+
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
20+
if module_has_recursion(sess, module) {
21+
return Err(rustc_errors::ErrorReported);
22+
}
1723
let functions = module
1824
.functions
1925
.iter()
@@ -58,6 +64,89 @@ pub fn inline(module: &mut Module) {
5864
inliner.inline_fn(function);
5965
fuse_trivial_branches(function);
6066
}
67+
Ok(())
68+
}
69+
70+
// https://stackoverflow.com/a/53995651
71+
fn module_has_recursion(sess: &Session, module: &Module) -> bool {
72+
let func_to_index: FxHashMap<Word, usize> = module
73+
.functions
74+
.iter()
75+
.enumerate()
76+
.map(|(index, func)| (func.def_id().unwrap(), index))
77+
.collect();
78+
let mut discovered = vec![false; module.functions.len()];
79+
let mut finished = vec![false; module.functions.len()];
80+
let mut has_recursion = false;
81+
for index in 0..module.functions.len() {
82+
if !discovered[index] && !finished[index] {
83+
visit(
84+
sess,
85+
module,
86+
index,
87+
&mut discovered,
88+
&mut finished,
89+
&mut has_recursion,
90+
&func_to_index,
91+
);
92+
}
93+
}
94+
95+
fn visit(
96+
sess: &Session,
97+
module: &Module,
98+
current: usize,
99+
discovered: &mut Vec<bool>,
100+
finished: &mut Vec<bool>,
101+
has_recursion: &mut bool,
102+
func_to_index: &FxHashMap<Word, usize>,
103+
) {
104+
discovered[current] = true;
105+
106+
for next in calls(&module.functions[current], func_to_index) {
107+
if discovered[next] {
108+
let names = get_names(module);
109+
let current_name = get_name(&names, module.functions[current].def_id().unwrap());
110+
let next_name = get_name(&names, module.functions[next].def_id().unwrap());
111+
sess.err(&format!(
112+
"module has recursion, which is not allowed: `{}` calls `{}`",
113+
current_name, next_name
114+
));
115+
*has_recursion = true;
116+
break;
117+
}
118+
119+
if !finished[next] {
120+
visit(
121+
sess,
122+
module,
123+
next,
124+
discovered,
125+
finished,
126+
has_recursion,
127+
func_to_index,
128+
);
129+
}
130+
}
131+
132+
discovered[current] = false;
133+
finished[current] = true;
134+
}
135+
136+
fn calls<'a>(
137+
func: &'a Function,
138+
func_to_index: &'a FxHashMap<Word, usize>,
139+
) -> impl Iterator<Item = usize> + 'a {
140+
func.all_inst_iter()
141+
.filter(|inst| inst.class.opcode == Op::FunctionCall)
142+
.map(move |inst| {
143+
*func_to_index
144+
.get(&inst.operands[0].id_ref_any().unwrap())
145+
.unwrap()
146+
})
147+
}
148+
149+
has_recursion
61150
}
62151

63152
fn compute_disallowed_argument_and_return_types(

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ mod specializer;
1515
mod structurizer;
1616
mod zombies;
1717

18+
use std::borrow::Cow;
19+
1820
use crate::codegen_cx::SpirvMetadata;
1921
use crate::decorations::{CustomDecoration, UnrollLoopsDecoration};
2022
use rspirv::binary::{Assemble, Consumer};
@@ -77,6 +79,27 @@ fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Bloc
7779
}
7880
}
7981

82+
fn get_names(module: &Module) -> FxHashMap<Word, &str> {
83+
module
84+
.debug_names
85+
.iter()
86+
.filter(|i| i.class.opcode == Op::Name)
87+
.map(|i| {
88+
(
89+
i.operands[0].unwrap_id_ref(),
90+
i.operands[1].unwrap_literal_string(),
91+
)
92+
})
93+
.collect()
94+
}
95+
96+
fn get_name<'a>(names: &FxHashMap<Word, &'a str>, id: Word) -> Cow<'a, str> {
97+
names
98+
.get(&id)
99+
.map(|&s| Cow::Borrowed(s))
100+
.unwrap_or_else(|| Cow::Owned(format!("Unnamed function ID %{}", id)))
101+
}
102+
80103
pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<LinkResult> {
81104
let mut output = {
82105
let _timer = sess.timer("link_merge");
@@ -178,7 +201,7 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
178201

179202
{
180203
let _timer = sess.timer("link_inline");
181-
inline::inline(&mut output);
204+
inline::inline(sess, &mut output)?;
182205
}
183206

184207
if opts.dce {

crates/rustc_codegen_spirv/src/linker/zombies.rs

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
//! See documentation on `CodegenCx::zombie` for a description of the zombie system.
22
3+
use super::{get_name, get_names};
34
use crate::decorations::{CustomDecoration, ZombieDecoration};
45
use rspirv::dr::{Instruction, Module};
5-
use rspirv::spirv::{Op, Word};
6+
use rspirv::spirv::Word;
67
use rustc_data_structures::fx::FxHashMap;
78
use rustc_session::Session;
89
use rustc_span::{Span, DUMMY_SP};
@@ -103,20 +104,6 @@ fn spread_zombie(module: &mut Module, zombie: &mut FxHashMap<Word, ZombieInfo<'_
103104
any
104105
}
105106

106-
fn get_names(module: &Module) -> FxHashMap<Word, &str> {
107-
module
108-
.debug_names
109-
.iter()
110-
.filter(|i| i.class.opcode == Op::Name)
111-
.map(|i| {
112-
(
113-
i.operands[0].unwrap_id_ref(),
114-
i.operands[1].unwrap_literal_string(),
115-
)
116-
})
117-
.collect()
118-
}
119-
120107
// If an entry point references a zombie'd value, then the entry point would normally get removed.
121108
// That's an absolutely horrible experience to debug, though, so instead, create a nice error
122109
// message containing the stack trace of how the entry point got to the zombie value.
@@ -125,12 +112,10 @@ fn report_error_zombies(sess: &Session, module: &Module, zombie: &FxHashMap<Word
125112
for root in super::dce::collect_roots(module) {
126113
if let Some(reason) = zombie.get(&root) {
127114
let names = names.get_or_insert_with(|| get_names(module));
128-
let stack = reason.stack.iter().map(|s| {
129-
names
130-
.get(s)
131-
.map(|&n| n.to_string())
132-
.unwrap_or_else(|| format!("Unnamed function ID %{}", s))
133-
});
115+
let stack = reason
116+
.stack
117+
.iter()
118+
.map(|&s| get_name(names, s).into_owned());
134119
let stack_note = once("Stack:".to_string())
135120
.chain(stack)
136121
.collect::<Vec<_>>()
@@ -174,18 +159,10 @@ pub fn remove_zombies(sess: &Session, module: &mut Module) {
174159
}
175160

176161
if env::var("PRINT_ZOMBIE").is_ok() {
162+
let names = get_names(module);
177163
for f in &module.functions {
178164
if let Some(reason) = is_zombie(f.def.as_ref().unwrap(), &zombies) {
179-
let name_id = f.def_id().unwrap();
180-
let name = module.debug_names.iter().find(|inst| {
181-
inst.class.opcode == Op::Name && inst.operands[0].unwrap_id_ref() == name_id
182-
});
183-
let name = match name {
184-
Some(Instruction { ref operands, .. }) => {
185-
operands[1].unwrap_literal_string().to_string()
186-
}
187-
_ => format!("{}", name_id),
188-
};
165+
let name = get_name(&names, f.def_id().unwrap());
189166
println!("Function removed {:?} because {:?}", name, reason.reason);
190167
}
191168
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// build-pass
2+
3+
use glam::UVec3;
4+
use spirv_std::glam;
5+
use spirv_std::glam::{Mat3, Vec3, Vec4};
6+
7+
fn index_to_transform(index: usize, raw_data: &[u8]) -> Transform2D {
8+
Transform2D {
9+
own_transform: Mat3::IDENTITY,
10+
parent_offset: 0,
11+
}
12+
}
13+
14+
const SIZE_OF_TRANSFORM: usize = core::mem::size_of::<Transform2D>();
15+
16+
#[derive(Clone)]
17+
struct Transform2D {
18+
own_transform: Mat3,
19+
parent_offset: i32,
20+
}
21+
22+
trait GivesFinalTransform {
23+
fn get_final_transform(&self, raw_data: &[u8]) -> Mat3;
24+
}
25+
26+
impl GivesFinalTransform for (i32, Transform2D) {
27+
fn get_final_transform(&self, raw_data: &[u8]) -> Mat3 {
28+
if self.1.parent_offset == 0 {
29+
self.1.own_transform
30+
} else {
31+
let parent_index = self.0 + self.1.parent_offset;
32+
self.1.own_transform.mul_mat3(
33+
&((
34+
parent_index as i32,
35+
index_to_transform(parent_index as usize, raw_data),
36+
)
37+
.get_final_transform(raw_data)),
38+
)
39+
}
40+
}
41+
}
42+
43+
#[spirv(compute(threads(64)))]
44+
pub fn main_cs(
45+
#[spirv(global_invocation_id)] id: UVec3,
46+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] raw_data: &mut [u8],
47+
#[spirv(position)] output_position: &mut Vec4,
48+
) {
49+
let index = id.x as usize;
50+
let final_transform =
51+
(index as i32, index_to_transform(index, raw_data)).get_final_transform(raw_data);
52+
*output_position = final_transform
53+
.mul_vec3(Vec3::new(0.1, 0.2, 0.3))
54+
.extend(0.0);
55+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
error: module has recursion, which is not allowed: `<(i32, issue_764::Transform2D) as issue_764::GivesFinalTransform>::get_final_transform` calls `<(i32, issue_764::Transform2D) as issue_764::GivesFinalTransform>::get_final_transform`
2+
3+
error: aborting due to previous error
4+

0 commit comments

Comments
 (0)