Skip to content

Commit

Permalink
multi-timing appears to work, cool
Browse files Browse the repository at this point in the history
  • Loading branch information
John Owens committed Oct 4, 2024
1 parent 2cb51c5 commit 1e0576d
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 68 deletions.
119 changes: 79 additions & 40 deletions catmull-clark.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ const device = await adapter?.requestDevice({
if (!device) {
fail("Fatal error: Device does not support WebGPU.");
}
const timingHelper = new TimingHelper(device);

// We can set runtime params from the input URL!
const urlParams = new URL(window.location.href).searchParams;
const debug = urlParams.get("debug"); // string or undefined
let frameCount = urlParams.get("frameCount");
frameCount = frameCount == undefined ? -1 : parseInt(frameCount, 10);
const separateComputePasses =
urlParams.get("separateComputePasses") != undefined; // true or false

// if we want more:
// Object.fromEntries(new URL(window.location.href).searchParams.entries());
// if url is 'https://foo.com/bar.html?abc=123&def=456&xyz=banana` then params is
Expand Down Expand Up @@ -60,14 +64,16 @@ uni.set({
TOGGLE_DURATION: 400.0, // number of timesteps between model toggle
WIGGLE_MAGNITUDE: 0, // 0.002, //0.025, // how much vertices are perturbed
WIGGLE_SPEED: 0.05, // how quickly perturbations occur
subdivLevel: 0,
subdivLevel: urlParams.get("subdivLevel")
? parseInt(urlParams.get("subdivLevel"), 10)
: 0,
level: 0,
time: 0.0,
timestep: 1.0,
});

const modelParams = {
model: "square_pyramid", // default starting point
model: urlParams.get("model") ? urlParams.get("model") : "square_pyramid", // default starting point
};

const modelToURL = {
Expand Down Expand Up @@ -932,20 +938,45 @@ async function frame() {
transformationMatrix.byteLength
);

function passBoundary(
separateComputePasses,
timingHelper,
computePasses,
encoder
) {
if (separateComputePasses) {
computePasses.at(-1).end();
computePasses.push(
timingHelper.beginComputePass(encoder, {
label: `compute pass ${
computePasses.length - 1
}, all compute kernels`,
})
);
}
}

// Encode commands to do the computation
const encoder = device.createCommandEncoder({
label:
"overall computation (perturb, face, edge, vertex, normals) + graphics encoder",
});

const computePass = timingHelper.beginComputePass(encoder, {
label: "compute pass, all compute kernels",
});
computePass.setPipeline(perturbPipeline);
computePass.setBindGroup(0, ctx.perturbBindGroup);
computePass.dispatchWorkgroups(
Math.ceil(mesh.levelCount[0].v / WORKGROUP_SIZE)
const kernels = separateComputePasses ? uni.views.subdivLevel[0] * 3 + 3 : 1;

const timingHelper = new TimingHelper(device, kernels);
const computePasses = [];
computePasses.push(
timingHelper.beginComputePass(encoder, {
label: `compute pass ${computePasses.length - 1}, all compute kernels`,
})
);
computePasses.at(-1).setPipeline(perturbPipeline);
computePasses.at(-1).setBindGroup(0, ctx.perturbBindGroup);
computePasses
.at(-1)
.dispatchWorkgroups(Math.ceil(mesh.levelCount[0].v / WORKGROUP_SIZE));
passBoundary(separateComputePasses, timingHelper, computePasses, encoder);

/** The face, edge, and vertex kernels run once per level */
for (var level = 1; level <= uni.views.subdivLevel[0]; level++) {
Expand All @@ -956,37 +987,41 @@ async function frame() {
// data structures
device.queue.writeBuffer(uniformsBuffer, 0, uni.arrayBuffer);

computePass.setPipeline(facePipeline);
computePass.setBindGroup(0, ctx.faceBindGroup);
computePass.dispatchWorkgroups(
Math.ceil(mesh.levelCount[level].f / WORKGROUP_SIZE)
);

computePass.setPipeline(edgePipeline);
computePass.setBindGroup(0, ctx.edgeBindGroup);
computePass.dispatchWorkgroups(
Math.ceil(mesh.levelCount[level].e / WORKGROUP_SIZE)
);

computePass.setPipeline(vertexPipeline);
computePass.setBindGroup(0, ctx.vertexBindGroup);
computePass.dispatchWorkgroups(
Math.ceil(mesh.levelCount[level].v / WORKGROUP_SIZE)
);
computePasses.at(-1).setPipeline(facePipeline);
computePasses.at(-1).setBindGroup(0, ctx.faceBindGroup);
computePasses
.at(-1)
.dispatchWorkgroups(Math.ceil(mesh.levelCount[level].f / WORKGROUP_SIZE));
passBoundary(separateComputePasses, timingHelper, computePasses, encoder);

computePasses.at(-1).setPipeline(edgePipeline);
computePasses.at(-1).setBindGroup(0, ctx.edgeBindGroup);
computePasses
.at(-1)
.dispatchWorkgroups(Math.ceil(mesh.levelCount[level].e / WORKGROUP_SIZE));
passBoundary(separateComputePasses, timingHelper, computePasses, encoder);

computePasses.at(-1).setPipeline(vertexPipeline);
computePasses.at(-1).setBindGroup(0, ctx.vertexBindGroup);
computePasses
.at(-1)
.dispatchWorkgroups(Math.ceil(mesh.levelCount[level].v / WORKGROUP_SIZE));
passBoundary(separateComputePasses, timingHelper, computePasses, encoder);
}

computePass.setPipeline(facetNormalsPipeline);
computePass.setBindGroup(0, ctx.facetNormalsBindGroup);
computePass.dispatchWorkgroups(
Math.ceil(mesh.facetNormals.length / WORKGROUP_SIZE)
);
computePasses.at(-1).setPipeline(facetNormalsPipeline);
computePasses.at(-1).setBindGroup(0, ctx.facetNormalsBindGroup);
computePasses
.at(-1)
.dispatchWorkgroups(Math.ceil(mesh.facetNormals.length / WORKGROUP_SIZE));

computePass.setPipeline(vertexNormalsPipeline);
computePass.setBindGroup(0, ctx.vertexNormalsBindGroup);
computePass.dispatchWorkgroups(
Math.ceil(mesh.vertexNormals.length / WORKGROUP_SIZE)
);
computePass.end();
passBoundary(separateComputePasses, timingHelper, computePasses, encoder);
computePasses.at(-1).setPipeline(vertexNormalsPipeline);
computePasses.at(-1).setBindGroup(0, ctx.vertexNormalsBindGroup);
computePasses
.at(-1)
.dispatchWorkgroups(Math.ceil(mesh.vertexNormals.length / WORKGROUP_SIZE));
computePasses.at(-1).end();

// Encode a command to copy the results to a mappable buffer.
// this is (from, to)
Expand Down Expand Up @@ -1077,12 +1112,16 @@ async function frame() {

/* is this correct for getting timing info? */
timingHelper.getResult().then((res) => {
// console.log("timing helper result", res);
console.log("Compute pass time:", res, "ns");
});

uni.views.time[0] = uni.views.time[0] + uni.views.timestep[0];
// console.log("time", uni.views.time[0]);
// return;
if (frameCount == 0) {
return;
} else if (frameCount > 0) {
frameCount--;
}
requestAnimationFrame(frame);
}
requestAnimationFrame(frame);
Expand Down
93 changes: 65 additions & 28 deletions webgpufundamentals-timing.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// copied from
// https://webgpufundamentals.org/webgpu/lessons/webgpu-timing.html by gman@

function assert(cond, msg = '') {
function assert(cond, msg = "") {
if (!cond) {
throw new Error(msg);
}
Expand All @@ -14,16 +14,20 @@ class TimingHelper {
#resolveBuffer;
#resultBuffer;
#resultBuffers = [];
// state can be 'free', 'need resolve', 'wait for result'
#state = 'free';
#passNumber;
#maxPasses;
// state can be 'free', 'in progress', 'need resolve', 'wait for result'
#state = "free";

constructor(device) {
constructor(device, maxPasses = 1) {
this.#device = device;
this.#canTimestamp = device.features.has('timestamp-query');
this.#passNumber = 0;
this.#maxPasses = maxPasses;
this.#canTimestamp = device.features.has("timestamp-query");
if (this.#canTimestamp) {
this.#querySet = device.createQuerySet({
type: 'timestamp',
count: 2,
type: "timestamp",
count: maxPasses * 2,
});
this.#resolveBuffer = device.createBuffer({
size: this.#querySet.count * 8,
Expand All @@ -34,23 +38,35 @@ class TimingHelper {

#beginTimestampPass(encoder, fnName, descriptor) {
if (this.#canTimestamp) {
assert(this.#state === 'free', 'state not free');
this.#state = 'need resolve';
assert(
/* haven't started or finished all passes yet */
this.#state === "free" || this.#state == "in progress",
`state not free (state = ${this.#state})`
);

const pass = encoder[fnName]({
...descriptor,
...{
timestampWrites: {
querySet: this.#querySet,
beginningOfPassWriteIndex: 0,
endOfPassWriteIndex: 1,
beginningOfPassWriteIndex: this.#passNumber * 2,
endOfPassWriteIndex: this.#passNumber * 2 + 1,
},
},
});

this.#passNumber++;
if (this.#passNumber == this.#maxPasses) {
/* finished all passes */
this.#state = "need resolve";
} else {
/* still have passes to do */
this.#state = "in progress";
}

const resolve = () => this.#resolveTiming(encoder);
pass.end = (function(origFn) {
return function() {
pass.end = (function (origFn) {
return function () {
origFn.call(this);
resolve();
};
Expand All @@ -63,42 +79,63 @@ class TimingHelper {
}

beginRenderPass(encoder, descriptor = {}) {
return this.#beginTimestampPass(encoder, 'beginRenderPass', descriptor);
return this.#beginTimestampPass(encoder, "beginRenderPass", descriptor);
}

beginComputePass(encoder, descriptor = {}) {
return this.#beginTimestampPass(encoder, 'beginComputePass', descriptor);
return this.#beginTimestampPass(encoder, "beginComputePass", descriptor);
}

#resolveTiming(encoder) {
if (!this.#canTimestamp) {
return;
}
assert(this.#state === 'need resolve', 'must call addTimestampToPass');
this.#state = 'wait for result';
if (this.#passNumber != this.#maxPasses) {
return;
}
assert(this.#state === "need resolve", "must call addTimestampToPass");
this.#state = "wait for result";

this.#resultBuffer = this.#resultBuffers.pop() || this.#device.createBuffer({
size: this.#resolveBuffer.size,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
});
this.#resultBuffer =
this.#resultBuffers.pop() ||
this.#device.createBuffer({
size: this.#resolveBuffer.size,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
});

encoder.resolveQuerySet(this.#querySet, 0, this.#querySet.count, this.#resolveBuffer, 0);
encoder.copyBufferToBuffer(this.#resolveBuffer, 0, this.#resultBuffer, 0, this.#resultBuffer.size);
encoder.resolveQuerySet(
this.#querySet,
0,
this.#querySet.count,
this.#resolveBuffer,
0
);
encoder.copyBufferToBuffer(
this.#resolveBuffer,
0,
this.#resultBuffer,
0,
this.#resultBuffer.size
);
}

async getResult() {
if (!this.#canTimestamp) {
return 0;
}
assert(this.#state === 'wait for result', 'must call resolveTiming');
this.#state = 'free';
assert(this.#state === "wait for result", "must call resolveTiming");
this.#state = "free";

const resultBuffer = this.#resultBuffer;
await resultBuffer.mapAsync(GPUMapMode.READ);
const times = new BigInt64Array(resultBuffer.getMappedRange());
const duration = Number(times[1] - times[0]);
/* I need to read about functional programming in JS to make below pretty */
const durations = [];
for (var idx = 0; idx < times.length; idx += 2) {
durations.push(Number(times[idx + 1] - times[idx]));
}
resultBuffer.unmap();
this.#resultBuffers.push(resultBuffer);
return duration;
return durations;
}
}
}

0 comments on commit 1e0576d

Please sign in to comment.