Skip to content

Commit 095ffe7

Browse files
committed
difftest: replace error enums with anyhow
1 parent fab1f2f commit 095ffe7

File tree

5 files changed

+40
-64
lines changed

5 files changed

+40
-64
lines changed

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/difftests/lib/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ wgpu = { version = "23", features = ["spirv", "vulkan-portability"] }
2323
tempfile = "3.5"
2424
futures = "0.3.31"
2525
bytemuck = "1.21.0"
26-
thiserror = "1.0"
26+
anyhow = "1.0.98"
2727

2828
[lints]
2929
workspace = true

tests/difftests/lib/src/config.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
11
use serde::Deserialize;
22
use std::{fs, path::Path};
3-
use thiserror::Error;
4-
5-
#[derive(Error, Debug)]
6-
pub enum ConfigError {
7-
#[error("I/O error: {0}")]
8-
Io(#[from] std::io::Error),
9-
#[error("JSON error: {0}")]
10-
Json(#[from] serde_json::Error),
11-
}
123

134
#[derive(Debug, Deserialize)]
145
pub struct Config {
156
pub output_path: std::path::PathBuf,
167
}
178

189
impl Config {
19-
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
10+
pub fn from_path<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
2011
let content = fs::read_to_string(path)?;
2112
let config = serde_json::from_str(&content)?;
2213
Ok(config)

tests/difftests/lib/src/scaffold/compute/wgpu.rs

Lines changed: 28 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::config::Config;
2+
use anyhow::Context;
23
use bytemuck::Pod;
3-
use futures::{channel::oneshot::Canceled, executor::block_on};
4+
use futures::executor::block_on;
45
use spirv_builder::{ModuleResult, SpirvBuilder};
56
use std::{
67
borrow::Cow,
@@ -9,29 +10,14 @@ use std::{
910
io::Write,
1011
path::PathBuf,
1112
};
12-
use thiserror::Error;
13-
use wgpu::{BufferAsyncError, PipelineCompilationOptions, util::DeviceExt};
14-
15-
#[derive(Error, Debug)]
16-
pub enum ComputeError {
17-
#[error("Failed to find a suitable GPU adapter")]
18-
AdapterNotFound,
19-
#[error("Failed to create device: {0}")]
20-
DeviceCreationFailed(String),
21-
#[error("Failed to load shader: {0}")]
22-
ShaderLoadFailed(String),
23-
#[error("Mapping compute output future canceled: {0}")]
24-
MappingCanceled(Canceled),
25-
#[error("Mapping compute output failed: {0}")]
26-
MappingFailed(BufferAsyncError),
27-
}
13+
use wgpu::{PipelineCompilationOptions, util::DeviceExt};
2814

2915
/// Trait that creates a shader module and provides its entry point.
3016
pub trait ComputeShader {
3117
fn create_module(
3218
&self,
3319
device: &wgpu::Device,
34-
) -> Result<(wgpu::ShaderModule, Option<String>), ComputeError>;
20+
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)>;
3521
}
3622

3723
/// A compute shader written in Rust compiled with spirv-builder.
@@ -49,40 +35,33 @@ impl ComputeShader for RustComputeShader {
4935
fn create_module(
5036
&self,
5137
device: &wgpu::Device,
52-
) -> Result<(wgpu::ShaderModule, Option<String>), ComputeError> {
38+
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> {
5339
let builder = SpirvBuilder::new(&self.path, "spirv-unknown-vulkan1.1")
5440
.print_metadata(spirv_builder::MetadataPrintout::None)
5541
.release(true)
5642
.multimodule(false)
5743
.shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit)
5844
.preserve_bindings(true);
59-
let artifact = builder
60-
.build()
61-
.map_err(|e| ComputeError::ShaderLoadFailed(e.to_string()))?;
45+
let artifact = builder.build().context("SpirvBuilder::build() failed")?;
6246

6347
if artifact.entry_points.len() != 1 {
64-
return Err(ComputeError::ShaderLoadFailed(format!(
48+
anyhow::bail!(
6549
"Expected exactly one entry point, found {}",
6650
artifact.entry_points.len()
67-
)));
51+
);
6852
}
6953
let entry_point = artifact.entry_points.into_iter().next().unwrap();
7054

7155
let shader_bytes = match artifact.module {
72-
ModuleResult::SingleModule(path) => {
73-
fs::read(&path).map_err(|e| ComputeError::ShaderLoadFailed(e.to_string()))?
74-
}
56+
ModuleResult::SingleModule(path) => fs::read(&path)
57+
.with_context(|| format!("reading spv file '{}' failed", path.display()))?,
7558
ModuleResult::MultiModule(_modules) => {
76-
return Err(ComputeError::ShaderLoadFailed(
77-
"Multiple modules produced".to_string(),
78-
));
59+
anyhow::bail!("MultiModule modules produced");
7960
}
8061
};
8162

8263
if shader_bytes.len() % 4 != 0 {
83-
return Err(ComputeError::ShaderLoadFailed(
84-
"SPIR-V binary length is not a multiple of 4".to_string(),
85-
));
64+
anyhow::bail!("SPIR-V binary length is not a multiple of 4");
8665
}
8766
let shader_words: Vec<u32> = bytemuck::cast_slice(&shader_bytes).to_vec();
8867
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
@@ -112,9 +91,9 @@ impl ComputeShader for WgslComputeShader {
11291
fn create_module(
11392
&self,
11493
device: &wgpu::Device,
115-
) -> Result<(wgpu::ShaderModule, Option<String>), ComputeError> {
94+
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> {
11695
let shader_source = fs::read_to_string(&self.path)
117-
.map_err(|e| ComputeError::ShaderLoadFailed(e.to_string()))?;
96+
.with_context(|| format!("reading wgsl source file '{}'", &self.path.display()))?;
11897
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
11998
label: Some("Compute Shader"),
12099
source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_source)),
@@ -142,7 +121,7 @@ where
142121
}
143122
}
144123

145-
fn init() -> Result<(wgpu::Device, wgpu::Queue), ComputeError> {
124+
fn init() -> anyhow::Result<(wgpu::Device, wgpu::Queue)> {
146125
block_on(async {
147126
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
148127
#[cfg(target_os = "linux")]
@@ -160,7 +139,7 @@ where
160139
force_fallback_adapter: false,
161140
})
162141
.await
163-
.ok_or(ComputeError::AdapterNotFound)?;
142+
.context("Failed to find a suitable GPU adapter")?;
164143
let (device, queue) = adapter
165144
.request_device(
166145
&wgpu::DeviceDescriptor {
@@ -175,12 +154,12 @@ where
175154
None,
176155
)
177156
.await
178-
.map_err(|e| ComputeError::DeviceCreationFailed(e.to_string()))?;
157+
.context("Failed to create device")?;
179158
Ok((device, queue))
180159
})
181160
}
182161

183-
fn run_internal<I>(self, input: Option<I>) -> Result<Vec<u8>, ComputeError>
162+
fn run_internal<I>(self, input: Option<I>) -> anyhow::Result<Vec<u8>>
184163
where
185164
I: Sized + Pod,
186165
{
@@ -278,42 +257,42 @@ where
278257
});
279258
device.poll(wgpu::Maintain::Wait);
280259
block_on(receiver)
281-
.map_err(ComputeError::MappingCanceled)?
282-
.map_err(ComputeError::MappingFailed)?;
260+
.context("mapping canceled")?
261+
.context("mapping failed")?;
283262
let data = buffer_slice.get_mapped_range().to_vec();
284263
staging_buffer.unmap();
285264
Ok(data)
286265
}
287266

288267
/// Runs the compute shader with no input.
289-
pub fn run(self) -> Result<Vec<u8>, ComputeError> {
268+
pub fn run(self) -> anyhow::Result<Vec<u8>> {
290269
self.run_internal::<()>(None)
291270
}
292271

293272
/// Runs the compute shader with provided input.
294-
pub fn run_with_input<I>(self, input: I) -> Result<Vec<u8>, ComputeError>
273+
pub fn run_with_input<I>(self, input: I) -> anyhow::Result<Vec<u8>>
295274
where
296275
I: Sized + Pod,
297276
{
298277
self.run_internal(Some(input))
299278
}
300279

301280
/// Runs the compute shader with no input and writes the output to a file.
302-
pub fn run_test(self, config: &Config) -> Result<(), ComputeError> {
281+
pub fn run_test(self, config: &Config) -> anyhow::Result<()> {
303282
let output = self.run()?;
304-
let mut f = File::create(&config.output_path).unwrap();
305-
f.write_all(&output).unwrap();
283+
let mut f = File::create(&config.output_path)?;
284+
f.write_all(&output)?;
306285
Ok(())
307286
}
308287

309288
/// Runs the compute shader with provided input and writes the output to a file.
310-
pub fn run_test_with_input<I>(self, config: &Config, input: I) -> Result<(), ComputeError>
289+
pub fn run_test_with_input<I>(self, config: &Config, input: I) -> anyhow::Result<()>
311290
where
312291
I: Sized + Pod,
313292
{
314293
let output = self.run_with_input(input)?;
315-
let mut f = File::create(&config.output_path).unwrap();
316-
f.write_all(&output).unwrap();
294+
let mut f = File::create(&config.output_path)?;
295+
f.write_all(&output)?;
317296
Ok(())
318297
}
319298
}

tests/difftests/tests/Cargo.lock

Lines changed: 7 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)