Skip to content

Commit

Permalink
refactor(core): extract Global::validate_pass_timestamp_writes
Browse files Browse the repository at this point in the history
  • Loading branch information
ErichDonGubler committed Nov 22, 2024
1 parent 99a7402 commit f8fbd71
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 105 deletions.
69 changes: 9 additions & 60 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ use crate::{
use thiserror::Error;
use wgt::{BufferAddress, DynamicOffset};

use super::{
bind::BinderError, memory_init::CommandBufferTextureMemoryActions, SimplifiedQueryType,
};
use super::{bind::BinderError, memory_init::CommandBufferTextureMemoryActions};
use crate::ray_tracing::TlasAction;
use std::sync::Arc;
use std::{fmt, mem::size_of, str};
Expand Down Expand Up @@ -310,64 +308,15 @@ impl Global {
Err(e) => return make_err(e, arc_desc),
};

arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let &PassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
} = tw;

match cmd_buf
.device
.require_features(wgt::Features::TIMESTAMP_QUERY)
{
Ok(()) => (),
Err(e) => return make_err(e.into(), arc_desc),
}

let query_set = match hub.query_sets.get(query_set).get() {
Ok(query_set) => query_set,
Err(e) => return make_err(e.into(), arc_desc),
};

match query_set.same_device(&cmd_buf.device) {
Ok(()) => (),
Err(e) => return make_err(e.into(), arc_desc),
}

for idx in [beginning_of_pass_write_index, end_of_pass_write_index]
.into_iter()
.flatten()
{
match query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None) {
Ok(()) => (),
Err(e) => return make_err(e.into(), arc_desc),
}
}

if let Some((begin, end)) = beginning_of_pass_write_index.zip(end_of_pass_write_index) {
if begin == end {
return make_err(
CommandEncoderError::TimestampWriteIndicesEqual { idx: begin },
arc_desc,
);
}
}

if beginning_of_pass_write_index
.or(end_of_pass_write_index)
.is_none()
{
return make_err(CommandEncoderError::TimestampWriteIndicesMissing, arc_desc);
}

Some(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
arc_desc.timestamp_writes = match desc
.timestamp_writes
.map(|tw| {
Self::validate_pass_timestamp_writes(&cmd_buf.device, &hub.query_sets.read(), tw)
})
} else {
None
.transpose()
{
Ok(ok) => ok,
Err(e) => return make_err(e, arc_desc),
};

(ComputePass::new(Some(cmd_buf), arc_desc), None)
Expand Down
47 changes: 46 additions & 1 deletion wgpu-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ use crate::snatch::SnatchGuard;

use crate::init_tracker::BufferInitTrackerAction;
use crate::ray_tracing::{BlasAction, TlasAction};
use crate::resource::{InvalidResourceError, Labeled};
use crate::resource::{Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet};
use crate::storage::Storage;
use crate::track::{DeviceTracker, Tracker, UsageScope};
use crate::LabelHelpers;
use crate::{api_log, global::Global, id, resource_log, Label};
Expand Down Expand Up @@ -782,6 +783,50 @@ impl Global {
}
Ok(())
}

fn validate_pass_timestamp_writes(
device: &Device,
query_sets: &Storage<Fallible<QuerySet>>,
timestamp_writes: &PassTimestampWrites,
) -> Result<ArcPassTimestampWrites, CommandEncoderError> {
let &PassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
} = timestamp_writes;

device.require_features(wgt::Features::TIMESTAMP_QUERY)?;

let query_set = query_sets.get(query_set).get()?;

query_set.same_device(device)?;

for idx in [beginning_of_pass_write_index, end_of_pass_write_index]
.into_iter()
.flatten()
{
query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None)?;
}

if let Some((begin, end)) = beginning_of_pass_write_index.zip(end_of_pass_write_index) {
if begin == end {
return Err(CommandEncoderError::TimestampWriteIndicesEqual { idx: begin });
}
}

if beginning_of_pass_write_index
.or(end_of_pass_write_index)
.is_none()
{
return Err(CommandEncoderError::TimestampWriteIndicesMissing);
}

Ok(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
})
}
}

fn push_constant_clear<PushFn>(offset: u32, size_bytes: u32, mut push_fn: PushFn)
Expand Down
48 changes: 4 additions & 44 deletions wgpu-core/src/command/render.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::binding_model::BindGroup;
use crate::command::{
validate_and_begin_occlusion_query, validate_and_begin_pipeline_statistics_query,
SimplifiedQueryType,
};
use crate::init_tracker::BufferInitTrackerAction;
use crate::pipeline::RenderPipeline;
Expand Down Expand Up @@ -1392,49 +1391,10 @@ impl Global {
None
};

arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let &PassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
} = tw;

let query_set = query_sets.get(query_set).get()?;

device.require_features(wgt::Features::TIMESTAMP_QUERY)?;

query_set.same_device(device)?;

for idx in [beginning_of_pass_write_index, end_of_pass_write_index]
.into_iter()
.flatten()
{
query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None)?;
}

if let Some((begin, end)) =
beginning_of_pass_write_index.zip(end_of_pass_write_index)
{
if begin == end {
return Err(CommandEncoderError::TimestampWriteIndicesEqual { idx: begin });
}
}

if beginning_of_pass_write_index
.or(end_of_pass_write_index)
.is_none()
{
return Err(CommandEncoderError::TimestampWriteIndicesMissing);
}

Some(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
})
} else {
None
};
arc_desc.timestamp_writes = desc
.timestamp_writes
.map(|tw| Global::validate_pass_timestamp_writes(device, &query_sets, tw))
.transpose()?;

arc_desc.occlusion_query_set =
if let Some(occlusion_query_set) = desc.occlusion_query_set {
Expand Down

0 comments on commit f8fbd71

Please sign in to comment.