diff --git a/strahl-lib/src/denoise-pass-shader.ts b/strahl-lib/src/denoise-pass-shader.ts deleted file mode 100644 index 1773e7d..0000000 --- a/strahl-lib/src/denoise-pass-shader.ts +++ /dev/null @@ -1,25 +0,0 @@ -import denoisePassShader from "./denoise-pass-shader.wgsl?raw"; - -type Params = { - imageWidth: number; - imageHeight: number; - maxWorkgroupDimension: number; - maxBvhStackDepth: number; -}; - -const PARAM_PLACEHOLDER_MAP: Record = { - imageWidth: "imageWidth", - imageHeight: "imageHeight", - maxWorkgroupDimension: "maxWorkgroupDimension", - maxBvhStackDepth: "maxBvhStackDepth", -}; - -export default function build(params: Params) { - const placeholders = Object.entries(PARAM_PLACEHOLDER_MAP) as [ - keyof Params, - string, - ][]; - return placeholders.reduce((aggregate, [key, value]) => { - return aggregate.replaceAll(`\${${value}}`, `${params[key]}`); - }, denoisePassShader); -} diff --git a/strahl-lib/src/path-tracer.ts b/strahl-lib/src/path-tracer.ts index 073d27d..9586288 100644 --- a/strahl-lib/src/path-tracer.ts +++ b/strahl-lib/src/path-tracer.ts @@ -1,6 +1,6 @@ import { buildPathTracerShader } from "./shaders/tracer-shader.ts"; import { buildRenderShader } from "./shaders/render-shader"; -import buildDenoisePassShader from "./denoise-pass-shader.ts"; +import { buildDenoisePassShader } from "./shaders/denoise-pass-shader.ts"; import { logGroup } from "./benchmark/cpu-performance-logger.ts"; import { OpenPBRMaterial } from "./openpbr-material"; import { @@ -1036,10 +1036,9 @@ async function runPathTracer( }); const denoisePassShaderCode = buildDenoisePassShader({ - imageWidth: width, - imageHeight: height, - maxWorkgroupDimension, - maxBvhStackDepth: maxBvhDepth, + bvhParams: { + maxBvhStackDepth: maxBvhDepth, + }, }); const denoisePassDefinitions = makeShaderDataDefinitions( @@ -1131,6 +1130,11 @@ async function runPathTracer( compute: { module: computeShaderModule, entryPoint: "computeMain", + constants: { + wgSize: maxWorkgroupDimension, + imageWidth: width, + imageHeight: height, + }, }, }); diff --git a/strahl-lib/src/denoise-pass-shader.wgsl b/strahl-lib/src/shaders/denoise-pass-shader.ts similarity index 51% rename from strahl-lib/src/denoise-pass-shader.wgsl rename to strahl-lib/src/shaders/denoise-pass-shader.ts index 19b10db..780915d 100644 --- a/strahl-lib/src/denoise-pass-shader.wgsl +++ b/strahl-lib/src/shaders/denoise-pass-shader.ts @@ -1,8 +1,17 @@ -// Denoise Pass contains the intersection test and ray casting code of tracer-shader -// todo: consolidate with tracer-shader +import { buildBvhShader } from "./bvh"; +type Params = { + bvhParams: Parameters[0]; +}; + +export function buildDenoisePassShader({ bvhParams }: Params) { + return /* wgsl */ ` alias Color = vec3f; +override wgSize: u32 = 16; +override imageWidth: u32 = 512; +override imageHeight: u32 = 512; + struct Material { baseWeight: f32, baseColor: Color, @@ -170,195 +179,7 @@ fn randomF32(seed: ptr) -> f32 { return f32(*seed - 1u) * range; } -const TRIANGLE_EPSILON = 1.0e-6; - -// Möller–Trumbore intersection algorithm without culling -fn triangleHit(triangle: Triangle, ray: Ray, rayT: Interval, hitRecord: ptr) -> bool { - let edge1 = triangle.u; - let edge2 = triangle.v; - let pvec = cross(ray.direction, edge2); - let det = dot(edge1, pvec); - // No hit if ray is parallel to the triangle (ray lies in plane of triangle) - if (det > -TRIANGLE_EPSILON && det < TRIANGLE_EPSILON) { - return false; - } - let invDet = 1.0 / det; - let tvec = ray.origin - triangle.Q; - let u = dot(tvec, pvec) * invDet; - - if (u < 0.0 || u > 1.0) { - return false; - } - - let qvec = cross(tvec, edge1); - let v = dot(ray.direction, qvec) * invDet; - - if (v < 0.0 || u + v > 1.0) { - return false; - } - - let t = dot(edge2, qvec) * invDet; - - // check if the intersection point is within the ray's interval - if (t < (rayT).min || t > (rayT).max) { - return false; - } - - (*hitRecord).t = t; - (*hitRecord).point = rayAt(ray, t); - (*hitRecord).normal = normalize(triangle.normal0 * (1.0 - u - v) + triangle.normal1 * u + triangle.normal2 * v); - - (*hitRecord).material = triangle.material; - - return true; -} - -// Based on https://github.com/gkjohnson/three-mesh-bvh/blob/master/src/gpu/glsl/bvh_ray_functions.glsl.js -fn intersectsBounds(ray: Ray, boundsMin: vec3f, boundsMax: vec3f, dist: ptr) -> bool { - let invDir = vec3f(1.0) / ray.direction; - - let tMinPlane = invDir * (boundsMin - ray.origin); - let tMaxPlane = invDir * (boundsMax - ray.origin); - - let tMinHit = min(tMaxPlane, tMinPlane); - let tMaxHit = max(tMaxPlane, tMinPlane); - - var t = max(tMinHit.xx, tMinHit.yz); - let t0 = max(t.x, t.y); - - t = min(tMaxHit.xx, tMaxHit.yz); - let t1 = min(t.x, t.y); - - (*dist) = max(t0, 0.0); - - return t1 >= (*dist); -} - -fn intersectsBVHNodeBounds(ray: Ray, currNodeIndex: u32, dist: ptr) -> bool { - // 2 x x,y,z + unused alpha - let boundaries = bounds[currNodeIndex]; - let boundsMin = boundaries[0]; - let boundsMax = boundaries[1]; - return intersectsBounds(ray, boundsMin.xyz, boundsMax.xyz, dist); -} - -fn intersectTriangles(offset: u32, count: u32, ray: Ray, rayT: Interval, hitRecord: ptr) -> bool { - var found = false; - var localDist = hitRecord.t; - let l = offset + count; - - for (var i = offset; i < l; i += 1) { - let indAccess = indirectIndices[i]; - let indicesPackage = indices[indAccess]; - let v1Index = indicesPackage.x; - let v2Index = indicesPackage.y; - let v3Index = indicesPackage.z; - - let v1 = positions[v1Index]; - let v2 = positions[v2Index]; - let v3 = positions[v3Index]; - let x = v1[0]; - let y = v2[0]; - let z = v3[0]; - - let normalX = v1[1]; - let normalY = v2[1]; - let normalZ = v3[1]; - - let Q = x; - let u = y - x; - let v = z - x; - - let vIndexOffset = indAccess * 3; - var matchingObjectDefinition: ObjectDefinition = objectDefinitions[0]; - for (var j = 0; j < uniformData.objectDefinitionLength ; j++) { - let objectDefinition = objectDefinitions[j]; - if (objectDefinition.start <= vIndexOffset && objectDefinition.start + objectDefinition.count > vIndexOffset) { - matchingObjectDefinition = objectDefinition; - break; - } - } - let materialDefinition = matchingObjectDefinition.material; - - let triangle = Triangle(Q, u, v, materialDefinition, normalX, normalY, normalZ); - - var tmpRecord: HitRecord; - if (triangleHit(triangle, ray, Interval(rayT.min, localDist), &tmpRecord)) { - if (localDist < tmpRecord.t) { - continue; - } - (*hitRecord) = tmpRecord; - - localDist = (*hitRecord).t; - found = true; - } - } - return found; -} - -fn hittableListHit(ray: Ray, rayT: Interval, hitRecord: ptr) -> bool { - var tempRecord: HitRecord; - var hitAnything = false; - var closestSoFar = rayT.max; - - // Inspired by https://github.com/gkjohnson/three-mesh-bvh/blob/master/src/gpu/glsl/bvh_ray_functions.glsl.js - - // BVH Intersection Detection - var sPtr = 0; - var stack: array = array(); - stack[sPtr] = 0u; - - while (sPtr > -1 && sPtr < ${maxBvhStackDepth}) { - let currNodeIndex = stack[sPtr]; - sPtr -= 1; - - var boundsHitDistance: f32; - - if (!intersectsBVHNodeBounds(ray, currNodeIndex, &boundsHitDistance) || boundsHitDistance > closestSoFar) { - continue; - } - - let boundsInfo = contents[currNodeIndex]; - let boundsInfoX = boundsInfo.x; - let boundsInfoY = boundsInfo.y; - - let isLeaf = (boundsInfoX & 0xffff0000u) == 0xffff0000u; - - if (isLeaf) { - let count = boundsInfoX & 0x0000ffffu; - let offset = boundsInfoY; - - let found2 = intersectTriangles( - offset, - count, - ray, - rayT, - hitRecord - ); - if (found2) { - closestSoFar = (*hitRecord).t; - } - - hitAnything = hitAnything || found2; - } else { - // Left node is always the next node - let leftIndex = currNodeIndex + 1u; - let splitAxis = boundsInfoX & 0x0000ffffu; - let rightIndex = boundsInfoY; - - let leftToRight = ray.direction[splitAxis] > 0.0; - let c1 = select(rightIndex, leftIndex, leftToRight); - let c2 = select(leftIndex, rightIndex, leftToRight); - - sPtr += 1; - stack[sPtr] = c2; - sPtr += 1; - stack[sPtr] = c1; - } - } - - return hitAnything; -} +${buildBvhShader(bvhParams)} const TRIANGLE_MIN_DISTANCE_THRESHOLD = 0.0005; const TRIANGLE_MAX_DISTANCE_THRESHOLD = 10e37f; @@ -424,19 +245,21 @@ fn getPixelJitter(seed: ptr) -> vec2f { } @compute -@workgroup_size(${maxWorkgroupDimension}, ${maxWorkgroupDimension}, 1) +@workgroup_size(wgSize, wgSize, 1) fn computeMain(@builtin(global_invocation_id) globalId: vec3) { - var seed = globalId.x + globalId.y * ${imageWidth}; + var seed = globalId.x + globalId.y * imageWidth; seed ^= uniformData.seedOffset; let pixelOrigin = vec2f(f32(globalId.x), f32(globalId.y)); let pixel = pixelOrigin; - let ndc = -1.0 + 2.0*pixel / vec2(${imageWidth}, ${imageHeight}); + let ndc = -1.0 + 2.0 * pixel / vec2f(f32(imageWidth), f32(imageHeight)); var ray = ndcToCameraRay(ndc, uniformData.invModelMatrix * uniformData.cameraWorldMatrix, uniformData.invProjectionMatrix, &seed); ray.direction = normalize(ray.direction); let output = getRayOutput(ray, &seed); - hdrColor[i32(globalId.x) + i32(globalId.y) * ${imageWidth}] = output; + hdrColor[i32(globalId.x) + i32(globalId.y) * i32(imageWidth)] = output; +} + `; }