diff --git a/src/valid/function.rs b/src/valid/function.rs index 703c1474f7..ff39fcd573 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -135,6 +135,11 @@ bitflags::bitflags! { } } +struct BlockInfo { + stages: ShaderStages, + finished: bool, +} + struct BlockContext<'a> { abilities: ControlFlowAbility, info: &'a FunctionInfo, @@ -326,7 +331,7 @@ impl super::Validator { &mut self, statements: &[crate::Statement], context: &BlockContext, - ) -> Result { + ) -> Result { use crate::{Statement as S, TypeInner as Ti}; let mut finished = false; let mut stages = ShaderStages::all(); @@ -345,7 +350,9 @@ impl super::Validator { } } S::Block(ref block) => { - stages &= self.validate_block(block, context)?; + let info = self.validate_block(block, context)?; + stages &= info.stages; + finished = info.finished; } S::If { condition, @@ -359,8 +366,8 @@ impl super::Validator { } => {} _ => return Err(FunctionError::InvalidIfType(condition)), } - stages &= self.validate_block(accept, context)?; - stages &= self.validate_block(reject, context)?; + stages &= self.validate_block(accept, context)?.stages; + stages &= self.validate_block(reject, context)?.stages; } S::Switch { selector, @@ -385,9 +392,9 @@ impl super::Validator { let sub_context = context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK); for case in cases { - stages &= self.validate_block(&case.body, &sub_context)?; + stages &= self.validate_block(&case.body, &sub_context)?.stages; } - stages &= self.validate_block(default, &sub_context)?; + stages &= self.validate_block(default, &sub_context)?.stages; } S::Loop { ref body, @@ -397,18 +404,22 @@ impl super::Validator { // because the continuing{} block inherits the scope let base_expression_count = self.valid_expression_list.len(); let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN; - stages &= self.validate_block_impl( - body, - &context.with_abilities( - pass_through_abilities - | ControlFlowAbility::BREAK - | ControlFlowAbility::CONTINUE, - ), - )?; - stages &= self.validate_block_impl( - continuing, - &context.with_abilities(ControlFlowAbility::empty()), - )?; + stages &= self + .validate_block_impl( + body, + &context.with_abilities( + pass_through_abilities + | ControlFlowAbility::BREAK + | ControlFlowAbility::CONTINUE, + ), + )? + .stages; + stages &= self + .validate_block_impl( + continuing, + &context.with_abilities(ControlFlowAbility::empty()), + )? + .stages; for handle in self.valid_expression_list.drain(base_expression_count..) { self.valid_expression_set.remove(handle.index()); } @@ -593,20 +604,20 @@ impl super::Validator { } } } - Ok(stages) + Ok(BlockInfo { stages, finished }) } fn validate_block( &mut self, statements: &[crate::Statement], context: &BlockContext, - ) -> Result { + ) -> Result { let base_expression_count = self.valid_expression_list.len(); - let stages = self.validate_block_impl(statements, context)?; + let info = self.validate_block_impl(statements, context)?; for handle in self.valid_expression_list.drain(base_expression_count..) { self.valid_expression_set.remove(handle.index()); } - Ok(stages) + Ok(info) } fn validate_local_var( @@ -694,10 +705,12 @@ impl super::Validator { } if self.flags.contains(ValidationFlags::BLOCKS) { - let stages = self.validate_block( - &fun.body, - &BlockContext::new(fun, module, &info, &mod_info.functions), - )?; + let stages = self + .validate_block( + &fun.body, + &BlockContext::new(fun, module, &info, &mod_info.functions), + )? + .stages; info.available_stages &= stages; } Ok(info) diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index 44c5706e0d..941865ab9d 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -848,3 +848,34 @@ fn invalid_local_vars() { if local_var_name == "not_okay" } } + +#[test] +fn dead_code() { + check_validation_error! { + " + fn dead_code_after_if(condition: bool) -> i32 { + if (condition) { + return 1; + } else { + return 2; + } + return 3; + } + ": + Ok(_) + } + check_validation_error! { + " + fn dead_code_after_block() -> i32 { + { + return 1; + } + return 2; + } + ": + Err(naga::valid::ValidationError::Function { + error: naga::valid::FunctionError::InstructionsAfterReturn, + .. + }) + } +}