Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 37 additions & 117 deletions wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,7 @@ impl Global {

let device = &cmd_buf.device;

if !device
.features
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
{
return Err(BuildAccelerationStructureError::MissingFeature);
}
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;

let build_command_index = NonZeroU64::new(
device
Expand Down Expand Up @@ -200,18 +195,13 @@ impl Global {
let mut tlas_buf_storage = Vec::new();

for entry in tlas_iter {
let instance_buffer = match hub.buffers.get(entry.instance_buffer_id).get() {
Ok(buffer) => buffer,
Err(_) => {
return Err(BuildAccelerationStructureError::InvalidBufferId);
}
};
let instance_buffer = hub.buffers.get(entry.instance_buffer_id).get()?;
let data = cmd_buf_data.trackers.buffers.set_single(
&instance_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
tlas_buf_storage.push(TlasBufferStore {
buffer: instance_buffer.clone(),
buffer: instance_buffer,
transition: data,
entry: entry.clone(),
});
Expand All @@ -222,14 +212,9 @@ impl Global {
let instance_buffer = {
let (instance_buffer, instance_pending) =
(&mut tlas_buf.buffer, &mut tlas_buf.transition);
let instance_raw = instance_buffer.raw.get(&snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidBuffer(instance_buffer.error_ident()),
)?;
if !instance_buffer.usage.contains(BufferUsages::TLAS_INPUT) {
return Err(BuildAccelerationStructureError::MissingTlasInputUsageFlag(
instance_buffer.error_ident(),
));
}
let instance_raw = instance_buffer.try_raw(&snatch_guard)?;
instance_buffer.check_usage(BufferUsages::TLAS_INPUT)?;

if let Some(barrier) = instance_pending
.take()
.map(|pending| pending.into_hal(instance_buffer, &snatch_guard))
Expand All @@ -239,11 +224,7 @@ impl Global {
instance_raw
};

let tlas = hub
.tlas_s
.get(entry.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
Expand All @@ -267,7 +248,7 @@ impl Global {
tlas,
entries: hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: Some(instance_buffer.as_ref()),
buffer: Some(instance_buffer),
offset: 0,
count: entry.instance_count,
},
Expand Down Expand Up @@ -307,9 +288,7 @@ impl Global {
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas.raw(&snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidTlas(tlas.error_ident()),
)?,
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
})
Expand Down Expand Up @@ -374,12 +353,7 @@ impl Global {

let device = &cmd_buf.device;

if !device
.features
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
{
return Err(BuildAccelerationStructureError::MissingFeature);
}
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;

let build_command_index = NonZeroU64::new(
device
Expand Down Expand Up @@ -512,17 +486,13 @@ impl Global {
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();

for package in tlas_iter {
let tlas = hub
.tlas_s
.get(package.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
let tlas = hub.tlas_s.get(package.tlas_id).get()?;
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
}
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());

tlas_lock_store.push((Some(package), tlas.clone()))
tlas_lock_store.push((Some(package), tlas))
}

let mut scratch_buffer_tlas_size = 0;
Expand All @@ -549,12 +519,7 @@ impl Global {
tlas.error_ident(),
));
}
let blas = hub
.blas_s
.get(instance.blas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasIdForInstance)?
.clone();
let blas = hub.blas_s.get(instance.blas_id).get()?;

cmd_buf_data.trackers.blas_s.set_single(blas.clone());

Expand All @@ -569,7 +534,7 @@ impl Global {
dependencies.push(blas.clone());

cmd_buf_data.blas_actions.push(BlasAction {
blas: blas.clone(),
blas,
kind: crate::ray_tracing::BlasActionKind::Use,
});
}
Expand Down Expand Up @@ -642,13 +607,7 @@ impl Global {
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas
.raw
.get(&snatch_guard)
.ok_or(BuildAccelerationStructureError::InvalidTlas(
tlas.error_ident(),
))?
.as_ref(),
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
})
Expand Down Expand Up @@ -828,9 +787,7 @@ impl CommandBufferMutable {
action.tlas.error_ident(),
));
}
if blas.raw.get(snatch_guard).is_none() {
return Err(ValidateTlasActionsError::InvalidBlas(blas.error_ident()));
}
blas.try_raw(snatch_guard)?;
}
}
}
Expand All @@ -850,11 +807,7 @@ fn iter_blas<'a>(
) -> Result<(), BuildAccelerationStructureError> {
let mut temp_buffer = Vec::new();
for entry in blas_iter {
let blas = hub
.blas_s
.get(entry.blas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
let blas = hub.blas_s.get(entry.blas_id).get()?;
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_blas(&blas);
Expand Down Expand Up @@ -937,19 +890,13 @@ fn iter_blas<'a>(
blas.error_ident(),
));
}
let vertex_buffer = match hub.buffers.get(mesh.vertex_buffer).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
&vertex_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
let index_data = if let Some(index_id) = mesh.index_buffer {
let index_buffer = match hub.buffers.get(index_id).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
let index_buffer = hub.buffers.get(index_id).get()?;
if mesh.index_buffer_offset.is_none()
|| mesh.size.index_count.is_none()
|| mesh.size.index_count.is_none()
Expand All @@ -962,15 +909,12 @@ fn iter_blas<'a>(
&index_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
Some((index_buffer.clone(), data))
Some((index_buffer, data))
} else {
None
};
let transform_data = if let Some(transform_id) = mesh.transform_buffer {
let transform_buffer = match hub.buffers.get(transform_id).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
let transform_buffer = hub.buffers.get(transform_id).get()?;
if mesh.transform_buffer_offset.is_none() {
return Err(BuildAccelerationStructureError::MissingAssociatedData(
transform_buffer.error_ident(),
Expand All @@ -985,7 +929,7 @@ fn iter_blas<'a>(
None
};
temp_buffer.push(TriangleBufferStore {
vertex_buffer: vertex_buffer.clone(),
vertex_buffer,
vertex_transition: vertex_pending,
index_buffer_transition: index_data,
transform_buffer_transition: transform_data,
Expand All @@ -995,7 +939,7 @@ fn iter_blas<'a>(
}

if let Some(last) = temp_buffer.last_mut() {
last.ending_blas = Some(blas.clone());
last.ending_blas = Some(blas);
buf_storage.append(&mut temp_buffer);
}
}
Expand All @@ -1020,14 +964,9 @@ fn iter_buffers<'a, 'b>(
let mesh = &buf.geometry;
let vertex_buffer = {
let vertex_buffer = buf.vertex_buffer.as_ref();
let vertex_raw = vertex_buffer.raw.get(snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidBuffer(vertex_buffer.error_ident()),
)?;
if !vertex_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
vertex_buffer.error_ident(),
));
}
let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;

if let Some(barrier) = buf
.vertex_transition
.take()
Expand All @@ -1047,10 +986,7 @@ fn iter_buffers<'a, 'b>(
let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
cmd_buf_data.buffer_memory_init_actions.extend(
vertex_buffer.initialization_status.read().create_action(
&hub.buffers
.get(mesh.vertex_buffer)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBufferId)?,
&hub.buffers.get(mesh.vertex_buffer).get()?,
vertex_buffer_offset
..(vertex_buffer_offset
+ mesh.size.vertex_count as u64 * mesh.vertex_stride),
Expand All @@ -1062,14 +998,9 @@ fn iter_buffers<'a, 'b>(
let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
buf.index_buffer_transition
{
let index_raw = index_buffer.raw.get(snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidBuffer(index_buffer.error_ident()),
)?;
if !index_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
index_buffer.error_ident(),
));
}
let index_raw = index_buffer.try_raw(snatch_guard)?;
index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;

if let Some(barrier) = index_pending
.take()
.map(|pending| pending.into_hal(index_buffer, snatch_guard))
Expand Down Expand Up @@ -1125,14 +1056,9 @@ fn iter_buffers<'a, 'b>(
transform_buffer.error_ident(),
));
}
let transform_raw = transform_buffer.raw.get(snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidBuffer(transform_buffer.error_ident()),
)?;
if !transform_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
transform_buffer.error_ident(),
));
}
let transform_raw = transform_buffer.try_raw(snatch_guard)?;
transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;

