From 7fa4299f0cff156194bb002d5ffb015574556ae1 Mon Sep 17 00:00:00 2001 From: Erich Gubler Date: Thu, 21 Nov 2024 22:50:26 -0500 Subject: [PATCH] refactor(core): extract `Global::validate_pass_timestamp_writes` --- wgpu-core/src/command/compute.rs | 69 +++++--------------------------- wgpu-core/src/command/mod.rs | 47 +++++++++++++++++++++- wgpu-core/src/command/render.rs | 48 ++-------------------- 3 files changed, 59 insertions(+), 105 deletions(-) diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index a0ce897f8a9..571e605d8a9 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -28,9 +28,7 @@ use crate::{ use thiserror::Error; use wgt::{BufferAddress, DynamicOffset}; -use super::{ - bind::BinderError, memory_init::CommandBufferTextureMemoryActions, SimplifiedQueryType, -}; +use super::{bind::BinderError, memory_init::CommandBufferTextureMemoryActions}; use crate::ray_tracing::TlasAction; use std::sync::Arc; use std::{fmt, mem::size_of, str}; @@ -310,64 +308,15 @@ impl Global { Err(e) => return make_err(e, arc_desc), }; - arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes { - let &PassTimestampWrites { - query_set, - beginning_of_pass_write_index, - end_of_pass_write_index, - } = tw; - - match cmd_buf - .device - .require_features(wgt::Features::TIMESTAMP_QUERY) - { - Ok(()) => (), - Err(e) => return make_err(e.into(), arc_desc), - } - - let query_set = match hub.query_sets.get(query_set).get() { - Ok(query_set) => query_set, - Err(e) => return make_err(e.into(), arc_desc), - }; - - match query_set.same_device(&cmd_buf.device) { - Ok(()) => (), - Err(e) => return make_err(e.into(), arc_desc), - } - - for idx in [beginning_of_pass_write_index, end_of_pass_write_index] - .into_iter() - .flatten() - { - match query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None) { - Ok(()) => (), - Err(e) => return make_err(e.into(), arc_desc), - } - } - - if let Some((begin, end)) = beginning_of_pass_write_index.zip(end_of_pass_write_index) { - if begin == end { - return make_err( - CommandEncoderError::TimestampWriteIndicesEqual { idx: begin }, - arc_desc, - ); - } - } - - if beginning_of_pass_write_index - .or(end_of_pass_write_index) - .is_none() - { - return make_err(CommandEncoderError::TimestampWriteIndicesMissing, arc_desc); - } - - Some(ArcPassTimestampWrites { - query_set, - beginning_of_pass_write_index, - end_of_pass_write_index, + arc_desc.timestamp_writes = match desc + .timestamp_writes + .map(|tw| { + Self::validate_pass_timestamp_writes(&cmd_buf.device, &*hub.query_sets.read(), tw) }) - } else { - None + .transpose() + { + Ok(ok) => ok, + Err(e) => return make_err(e.into(), arc_desc), }; (ComputePass::new(Some(cmd_buf), arc_desc), None) diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index 4a1694fcf08..fab58515fb6 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -33,7 +33,8 @@ use crate::snatch::SnatchGuard; use crate::init_tracker::BufferInitTrackerAction; use crate::ray_tracing::{BlasAction, TlasAction}; -use crate::resource::{InvalidResourceError, Labeled}; +use crate::resource::{Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet}; +use crate::storage::Storage; use crate::track::{DeviceTracker, Tracker, UsageScope}; use crate::LabelHelpers; use crate::{api_log, global::Global, id, resource_log, Label}; @@ -782,6 +783,50 @@ impl Global { } Ok(()) } + + fn validate_pass_timestamp_writes( + device: &Device, + query_sets: &Storage>, + timestamp_writes: &PassTimestampWrites, + ) -> Result { + let &PassTimestampWrites { + query_set, + beginning_of_pass_write_index, + end_of_pass_write_index, + } = timestamp_writes; + + device.require_features(wgt::Features::TIMESTAMP_QUERY)?; + + let query_set = query_sets.get(query_set).get()?; + + query_set.same_device(device)?; + + for idx in [beginning_of_pass_write_index, end_of_pass_write_index] + .into_iter() + .flatten() + { + query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None)?; + } + + if let Some((begin, end)) = beginning_of_pass_write_index.zip(end_of_pass_write_index) { + if begin == end { + return Err(CommandEncoderError::TimestampWriteIndicesEqual { idx: begin }); + } + } + + if beginning_of_pass_write_index + .or(end_of_pass_write_index) + .is_none() + { + return Err(CommandEncoderError::TimestampWriteIndicesMissing); + } + + Ok(ArcPassTimestampWrites { + query_set, + beginning_of_pass_write_index, + end_of_pass_write_index, + }) + } } fn push_constant_clear(offset: u32, size_bytes: u32, mut push_fn: PushFn) diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index 017e0af14c4..b2d42ca2942 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -1,7 +1,6 @@ use crate::binding_model::BindGroup; use crate::command::{ validate_and_begin_occlusion_query, validate_and_begin_pipeline_statistics_query, - SimplifiedQueryType, }; use crate::init_tracker::BufferInitTrackerAction; use crate::pipeline::RenderPipeline; @@ -1392,49 +1391,10 @@ impl Global { None }; - arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes { - let &PassTimestampWrites { - query_set, - beginning_of_pass_write_index, - end_of_pass_write_index, - } = tw; - - let query_set = query_sets.get(query_set).get()?; - - device.require_features(wgt::Features::TIMESTAMP_QUERY)?; - - query_set.same_device(device)?; - - for idx in [beginning_of_pass_write_index, end_of_pass_write_index] - .into_iter() - .flatten() - { - query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None)?; - } - - if let Some((begin, end)) = - beginning_of_pass_write_index.zip(end_of_pass_write_index) - { - if begin == end { - return Err(CommandEncoderError::TimestampWriteIndicesEqual { idx: begin }); - } - } - - if beginning_of_pass_write_index - .or(end_of_pass_write_index) - .is_none() - { - return Err(CommandEncoderError::TimestampWriteIndicesMissing); - } - - Some(ArcPassTimestampWrites { - query_set, - beginning_of_pass_write_index, - end_of_pass_write_index, - }) - } else { - None - }; + arc_desc.timestamp_writes = desc + .timestamp_writes + .map(|tw| Global::validate_pass_timestamp_writes(device, &*query_sets, tw)) + .transpose()?; arc_desc.occlusion_query_set = if let Some(occlusion_query_set) = desc.occlusion_query_set {