Skip to content

Commit d22dea5

Browse files
committed
[wgpu-core] ray tracing: use error handling helpers
1 parent 1abf3fe commit d22dea5

File tree

5 files changed

+103
-242
lines changed

5 files changed

+103
-242
lines changed

wgpu-core/src/command/ray_tracing.rs

Lines changed: 37 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,7 @@ impl Global {
8484

8585
let device = &cmd_buf.device;
8686

87-
if !device
88-
.features
89-
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
90-
{
91-
return Err(BuildAccelerationStructureError::MissingFeature);
92-
}
87+
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
9388

9489
let build_command_index = NonZeroU64::new(
9590
device
@@ -200,18 +195,13 @@ impl Global {
200195
let mut tlas_buf_storage = Vec::new();
201196

202197
for entry in tlas_iter {
203-
let instance_buffer = match hub.buffers.get(entry.instance_buffer_id).get() {
204-
Ok(buffer) => buffer,
205-
Err(_) => {
206-
return Err(BuildAccelerationStructureError::InvalidBufferId);
207-
}
208-
};
198+
let instance_buffer = hub.buffers.get(entry.instance_buffer_id).get()?;
209199
let data = cmd_buf_data.trackers.buffers.set_single(
210200
&instance_buffer,
211201
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
212202
);
213203
tlas_buf_storage.push(TlasBufferStore {
214-
buffer: instance_buffer.clone(),
204+
buffer: instance_buffer,
215205
transition: data,
216206
entry: entry.clone(),
217207
});
@@ -222,14 +212,9 @@ impl Global {
222212
let instance_buffer = {
223213
let (instance_buffer, instance_pending) =
224214
(&mut tlas_buf.buffer, &mut tlas_buf.transition);
225-
let instance_raw = instance_buffer.raw.get(&snatch_guard).ok_or(
226-
BuildAccelerationStructureError::InvalidBuffer(instance_buffer.error_ident()),
227-
)?;
228-
if !instance_buffer.usage.contains(BufferUsages::TLAS_INPUT) {
229-
return Err(BuildAccelerationStructureError::MissingTlasInputUsageFlag(
230-
instance_buffer.error_ident(),
231-
));
232-
}
215+
let instance_raw = instance_buffer.try_raw(&snatch_guard)?;
216+
instance_buffer.check_usage(BufferUsages::TLAS_INPUT)?;
217+
233218
if let Some(barrier) = instance_pending
234219
.take()
235220
.map(|pending| pending.into_hal(instance_buffer, &snatch_guard))
@@ -239,11 +224,7 @@ impl Global {
239224
instance_raw
240225
};
241226

242-
let tlas = hub
243-
.tlas_s
244-
.get(entry.tlas_id)
245-
.get()
246-
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
227+
let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
247228
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
248229
if let Some(queue) = device.get_queue() {
249230
queue.pending_writes.lock().insert_tlas(&tlas);
@@ -267,7 +248,7 @@ impl Global {
267248
tlas,
268249
entries: hal::AccelerationStructureEntries::Instances(
269250
hal::AccelerationStructureInstances {
270-
buffer: Some(instance_buffer.as_ref()),
251+
buffer: Some(instance_buffer),
271252
offset: 0,
272253
count: entry.instance_count,
273254
},
@@ -307,9 +288,7 @@ impl Global {
307288
mode: hal::AccelerationStructureBuildMode::Build,
308289
flags: tlas.flags,
309290
source_acceleration_structure: None,
310-
destination_acceleration_structure: tlas.raw(&snatch_guard).ok_or(
311-
BuildAccelerationStructureError::InvalidTlas(tlas.error_ident()),
312-
)?,
291+
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
313292
scratch_buffer: scratch_buffer.raw(),
314293
scratch_buffer_offset: *scratch_buffer_offset,
315294
})
@@ -374,12 +353,7 @@ impl Global {
374353

375354
let device = &cmd_buf.device;
376355

377-
if !device
378-
.features
379-
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
380-
{
381-
return Err(BuildAccelerationStructureError::MissingFeature);
382-
}
356+
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
383357

384358
let build_command_index = NonZeroU64::new(
385359
device
@@ -512,17 +486,13 @@ impl Global {
512486
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
513487

514488
for package in tlas_iter {
515-
let tlas = hub
516-
.tlas_s
517-
.get(package.tlas_id)
518-
.get()
519-
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
489+
let tlas = hub.tlas_s.get(package.tlas_id).get()?;
520490
if let Some(queue) = device.get_queue() {
521491
queue.pending_writes.lock().insert_tlas(&tlas);
522492
}
523493
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
524494

525-
tlas_lock_store.push((Some(package), tlas.clone()))
495+
tlas_lock_store.push((Some(package), tlas))
526496
}
527497

528498
let mut scratch_buffer_tlas_size = 0;
@@ -549,12 +519,7 @@ impl Global {
549519
tlas.error_ident(),
550520
));
551521
}
552-
let blas = hub
553-
.blas_s
554-
.get(instance.blas_id)
555-
.get()
556-
.map_err(|_| BuildAccelerationStructureError::InvalidBlasIdForInstance)?
557-
.clone();
522+
let blas = hub.blas_s.get(instance.blas_id).get()?;
558523

559524
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
560525

@@ -569,7 +534,7 @@ impl Global {
569534
dependencies.push(blas.clone());
570535

571536
cmd_buf_data.blas_actions.push(BlasAction {
572-
blas: blas.clone(),
537+
blas,
573538
kind: crate::ray_tracing::BlasActionKind::Use,
574539
});
575540
}
@@ -642,13 +607,7 @@ impl Global {
642607
mode: hal::AccelerationStructureBuildMode::Build,
643608
flags: tlas.flags,
644609
source_acceleration_structure: None,
645-
destination_acceleration_structure: tlas
646-
.raw
647-
.get(&snatch_guard)
648-
.ok_or(BuildAccelerationStructureError::InvalidTlas(
649-
tlas.error_ident(),
650-
))?
651-
.as_ref(),
610+
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
652611
scratch_buffer: scratch_buffer.raw(),
653612
scratch_buffer_offset: *scratch_buffer_offset,
654613
})
@@ -828,9 +787,7 @@ impl CommandBufferMutable {
828787
action.tlas.error_ident(),
829788
));
830789
}
831-
if blas.raw.get(snatch_guard).is_none() {
832-
return Err(ValidateTlasActionsError::InvalidBlas(blas.error_ident()));
833-
}
790+
blas.try_raw(snatch_guard)?;
834791
}
835792
}
836793
}
@@ -850,11 +807,7 @@ fn iter_blas<'a>(
850807
) -> Result<(), BuildAccelerationStructureError> {
851808
let mut temp_buffer = Vec::new();
852809
for entry in blas_iter {
853-
let blas = hub
854-
.blas_s
855-
.get(entry.blas_id)
856-
.get()
857-
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
810+
let blas = hub.blas_s.get(entry.blas_id).get()?;
858811
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
859812
if let Some(queue) = device.get_queue() {
860813
queue.pending_writes.lock().insert_blas(&blas);
@@ -937,19 +890,13 @@ fn iter_blas<'a>(
937890
blas.error_ident(),
938891
));
939892
}
940-
let vertex_buffer = match hub.buffers.get(mesh.vertex_buffer).get() {
941-
Ok(buffer) => buffer,
942-
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
943-
};
893+
let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
944894
let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
945895
&vertex_buffer,
946896
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
947897
);
948898
let index_data = if let Some(index_id) = mesh.index_buffer {
949-
let index_buffer = match hub.buffers.get(index_id).get() {
950-
Ok(buffer) => buffer,
951-
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
952-
};
899+
let index_buffer = hub.buffers.get(index_id).get()?;
953900
if mesh.index_buffer_offset.is_none()
954901
|| mesh.size.index_count.is_none()
955902
|| mesh.size.index_count.is_none()
@@ -962,15 +909,12 @@ fn iter_blas<'a>(
962909
&index_buffer,
963910
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
964911
);
965-
Some((index_buffer.clone(), data))
912+
Some((index_buffer, data))
966913
} else {
967914
None
968915
};
969916
let transform_data = if let Some(transform_id) = mesh.transform_buffer {
970-
let transform_buffer = match hub.buffers.get(transform_id).get() {
971-
Ok(buffer) => buffer,
972-
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
973-
};
917+
let transform_buffer = hub.buffers.get(transform_id).get()?;
974918
if mesh.transform_buffer_offset.is_none() {
975919
return Err(BuildAccelerationStructureError::MissingAssociatedData(
976920
transform_buffer.error_ident(),
@@ -985,7 +929,7 @@ fn iter_blas<'a>(
985929
None
986930
};
987931
temp_buffer.push(TriangleBufferStore {
988-
vertex_buffer: vertex_buffer.clone(),
932+
vertex_buffer,
989933
vertex_transition: vertex_pending,
990934
index_buffer_transition: index_data,
991935
transform_buffer_transition: transform_data,
@@ -995,7 +939,7 @@ fn iter_blas<'a>(
995939
}
996940

997941
if let Some(last) = temp_buffer.last_mut() {
998-
last.ending_blas = Some(blas.clone());
942+
last.ending_blas = Some(blas);
999943
buf_storage.append(&mut temp_buffer);
1000944
}
1001945
}
@@ -1020,14 +964,9 @@ fn iter_buffers<'a, 'b>(
1020964
let mesh = &buf.geometry;
1021965
let vertex_buffer = {
1022966
let vertex_buffer = buf.vertex_buffer.as_ref();
1023-
let vertex_raw = vertex_buffer.raw.get(snatch_guard).ok_or(
1024-
BuildAccelerationStructureError::InvalidBuffer(vertex_buffer.error_ident()),
1025-
)?;
1026-
if !vertex_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
1027-
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
1028-
vertex_buffer.error_ident(),
1029-
));
1030-
}
967+
let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
968+
vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
969+
1031970
if let Some(barrier) = buf
1032971
.vertex_transition
1033972
.take()
@@ -1047,10 +986,7 @@ fn iter_buffers<'a, 'b>(
1047986
let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
1048987
cmd_buf_data.buffer_memory_init_actions.extend(
1049988
vertex_buffer.initialization_status.read().create_action(
1050-
&hub.buffers
1051-
.get(mesh.vertex_buffer)
1052-
.get()
1053-
.map_err(|_| BuildAccelerationStructureError::InvalidBufferId)?,
989+
&hub.buffers.get(mesh.vertex_buffer).get()?,
1054990
vertex_buffer_offset
1055991
..(vertex_buffer_offset
1056992
+ mesh.size.vertex_count as u64 * mesh.vertex_stride),
@@ -1062,14 +998,9 @@ fn iter_buffers<'a, 'b>(
1062998
let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
1063999
buf.index_buffer_transition
10641000
{
1065-
let index_raw = index_buffer.raw.get(snatch_guard).ok_or(
1066-
BuildAccelerationStructureError::InvalidBuffer(index_buffer.error_ident()),
1067-
)?;
1068-
if !index_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
1069-
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
1070-
index_buffer.error_ident(),
1071-
));
1072-
}
1001+
let index_raw = index_buffer.try_raw(snatch_guard)?;
1002+
index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1003+
10731004
if let Some(barrier) = index_pending
10741005
.take()
10751006
.map(|pending| pending.into_hal(index_buffer, snatch_guard))
@@ -1125,14 +1056,9 @@ fn iter_buffers<'a, 'b>(
11251056
transform_buffer.error_ident(),
11261057
));
11271058
}
1128-
let transform_raw = transform_buffer.raw.get(snatch_guard).ok_or(
1129-
BuildAccelerationStructureError::InvalidBuffer(transform_buffer.error_ident()),
1130-
)?;
1131-
if !transform_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
1132-
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
1133-
transform_buffer.error_ident(),
1134-
));
1135-
}
1059+
let transform_raw = transform_buffer.try_raw(snatch_guard)?;
1060+
transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1061+
11361062
if let Some(barrier) = transform_pending
11371063
.take()
11381064
.map(|pending| pending.into_hal(transform_buffer, snatch_guard))
@@ -1166,7 +1092,7 @@ fn iter_buffers<'a, 'b>(
11661092
};
11671093

