Skip to content

Commit 6714b3c

Browse files
committed
Add FunctionNode.write_timestamps to allow writing timestamps before/after dispatch
1 parent 12bb816 commit 6714b3c

File tree

4 files changed

+74
-5
lines changed

4 files changed

+74
-5
lines changed

slangpy/core/function.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
)
1111

1212
from slangpy.reflection import SlangFunction, SlangType
13-
from slangpy import CommandEncoder, TypeConformance, uint3, Logger, NativeHandle, NativeHandleType
13+
from slangpy import (
14+
CommandEncoder,
15+
QueryPool,
16+
TypeConformance,
17+
uint3,
18+
Logger,
19+
NativeHandle,
20+
NativeHandleType,
21+
)
1422
from slangpy.slangpy import Shape
1523
from slangpy.bindings.typeregistry import PYTHON_SIGNATURES
1624

@@ -139,6 +147,12 @@ def cuda_stream(self, stream: NativeHandle) -> "FunctionNode":
139147
"""
140148
return FunctionNodeCUDAStream(self, stream)
141149

150+
def write_timestamps(self, write_timestamps: tuple[QueryPool, int, int]) -> "FunctionNode":
151+
"""
152+
Specify a query pool and and a before/after query index to write timestamps before/after the dispatch.
153+
"""
154+
return FunctionNodeWriteTimestamps(self, write_timestamps)
155+
142156
def constants(self, constants: dict[str, Any]):
143157
"""
144158
Specify link time constants that should be set when the function is compiled. These are
@@ -427,6 +441,21 @@ def _populate_build_info(self, info: FunctionBuildInfo):
427441
info.options["cuda_stream"] = self.stream
428442

429443

444+
class FunctionNodeWriteTimestamps(FunctionNode):
445+
def __init__(
446+
self, parent: NativeFunctionNode, write_timestamps: tuple[QueryPool, int, int]
447+
) -> None:
448+
super().__init__(parent, FunctionNodeType.write_timestamps, write_timestamps)
449+
self.slangpy_signature = str(write_timestamps)
450+
451+
@property
452+
def write_timestamps(self):
453+
return cast(tuple[QueryPool, int, int], self._native_data)
454+
455+
def _populate_build_info(self, info: FunctionBuildInfo):
456+
info.options["write_timestamps"] = self.write_timestamps
457+
458+
430459
class FunctionNodeConstants(FunctionNode):
431460
def __init__(self, parent: NativeFunctionNode, constants: dict[str, Any]) -> None:
432461
super().__init__(parent, FunctionNodeType.kernelgen, constants)

src/slangpy_ext/utils/slangpy.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "sgl/utils/slangpy.h"
1111
#include "sgl/device/device.h"
1212
#include "sgl/device/kernel.h"
13+
#include "sgl/device/query.h"
1314
#include "sgl/device/command.h"
1415
#include "sgl/stl/bit.h" // Replace with <bit> when available on all platforms.
1516

@@ -643,10 +644,25 @@ nb::object NativeCallData::exec(
643644

644645
if (command_encoder == nullptr) {
645646
// If we are not appending to a command encoder, we can dispatch directly.
646-
m_kernel->dispatch(uint3(total_threads, 1, 1), bind_vars, CommandQueueType::graphics, cuda_stream);
647+
m_kernel->dispatch(
648+
uint3(total_threads, 1, 1),
649+
bind_vars,
650+
CommandQueueType::graphics,
651+
cuda_stream,
652+
opts->get_query_pool(),
653+
opts->get_query_before_index(),
654+
opts->get_query_after_index()
655+
);
647656
} else {
648657
// If we are appending to a command encoder, we need to use the command encoder.
649-
m_kernel->dispatch(uint3(total_threads, 1, 1), bind_vars, command_encoder);
658+
m_kernel->dispatch(
659+
uint3(total_threads, 1, 1),
660+
bind_vars,
661+
command_encoder,
662+
opts->get_query_pool(),
663+
opts->get_query_before_index(),
664+
opts->get_query_after_index()
665+
);
650666
}
651667

652668
// If command_buffer is not null, return early.

src/slangpy_ext/utils/slangpy.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,25 @@ class NativeCallRuntimeOptions : Object {
598598
/// Set the CUDA stream.
599599
void set_cuda_stream(NativeHandle cuda_stream) { m_cuda_stream = cuda_stream; }
600600

601+
QueryPool* get_query_pool() const { return m_query_pool.get(); }
602+
603+
void set_query_pool(QueryPool* query_pool) { m_query_pool = ref(query_pool); }
604+
605+
uint32_t get_query_before_index() const { return m_query_before_index; }
606+
607+
void set_query_before_index(uint32_t query_before_index) { m_query_before_index = query_before_index; }
608+
609+
uint32_t get_query_after_index() const { return m_query_after_index; }
610+
611+
void set_query_after_index(uint32_t query_after_index) { m_query_after_index = query_after_index; }
612+
601613
private:
602614
nb::list m_uniforms;
603615
nb::object m_this{nb::none()};
604616
NativeHandle m_cuda_stream;
617+
ref<QueryPool> m_query_pool;
618+
uint32_t m_query_before_index{0};
619+
uint32_t m_query_after_index{0};
605620
};
606621

607622
/// Defines the common logging functions for a given log level.

src/slangpy_ext/utils/slangpyfunction.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414

1515
#include "sgl/device/fwd.h"
1616
#include "sgl/device/resource.h"
17+
#include "sgl/device/query.h"
1718

1819
#include "utils/slangpy.h"
1920

2021
namespace sgl::slangpy {
2122

22-
enum class FunctionNodeType { unknown, uniforms, kernelgen, this_, cuda_stream };
23+
enum class FunctionNodeType { unknown, uniforms, kernelgen, this_, cuda_stream, write_timestamps };
2324
SGL_ENUM_INFO(
2425
FunctionNodeType,
2526
{{FunctionNodeType::unknown, "unknown"},
2627
{FunctionNodeType::uniforms, "uniforms"},
2728
{FunctionNodeType::kernelgen, "kernelgen"},
2829
{FunctionNodeType::this_, "this"},
29-
{FunctionNodeType::cuda_stream, "cuda_stream"}}
30+
{FunctionNodeType::cuda_stream, "cuda_stream"},
31+
{FunctionNodeType::write_timestamps, "write_timestamps"}}
3032
);
3133
SGL_ENUM_REGISTER(FunctionNodeType);
3234

@@ -72,6 +74,13 @@ class NativeFunctionNode : NativeObject {
7274
case sgl::slangpy::FunctionNodeType::cuda_stream:
7375
options->set_cuda_stream(nb::cast<NativeHandle>(m_data));
7476
break;
77+
case sgl::slangpy::FunctionNodeType::write_timestamps: {
78+
nb::tuple t = nb::cast<nb::tuple>(m_data);
79+
options->set_query_pool(nb::cast<QueryPool*>(t[0]));
80+
options->set_query_before_index(nb::cast<uint32_t>(t[1]));
81+
options->set_query_after_index(nb::cast<uint32_t>(t[2]));
82+
break;
83+
}
7584
default:
7685
break;
7786
}

0 commit comments

Comments
 (0)