Skip to content

Commit

Permalink
O(n) vertex normals kernel implemented, appears to work, still some i…
Browse files Browse the repository at this point in the history
…ssues with timing failing and dispatch for large models exceeding 65535 workgroups in x dimension
  • Loading branch information
John Owens committed Oct 4, 2024
1 parent 1e0576d commit a919ea6
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 14 deletions.
116 changes: 104 additions & 12 deletions catmull-clark.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,25 @@ import {
makeStructuredView,
} from "https://greggman.github.io/webgpu-utils/dist/1.x/webgpu-utils.module.js";

// We can set runtime params from the input URL!
const urlParams = new URL(window.location.href).searchParams;
const debug = urlParams.get("debug"); // string or undefined
let frameCount = urlParams.get("frameCount");
frameCount = frameCount == undefined ? -1 : parseInt(frameCount, 10);
const separateComputePasses = urlParams.get("separateComputePasses");
const timingEnabled = urlParams.get("timing");

const adapter = await navigator.gpu?.requestAdapter();
const canTimestamp = adapter.features.has("timestamp-query");
const device = await adapter?.requestDevice({
requiredFeatures: [...(canTimestamp ? ["timestamp-query"] : [])], // ...: conditional add
requiredFeatures: [
...(canTimestamp && timingEnabled ? ["timestamp-query"] : []),
], // ...: conditional add
});
if (!device) {
fail("Fatal error: Device does not support WebGPU.");
}

// We can set runtime params from the input URL!
const urlParams = new URL(window.location.href).searchParams;
const debug = urlParams.get("debug"); // string or undefined
let frameCount = urlParams.get("frameCount");
frameCount = frameCount == undefined ? -1 : parseInt(frameCount, 10);
const separateComputePasses =
urlParams.get("separateComputePasses") != undefined; // true or false