if let Some(barrier) = transform_pending
.take()
.map(|pending| pending.into_hal(transform_buffer, snatch_guard))
Expand Down Expand Up @@ -1166,7 +1092,7 @@ fn iter_buffers<'a, 'b>(
};

let triangles = hal::AccelerationStructureTriangles {
vertex_buffer: Some(vertex_buffer.as_ref()),
vertex_buffer: Some(vertex_buffer),
vertex_format: mesh.size.vertex_format,
first_vertex: mesh.first_vertex,
vertex_count: mesh.size.vertex_count,
Expand All @@ -1175,13 +1101,13 @@ fn iter_buffers<'a, 'b>(
dyn hal::DynBuffer,
> {
format: mesh.size.index_format.unwrap(),
buffer: Some(index_buffer.as_ref()),
buffer: Some(index_buffer),
offset: mesh.index_buffer_offset.unwrap() as u32,
count: mesh.size.index_count.unwrap(),
}),
transform: transform_buffer.map(|transform_buffer| {
hal::AccelerationStructureTriangleTransform {
buffer: transform_buffer.as_ref(),
buffer: transform_buffer,
offset: mesh.transform_buffer_offset.unwrap() as u32,
}
}),
Expand Down Expand Up @@ -1231,13 +1157,7 @@ fn map_blas<'a>(
mode: hal::AccelerationStructureBuildMode::Build,
flags: blas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: blas
.raw
.get(snatch_guard)
.ok_or(BuildAccelerationStructureError::InvalidBlas(
blas.error_ident(),
))?
.as_ref(),
destination_acceleration_structure: blas.try_raw(snatch_guard)?,
scratch_buffer,
scratch_buffer_offset: *scratch_buffer_offset,
})
Expand Down
Loading