11681094
let triangles = hal::AccelerationStructureTriangles {
1169-
vertex_buffer: Some(vertex_buffer.as_ref()),
1095+
vertex_buffer: Some(vertex_buffer),
11701096
vertex_format: mesh.size.vertex_format,
11711097
first_vertex: mesh.first_vertex,
11721098
vertex_count: mesh.size.vertex_count,
@@ -1175,13 +1101,13 @@ fn iter_buffers<'a, 'b>(
11751101
dyn hal::DynBuffer,
11761102
> {
11771103
format: mesh.size.index_format.unwrap(),
1178-
buffer: Some(index_buffer.as_ref()),
1104+
buffer: Some(index_buffer),
11791105
offset: mesh.index_buffer_offset.unwrap() as u32,
11801106
count: mesh.size.index_count.unwrap(),
11811107
}),
11821108
transform: transform_buffer.map(|transform_buffer| {
11831109
hal::AccelerationStructureTriangleTransform {
1184-
buffer: transform_buffer.as_ref(),
1110+
buffer: transform_buffer,
11851111
offset: mesh.transform_buffer_offset.unwrap() as u32,
11861112
}
11871113
}),
@@ -1231,13 +1157,7 @@ fn map_blas<'a>(
12311157
mode: hal::AccelerationStructureBuildMode::Build,
12321158
flags: blas.flags,
12331159
source_acceleration_structure: None,
1234-
destination_acceleration_structure: blas
1235-
.raw
1236-
.get(snatch_guard)
1237-
.ok_or(BuildAccelerationStructureError::InvalidBlas(
1238-
blas.error_ident(),
1239-
))?
1240-
.as_ref(),
1160+
destination_acceleration_structure: blas.try_raw(snatch_guard)?,
12411161
scratch_buffer,
12421162
scratch_buffer_offset: *scratch_buffer_offset,
12431163
})

0 commit comments

Comments
 (0)