// if we want more:
// Object.fromEntries(new URL(window.location.href).searchParams.entries());
// if url is 'https://foo.com/bar.html?abc=123&def=456&xyz=banana` then params is
Expand Down Expand Up @@ -388,8 +390,8 @@ const facetNormalsModule = device.createShaderModule({
}`,
});

const vertexNormalsModule = device.createShaderModule({
label: "compute vertex normals module",
const vertexNormalsON2Module = device.createShaderModule({
label: "compute vertex normals (O(n^2)) module",
code: /* wgsl */ `
${uniformsCode} /* this specifies @group(0) @binding(0) */
/* output */
Expand Down Expand Up @@ -418,6 +420,33 @@ const vertexNormalsModule = device.createShaderModule({
}`,
});

const vertexNormalsModule = device.createShaderModule({
label: "compute vertex normals (O(n)) module",
code: /* wgsl */ `
${uniformsCode} /* this specifies @group(0) @binding(0) */
/* output */
@group(0) @binding(1) var<storage, read_write> vertexNormals: array<vec3f>;
/* input */
@group(0) @binding(2) var<storage, read> facetNormals: array<vec3f>;
@group(0) @binding(3) var<storage, read> vertexToTriangles: array<u32>;
@group(0) @binding(4) var<storage, read> vertexToTrianglesOffset: array<u32>;
@group(0) @binding(5) var<storage, read> vertexToTrianglesValence: array<u32>;
/* see facetNormalsModule for algorithm */
@compute @workgroup_size(${WORKGROUP_SIZE}) fn vertexNormalsKernel(
@builtin(global_invocation_id) id: vec3u) {
let vtx = id.x;
if (vtx < arrayLength(&vertexNormals)) {
vertexNormals[vtx] = vec3f(0, 0, 0);
for (var neighbor: u32 = vertexToTrianglesOffset[vtx]; neighbor < vertexToTrianglesOffset[vtx] + vertexToTrianglesValence[vtx]; neighbor++) {
vertexNormals[vtx] += facetNormals[vertexToTriangles[neighbor]];
}
vertexNormals[vtx] = normalize(vertexNormals[vtx]);
}
}`,
});

const renderModule = device.createShaderModule({
label: "render module",
code: /* wgsl */ `
Expand Down Expand Up @@ -503,6 +532,14 @@ const facetNormalsPipeline = device.createComputePipeline({
},
});

const vertexNormalsON2Pipeline = device.createComputePipeline({
label: "vertex normals O(N^2) compute pipeline",
layout: "auto",
compute: {
module: vertexNormalsON2Module,
},
});

const vertexNormalsPipeline = device.createComputePipeline({
label: "vertex normals compute pipeline",
layout: "auto",
Expand Down Expand Up @@ -669,6 +706,24 @@ class GPUContext {
GPUBufferUsage.INDEX | GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});

this.vertexToTrianglesBuffer = device.createBuffer({
label: "vertex to triangles buffer",
size: mesh.vertexToTriangles.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});

this.vertexToTrianglesOffsetBuffer = device.createBuffer({
label: "vertex to triangles offset buffer",
size: mesh.vertexToTrianglesOffset.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});

this.vertexToTrianglesValenceBuffer = device.createBuffer({
label: "vertex to triangles valence buffer",
size: mesh.vertexToTrianglesValence.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});

// vertex buffer is both input and output
this.verticesBuffer = device.createBuffer({
label: "vertex buffer",
Expand Down Expand Up @@ -795,6 +850,25 @@ class GPUContext {
this.vertexNormalsBindGroup = device.createBindGroup({
label: "bindGroup for computing vertex normals",
layout: vertexNormalsPipeline.getBindGroupLayout(0),
entries: [
// { binding: 0, resource: { buffer: uniformsBuffer } },
{ binding: 1, resource: { buffer: this.vertexNormalsBuffer } },
{ binding: 2, resource: { buffer: this.facetNormalsBuffer } },
{ binding: 3, resource: { buffer: this.vertexToTrianglesBuffer } },
{
binding: 4,
resource: { buffer: this.vertexToTrianglesOffsetBuffer },
},
{
binding: 5,
resource: { buffer: this.vertexToTrianglesValenceBuffer },
},
],
});

this.vertexNormalsON2BindGroup = device.createBindGroup({
label: "bindGroup for computing vertex normals",
layout: vertexNormalsON2Pipeline.getBindGroupLayout(0),
entries: [
// { binding: 0, resource: { buffer: uniformsBuffer } },
{ binding: 1, resource: { buffer: this.vertexNormalsBuffer } },
Expand Down Expand Up @@ -832,6 +906,21 @@ class GPUContext {
device.queue.writeBuffer(this.vertexIndexBuffer, 0, mesh.vertexIndex);
device.queue.writeBuffer(this.triangleIndicesBuffer, 0, mesh.triangles);
device.queue.writeBuffer(this.verticesBuffer, 0, mesh.vertices);
device.queue.writeBuffer(
this.vertexToTrianglesBuffer,
0,
mesh.vertexToTriangles
);
device.queue.writeBuffer(
this.vertexToTrianglesValenceBuffer,
0,
mesh.vertexToTrianglesValence
);
device.queue.writeBuffer(
this.vertexToTrianglesOffsetBuffer,
0,
mesh.vertexToTrianglesOffset
);
device.queue.writeBuffer(this.facetNormalsBuffer, 0, mesh.facetNormals);
device.queue.writeBuffer(this.vertexNormalsBuffer, 0, mesh.vertexNormals);
uni.set({ levelCount: mesh.levelCount, levelBasePtr: mesh.levelBasePtr });
Expand All @@ -850,6 +939,9 @@ class GPUContext {
this.vertexIndexBuffer.destroy();
this.triangleIndicesBuffer.destroy();
this.verticesBuffer.destroy();
this.vertexToTrianglesBuffer.destroy();
this.vertexToTrianglesValenceBuffer.destroy();
this.vertexToTrianglesOffsetBuffer.destroy();
this.facetNormalsBuffer.destroy();
this.vertexNormalsBuffer.destroy();
this.mappableVerticesResultBuffer.destroy();
Expand Down Expand Up @@ -1112,7 +1204,7 @@ async function frame() {

/* is this correct for getting timing info? */
timingHelper.getResult().then((res) => {
console.log("Compute pass time:", res, "ns");
// console.log("Compute pass time:", res, "ns");
});

uni.views.time[0] = uni.views.time[0] + uni.views.timestep[0];
Expand Down
51 changes: 49 additions & 2 deletions objload.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ class Level {
}

class SubdivMesh {
exclusive_prefix_sum(inList) {
const outList = [];
var sum = 0;
for (var i = 0; i < inList.length; i++) {
outList.push(sum);
sum += inList[i];
}
return outList;
}

constructor(verticesIn, facesIn) {
/* everything prefixed with "this." is a data structure that will go to the GPU */
/* everything else is internal-only and will not be externally visible */
Expand All @@ -30,12 +40,15 @@ class SubdivMesh {
this.vertexOffset = [];
this.vertexValence = [];
this.vertexIndex = [];
this.vertexToTriangles = new Array(); // indexed by vertex number, has list of tri neighbors
this.vertexToTrianglesOffset = [];
this.vertexToTrianglesValence = [];
this.vertexSize = 4; // # elements per vertex (ignore w coord for now)
this.normalSize = 4; // float4s (ignore w coord for now)

/** levelCount[L].x is the starting index into the vertices array for level L, point type x
* levelBasePtr[L].x is the number of point type x in level L
* levelBasePtr ~= exclusive-scan(levelCount), mostly
* levelBasePtr ~= exclusive-sum-scan(levelCount), mostly
* now: populate level 0 of levelCount and levelBasePtr
* assumes manifold surface!
*/
Expand All @@ -52,7 +65,7 @@ class SubdivMesh {

this.scaleInput = true; // if true, scales output into [-1,1]^3 box
this.largestInput = 0.0;
this.maxLevel = 3; // valid levels are <= maxLevel
this.maxLevel = 4; // valid levels are <= maxLevel

/** OBJ stores faces in CCW order
* The OBJ (or .OBJ) file format stores vertices in a counterclockwise order by default.
Expand Down Expand Up @@ -128,6 +141,14 @@ class SubdivMesh {
/* triangles: (-3, -2, -1)
* quads: (-4, -3, -2) (-4, -2, -1) */
);
const triID = this.triangles.length / 3 - 1;
for (let k = -3; k < 0; k++) {
const vtxID = this.triangles.at(k);
if (this.vertexToTriangles[vtxID] == undefined) {
this.vertexToTriangles[vtxID] = [];
}
this.vertexToTriangles[vtxID].push(triID);
}
}
this.levelCount[0].t += valence - 2;
}
Expand Down Expand Up @@ -299,6 +320,22 @@ class SubdivMesh {
v(j),
e(j, mod(j + 1, valence))
);
var triID = this.triangles.length / 3 - 2;
for (let k = -6; k < -3; k++) {
const vtxID = this.triangles.at(k);
if (this.vertexToTriangles[vtxID] == undefined) {
this.vertexToTriangles[vtxID] = [];
}
this.vertexToTriangles[vtxID].push(triID);
}
triID++;
for (let k = -3; k < 0; k++) {
const vtxID = this.triangles.at(k);
if (this.vertexToTriangles[vtxID] == undefined) {
this.vertexToTriangles[vtxID] = [];
}
this.vertexToTriangles[vtxID].push(triID);
}
}
this.levelCount[level].t += valence * 2;
}
Expand Down Expand Up @@ -375,6 +412,16 @@ class SubdivMesh {
this.vertexOffset = new Uint32Array(this.vertexOffset);
this.vertexIndex = new Uint32Array(this.vertexIndex);
this.vertexNeighbors = new Uint32Array(this.vertexNeighbors.flat());
this.vertexToTrianglesValence = this.vertexToTriangles.map(
(tris) => tris.length
);
this.vertexToTrianglesOffset = new Uint32Array(
this.exclusive_prefix_sum(this.vertexToTrianglesValence)
);
this.vertexToTrianglesValence = new Uint32Array(
this.vertexToTrianglesValence
);
this.vertexToTriangles = new Uint32Array(this.vertexToTriangles.flat());
/* the next two arrays are empty and will be filled by the GPU */
this.vertexNormals = new Float32Array(this.verticesSize * this.normalSize);
this.facetNormals = new Float32Array(
Expand Down

0 comments on commit a919ea6

Please sign in to comment.