1
1
use crate :: config:: Config ;
2
+ use anyhow:: Context ;
2
3
use bytemuck:: Pod ;
3
- use futures:: { channel :: oneshot :: Canceled , executor:: block_on} ;
4
+ use futures:: executor:: block_on;
4
5
use spirv_builder:: { ModuleResult , SpirvBuilder } ;
5
6
use std:: {
6
7
borrow:: Cow ,
@@ -9,29 +10,14 @@ use std::{
9
10
io:: Write ,
10
11
path:: PathBuf ,
11
12
} ;
12
- use thiserror:: Error ;
13
- use wgpu:: { BufferAsyncError , PipelineCompilationOptions , util:: DeviceExt } ;
14
-
15
- #[ derive( Error , Debug ) ]
16
- pub enum ComputeError {
17
- #[ error( "Failed to find a suitable GPU adapter" ) ]
18
- AdapterNotFound ,
19
- #[ error( "Failed to create device: {0}" ) ]
20
- DeviceCreationFailed ( String ) ,
21
- #[ error( "Failed to load shader: {0}" ) ]
22
- ShaderLoadFailed ( String ) ,
23
- #[ error( "Mapping compute output future canceled: {0}" ) ]
24
- MappingCanceled ( Canceled ) ,
25
- #[ error( "Mapping compute output failed: {0}" ) ]
26
- MappingFailed ( BufferAsyncError ) ,
27
- }
13
+ use wgpu:: { PipelineCompilationOptions , util:: DeviceExt } ;
28
14
29
15
/// Trait that creates a shader module and provides its entry point.
30
16
pub trait ComputeShader {
31
17
fn create_module (
32
18
& self ,
33
19
device : & wgpu:: Device ,
34
- ) -> Result < ( wgpu:: ShaderModule , Option < String > ) , ComputeError > ;
20
+ ) -> anyhow :: Result < ( wgpu:: ShaderModule , Option < String > ) > ;
35
21
}
36
22
37
23
/// A compute shader written in Rust compiled with spirv-builder.
@@ -49,40 +35,33 @@ impl ComputeShader for RustComputeShader {
49
35
fn create_module (
50
36
& self ,
51
37
device : & wgpu:: Device ,
52
- ) -> Result < ( wgpu:: ShaderModule , Option < String > ) , ComputeError > {
38
+ ) -> anyhow :: Result < ( wgpu:: ShaderModule , Option < String > ) > {
53
39
let builder = SpirvBuilder :: new ( & self . path , "spirv-unknown-vulkan1.1" )
54
40
. print_metadata ( spirv_builder:: MetadataPrintout :: None )
55
41
. release ( true )
56
42
. multimodule ( false )
57
43
. shader_panic_strategy ( spirv_builder:: ShaderPanicStrategy :: SilentExit )
58
44
. preserve_bindings ( true ) ;
59
- let artifact = builder
60
- . build ( )
61
- . map_err ( |e| ComputeError :: ShaderLoadFailed ( e. to_string ( ) ) ) ?;
45
+ let artifact = builder. build ( ) . context ( "SpirvBuilder::build() failed" ) ?;
62
46
63
47
if artifact. entry_points . len ( ) != 1 {
64
- return Err ( ComputeError :: ShaderLoadFailed ( format ! (
48
+ anyhow :: bail !(
65
49
"Expected exactly one entry point, found {}" ,
66
50
artifact. entry_points. len( )
67
- ) ) ) ;
51
+ ) ;
68
52
}
69
53
let entry_point = artifact. entry_points . into_iter ( ) . next ( ) . unwrap ( ) ;
70
54
71
55
let shader_bytes = match artifact. module {
72
- ModuleResult :: SingleModule ( path) => {
73
- fs:: read ( & path) . map_err ( |e| ComputeError :: ShaderLoadFailed ( e. to_string ( ) ) ) ?
74
- }
56
+ ModuleResult :: SingleModule ( path) => fs:: read ( & path)
57
+ . with_context ( || format ! ( "reading spv file '{}' failed" , path. display( ) ) ) ?,
75
58
ModuleResult :: MultiModule ( _modules) => {
76
- return Err ( ComputeError :: ShaderLoadFailed (
77
- "Multiple modules produced" . to_string ( ) ,
78
- ) ) ;
59
+ anyhow:: bail!( "MultiModule modules produced" ) ;
79
60
}
80
61
} ;
81
62
82
63
if shader_bytes. len ( ) % 4 != 0 {
83
- return Err ( ComputeError :: ShaderLoadFailed (
84
- "SPIR-V binary length is not a multiple of 4" . to_string ( ) ,
85
- ) ) ;
64
+ anyhow:: bail!( "SPIR-V binary length is not a multiple of 4" ) ;
86
65
}
87
66
let shader_words: Vec < u32 > = bytemuck:: cast_slice ( & shader_bytes) . to_vec ( ) ;
88
67
let module = device. create_shader_module ( wgpu:: ShaderModuleDescriptor {
@@ -112,9 +91,9 @@ impl ComputeShader for WgslComputeShader {
112
91
fn create_module (
113
92
& self ,
114
93
device : & wgpu:: Device ,
115
- ) -> Result < ( wgpu:: ShaderModule , Option < String > ) , ComputeError > {
94
+ ) -> anyhow :: Result < ( wgpu:: ShaderModule , Option < String > ) > {
116
95
let shader_source = fs:: read_to_string ( & self . path )
117
- . map_err ( |e| ComputeError :: ShaderLoadFailed ( e . to_string ( ) ) ) ?;
96
+ . with_context ( || format ! ( "reading wgsl source file '{}'" , & self . path . display ( ) ) ) ?;
118
97
let module = device. create_shader_module ( wgpu:: ShaderModuleDescriptor {
119
98
label : Some ( "Compute Shader" ) ,
120
99
source : wgpu:: ShaderSource :: Wgsl ( Cow :: Owned ( shader_source) ) ,
@@ -142,7 +121,7 @@ where
142
121
}
143
122
}
144
123
145
- fn init ( ) -> Result < ( wgpu:: Device , wgpu:: Queue ) , ComputeError > {
124
+ fn init ( ) -> anyhow :: Result < ( wgpu:: Device , wgpu:: Queue ) > {
146
125
block_on ( async {
147
126
let instance = wgpu:: Instance :: new ( wgpu:: InstanceDescriptor {
148
127
#[ cfg( target_os = "linux" ) ]
@@ -160,7 +139,7 @@ where
160
139
force_fallback_adapter : false ,
161
140
} )
162
141
. await
163
- . ok_or ( ComputeError :: AdapterNotFound ) ?;
142
+ . context ( "Failed to find a suitable GPU adapter" ) ?;
164
143
let ( device, queue) = adapter
165
144
. request_device (
166
145
& wgpu:: DeviceDescriptor {
@@ -175,12 +154,12 @@ where
175
154
None ,
176
155
)
177
156
. await
178
- . map_err ( |e| ComputeError :: DeviceCreationFailed ( e . to_string ( ) ) ) ?;
157
+ . context ( "Failed to create device" ) ?;
179
158
Ok ( ( device, queue) )
180
159
} )
181
160
}
182
161
183
- fn run_internal < I > ( self , input : Option < I > ) -> Result < Vec < u8 > , ComputeError >
162
+ fn run_internal < I > ( self , input : Option < I > ) -> anyhow :: Result < Vec < u8 > >
184
163
where
185
164
I : Sized + Pod ,
186
165
{
@@ -278,42 +257,42 @@ where
278
257
} ) ;
279
258
device. poll ( wgpu:: Maintain :: Wait ) ;
280
259
block_on ( receiver)
281
- . map_err ( ComputeError :: MappingCanceled ) ?
282
- . map_err ( ComputeError :: MappingFailed ) ?;
260
+ . context ( "mapping canceled" ) ?
261
+ . context ( "mapping failed" ) ?;
283
262
let data = buffer_slice. get_mapped_range ( ) . to_vec ( ) ;
284
263
staging_buffer. unmap ( ) ;
285
264
Ok ( data)
286
265
}
287
266
288
267
/// Runs the compute shader with no input.
289
- pub fn run ( self ) -> Result < Vec < u8 > , ComputeError > {
268
+ pub fn run ( self ) -> anyhow :: Result < Vec < u8 > > {
290
269
self . run_internal :: < ( ) > ( None )
291
270
}
292
271
293
272
/// Runs the compute shader with provided input.
294
- pub fn run_with_input < I > ( self , input : I ) -> Result < Vec < u8 > , ComputeError >
273
+ pub fn run_with_input < I > ( self , input : I ) -> anyhow :: Result < Vec < u8 > >
295
274
where
296
275
I : Sized + Pod ,
297
276
{
298
277
self . run_internal ( Some ( input) )
299
278
}
300
279
301
280
/// Runs the compute shader with no input and writes the output to a file.
302
- pub fn run_test ( self , config : & Config ) -> Result < ( ) , ComputeError > {
281
+ pub fn run_test ( self , config : & Config ) -> anyhow :: Result < ( ) > {
303
282
let output = self . run ( ) ?;
304
- let mut f = File :: create ( & config. output_path ) . unwrap ( ) ;
305
- f. write_all ( & output) . unwrap ( ) ;
283
+ let mut f = File :: create ( & config. output_path ) ? ;
284
+ f. write_all ( & output) ? ;
306
285
Ok ( ( ) )
307
286
}
308
287
309
288
/// Runs the compute shader with provided input and writes the output to a file.
310
- pub fn run_test_with_input < I > ( self , config : & Config , input : I ) -> Result < ( ) , ComputeError >
289
+ pub fn run_test_with_input < I > ( self , config : & Config , input : I ) -> anyhow :: Result < ( ) >
311
290
where
312
291
I : Sized + Pod ,
313
292
{
314
293
let output = self . run_with_input ( input) ?;
315
- let mut f = File :: create ( & config. output_path ) . unwrap ( ) ;
316
- f. write_all ( & output) . unwrap ( ) ;
294
+ let mut f = File :: create ( & config. output_path ) ? ;
295
+ f. write_all ( & output) ? ;
317
296
Ok ( ( ) )
318
297
}
319
298
}
0 commit comments