Skip to content

Commit

Permalink
Consolidate denoise pass shader
Browse files Browse the repository at this point in the history
  • Loading branch information
StuckiSimon committed Aug 24, 2024
1 parent 548e20b commit 4d85ec2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 225 deletions.
25 changes: 0 additions & 25 deletions strahl-lib/src/denoise-pass-shader.ts

This file was deleted.

14 changes: 9 additions & 5 deletions strahl-lib/src/path-tracer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { buildPathTracerShader } from "./shaders/tracer-shader.ts";
import { buildRenderShader } from "./shaders/render-shader";
import buildDenoisePassShader from "./denoise-pass-shader.ts";
import { buildDenoisePassShader } from "./shaders/denoise-pass-shader.ts";
import { logGroup } from "./benchmark/cpu-performance-logger.ts";
import { OpenPBRMaterial } from "./openpbr-material";
import {
Expand Down Expand Up @@ -1036,10 +1036,9 @@ async function runPathTracer(
});

const denoisePassShaderCode = buildDenoisePassShader({
imageWidth: width,
imageHeight: height,
maxWorkgroupDimension,
maxBvhStackDepth: maxBvhDepth,
bvhParams: {
maxBvhStackDepth: maxBvhDepth,
},
});

const denoisePassDefinitions = makeShaderDataDefinitions(
Expand Down Expand Up @@ -1131,6 +1130,11 @@ async function runPathTracer(
compute: {
module: computeShaderModule,
entryPoint: "computeMain",
constants: {
wgSize: maxWorkgroupDimension,
imageWidth: width,
imageHeight: height,
},
},
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
// Denoise Pass contains the intersection test and ray casting code of tracer-shader
// todo: consolidate with tracer-shader
import { buildBvhShader } from "./bvh";

type Params = {
bvhParams: Parameters<typeof buildBvhShader>[0];
};

export function buildDenoisePassShader({ bvhParams }: Params) {
return /* wgsl */ `
alias Color = vec3f;
override wgSize: u32 = 16;
override imageWidth: u32 = 512;
override imageHeight: u32 = 512;
struct Material {
baseWeight: f32,
baseColor: Color,
Expand Down Expand Up @@ -170,195 +179,7 @@ fn randomF32(seed: ptr<function, u32>) -> f32 {
return f32(*seed - 1u) * range;
}
const TRIANGLE_EPSILON = 1.0e-6;

// Möller–Trumbore intersection algorithm without culling
fn triangleHit(triangle: Triangle, ray: Ray, rayT: Interval, hitRecord: ptr<function, HitRecord>) -> bool {
let edge1 = triangle.u;
let edge2 = triangle.v;
let pvec = cross(ray.direction, edge2);
let det = dot(edge1, pvec);
// No hit if ray is parallel to the triangle (ray lies in plane of triangle)
if (det > -TRIANGLE_EPSILON && det < TRIANGLE_EPSILON) {
return false;
}
let invDet = 1.0 / det;
let tvec = ray.origin - triangle.Q;
let u = dot(tvec, pvec) * invDet;

if (u < 0.0 || u > 1.0) {
return false;
}

let qvec = cross(tvec, edge1);
let v = dot(ray.direction, qvec) * invDet;

if (v < 0.0 || u + v > 1.0) {
return false;
}

let t = dot(edge2, qvec) * invDet;

// check if the intersection point is within the ray's interval
if (t < (rayT).min || t > (rayT).max) {
return false;
}

(*hitRecord).t = t;
(*hitRecord).point = rayAt(ray, t);
(*hitRecord).normal = normalize(triangle.normal0 * (1.0 - u - v) + triangle.normal1 * u + triangle.normal2 * v);

(*hitRecord).material = triangle.material;

return true;
}

// Based on https://github.com/gkjohnson/three-mesh-bvh/blob/master/src/gpu/glsl/bvh_ray_functions.glsl.js
fn intersectsBounds(ray: Ray, boundsMin: vec3f, boundsMax: vec3f, dist: ptr<function, f32>) -> bool {
let invDir = vec3f(1.0) / ray.direction;

let tMinPlane = invDir * (boundsMin - ray.origin);
let tMaxPlane = invDir * (boundsMax - ray.origin);

let tMinHit = min(tMaxPlane, tMinPlane);
let tMaxHit = max(tMaxPlane, tMinPlane);

var t = max(tMinHit.xx, tMinHit.yz);
let t0 = max(t.x, t.y);

t = min(tMaxHit.xx, tMaxHit.yz);
let t1 = min(t.x, t.y);

(*dist) = max(t0, 0.0);

return t1 >= (*dist);
}

fn intersectsBVHNodeBounds(ray: Ray, currNodeIndex: u32, dist: ptr<function, f32>) -> bool {
// 2 x x,y,z + unused alpha
let boundaries = bounds[currNodeIndex];
let boundsMin = boundaries[0];
let boundsMax = boundaries[1];
return intersectsBounds(ray, boundsMin.xyz, boundsMax.xyz, dist);
}

fn intersectTriangles(offset: u32, count: u32, ray: Ray, rayT: Interval, hitRecord: ptr<function, HitRecord>) -> bool {
var found = false;
var localDist = hitRecord.t;
let l = offset + count;

for (var i = offset; i < l; i += 1) {
let indAccess = indirectIndices[i];
let indicesPackage = indices[indAccess];
let v1Index = indicesPackage.x;
let v2Index = indicesPackage.y;
let v3Index = indicesPackage.z;

let v1 = positions[v1Index];
let v2 = positions[v2Index];
let v3 = positions[v3Index];
let x = v1[0];
let y = v2[0];
let z = v3[0];

let normalX = v1[1];
let normalY = v2[1];
let normalZ = v3[1];

let Q = x;
let u = y - x;
let v = z - x;

let vIndexOffset = indAccess * 3;
var matchingObjectDefinition: ObjectDefinition = objectDefinitions[0];
for (var j = 0; j < uniformData.objectDefinitionLength ; j++) {
let objectDefinition = objectDefinitions[j];
if (objectDefinition.start <= vIndexOffset && objectDefinition.start + objectDefinition.count > vIndexOffset) {
matchingObjectDefinition = objectDefinition;
break;
}
}
let materialDefinition = matchingObjectDefinition.material;

let triangle = Triangle(Q, u, v, materialDefinition, normalX, normalY, normalZ);

var tmpRecord: HitRecord;
if (triangleHit(triangle, ray, Interval(rayT.min, localDist), &tmpRecord)) {
if (localDist < tmpRecord.t) {
continue;
}
(*hitRecord) = tmpRecord;

localDist = (*hitRecord).t;
found = true;
}
}
return found;
}

fn hittableListHit(ray: Ray, rayT: Interval, hitRecord: ptr<function, HitRecord>) -> bool {
var tempRecord: HitRecord;
var hitAnything = false;
var closestSoFar = rayT.max;

// Inspired by https://github.com/gkjohnson/three-mesh-bvh/blob/master/src/gpu/glsl/bvh_ray_functions.glsl.js

// BVH Intersection Detection
var sPtr = 0;
var stack: array<u32, ${maxBvhStackDepth}> = array<u32, ${maxBvhStackDepth}>();
stack[sPtr] = 0u;

while (sPtr > -1 && sPtr < ${maxBvhStackDepth}) {
let currNodeIndex = stack[sPtr];
sPtr -= 1;

var boundsHitDistance: f32;

if (!intersectsBVHNodeBounds(ray, currNodeIndex, &boundsHitDistance) || boundsHitDistance > closestSoFar) {
continue;
}

let boundsInfo = contents[currNodeIndex];
let boundsInfoX = boundsInfo.x;
let boundsInfoY = boundsInfo.y;

let isLeaf = (boundsInfoX & 0xffff0000u) == 0xffff0000u;

if (isLeaf) {
let count = boundsInfoX & 0x0000ffffu;
let offset = boundsInfoY;

let found2 = intersectTriangles(
offset,
count,
ray,
rayT,
hitRecord
);
if (found2) {
closestSoFar = (*hitRecord).t;
}

hitAnything = hitAnything || found2;
} else {
// Left node is always the next node
let leftIndex = currNodeIndex + 1u;
let splitAxis = boundsInfoX & 0x0000ffffu;
let rightIndex = boundsInfoY;

let leftToRight = ray.direction[splitAxis] > 0.0;
let c1 = select(rightIndex, leftIndex, leftToRight);
let c2 = select(leftIndex, rightIndex, leftToRight);

sPtr += 1;
stack[sPtr] = c2;
sPtr += 1;
stack[sPtr] = c1;
}
}

return hitAnything;
}
${buildBvhShader(bvhParams)}
const TRIANGLE_MIN_DISTANCE_THRESHOLD = 0.0005;
const TRIANGLE_MAX_DISTANCE_THRESHOLD = 10e37f;
Expand Down Expand Up @@ -424,19 +245,21 @@ fn getPixelJitter(seed: ptr<function, u32>) -> vec2f {
}
@compute
@workgroup_size(${maxWorkgroupDimension}, ${maxWorkgroupDimension}, 1)
@workgroup_size(wgSize, wgSize, 1)
fn computeMain(@builtin(global_invocation_id) globalId: vec3<u32>) {
var seed = globalId.x + globalId.y * ${imageWidth};
var seed = globalId.x + globalId.y * imageWidth;
seed ^= uniformData.seedOffset;
let pixelOrigin = vec2f(f32(globalId.x), f32(globalId.y));
let pixel = pixelOrigin;
let ndc = -1.0 + 2.0*pixel / vec2<f32>(${imageWidth}, ${imageHeight});
let ndc = -1.0 + 2.0 * pixel / vec2f(f32(imageWidth), f32(imageHeight));
var ray = ndcToCameraRay(ndc, uniformData.invModelMatrix * uniformData.cameraWorldMatrix, uniformData.invProjectionMatrix, &seed);
ray.direction = normalize(ray.direction);
let output = getRayOutput(ray, &seed);
hdrColor[i32(globalId.x) + i32(globalId.y) * ${imageWidth}] = output;
hdrColor[i32(globalId.x) + i32(globalId.y) * i32(imageWidth)] = output;
}
`;
}

0 comments on commit 4d85ec2

Please sign in to comment.