From 20a26d2010eb3b4bc818822e717bab3af75f5a88 Mon Sep 17 00:00:00 2001 From: Christian Helgeson <62450112+cmhhelgeson@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:03:05 -0700 Subject: [PATCH] Added bitonic sort example using compute shaders. (#301) * Added bitonic sort example * Changed shaped of hover cursor in shader to be more readable * Changed names of values and added auto-complete sort functionality * Removed unused argKeys value * Implemented suggested changes * Implemented lolokung suggested changes * Removed references to reticle * Implemented non-extant austinEng suggestions * Removed unused enums (may add back later), changed type of completeSortIntervalID, opened Sort Controls folder on init for sake of clarity * Removed createWGSLUniforms * Minor shader fix --- src/pages/samples/[slug].tsx | 1 + .../bitonicSort/bitonicDisplay.frag.wgsl | 36 ++ src/sample/bitonicSort/bitonicDisplay.ts | 88 +++ src/sample/bitonicSort/computeShader.ts | 94 ++++ src/sample/bitonicSort/main.ts | 518 ++++++++++++++++++ src/sample/bitonicSort/utils.ts | 212 +++++++ 6 files changed, 949 insertions(+) create mode 100644 src/sample/bitonicSort/bitonicDisplay.frag.wgsl create mode 100644 src/sample/bitonicSort/bitonicDisplay.ts create mode 100644 src/sample/bitonicSort/computeShader.ts create mode 100644 src/sample/bitonicSort/main.ts create mode 100644 src/sample/bitonicSort/utils.ts diff --git a/src/pages/samples/[slug].tsx b/src/pages/samples/[slug].tsx index 7dbc7234..79623cd9 100644 --- a/src/pages/samples/[slug].tsx +++ b/src/pages/samples/[slug].tsx @@ -47,6 +47,7 @@ export const pages: PageComponentType = { renderBundles: dynamic(() => import('../../sample/renderBundles/main')), worker: dynamic(() => import('../../sample/worker/main')), 'A-buffer': dynamic(() => import('../../sample/a-buffer/main')), + bitonicSort: dynamic(() => import('../../sample/bitonicSort/main')), }; function Page({ slug }: Props): JSX.Element { diff --git a/src/sample/bitonicSort/bitonicDisplay.frag.wgsl b/src/sample/bitonicSort/bitonicDisplay.frag.wgsl new file mode 100644 index 00000000..3f4a17ea --- /dev/null +++ b/src/sample/bitonicSort/bitonicDisplay.frag.wgsl @@ -0,0 +1,36 @@ +struct Uniforms { + width: f32, + height: f32, +} + +struct VertexOutput { + @builtin(position) Position: vec4, + @location(0) fragUV: vec2 +} + +@group(0) @binding(0) var uniforms: Uniforms; +@group(1) @binding(0) var data: array; + +@fragment +fn frag_main(input: VertexOutput) -> @location(0) vec4 { + var uv: vec2 = vec2( + input.fragUV.x * uniforms.width, + input.fragUV.y * uniforms.height + ); + + var pixel: vec2 = vec2( + u32(floor(uv.x)), + u32(floor(uv.y)), + ); + + var elementIndex = u32(uniforms.width) * pixel.y + pixel.x; + var colorChanger = data[elementIndex]; + + var subtracter = f32(colorChanger) / (uniforms.width * uniforms.height); + + var color: vec3 = vec3f( + 1.0 - subtracter + ); + + return vec4(color.rgb, 1.0); +} diff --git a/src/sample/bitonicSort/bitonicDisplay.ts b/src/sample/bitonicSort/bitonicDisplay.ts new file mode 100644 index 00000000..210ef232 --- /dev/null +++ b/src/sample/bitonicSort/bitonicDisplay.ts @@ -0,0 +1,88 @@ +import { + BindGroupsObjectsAndLayout, + createBindGroupDescriptor, + Base2DRendererClass, +} from './utils'; + +import bitonicDisplay from './bitonicDisplay.frag.wgsl'; + +interface BitonicDisplayRenderArgs { + width: number; + height: number; +} + +export default class BitonicDisplayRenderer extends Base2DRendererClass { + static sourceInfo = { + name: __filename.substring(__dirname.length + 1), + contents: __SOURCE__, + }; + + switchBindGroup: (name: string) => void; + setArguments: (args: BitonicDisplayRenderArgs) => void; + computeBGDescript: BindGroupsObjectsAndLayout; + + constructor( + device: GPUDevice, + presentationFormat: GPUTextureFormat, + renderPassDescriptor: GPURenderPassDescriptor, + bindGroupNames: string[], + computeBGDescript: BindGroupsObjectsAndLayout, + label: string + ) { + super(); + this.renderPassDescriptor = renderPassDescriptor; + this.computeBGDescript = computeBGDescript; + + const uniformBuffer = device.createBuffer({ + size: Float32Array.BYTES_PER_ELEMENT * 2, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + const bgDescript = createBindGroupDescriptor( + [0], + [GPUShaderStage.FRAGMENT], + ['buffer'], + [{ type: 'uniform' }], + [[{ buffer: uniformBuffer }]], + label, + device + ); + + this.currentBindGroup = bgDescript.bindGroups[0]; + this.currentBindGroupName = bindGroupNames[0]; + + this.bindGroupMap = {}; + + bgDescript.bindGroups.forEach((bg, idx) => { + this.bindGroupMap[bindGroupNames[idx]] = bg; + }); + + this.pipeline = super.create2DRenderPipeline( + device, + label, + [bgDescript.bindGroupLayout, this.computeBGDescript.bindGroupLayout], + bitonicDisplay, + presentationFormat + ); + + this.switchBindGroup = (name: string) => { + this.currentBindGroup = this.bindGroupMap[name]; + this.currentBindGroupName = name; + }; + + this.setArguments = (args: BitonicDisplayRenderArgs) => { + super.setUniformArguments(device, uniformBuffer, args, [ + 'width', + 'height', + ]); + }; + } + + startRun(commandEncoder: GPUCommandEncoder, args: BitonicDisplayRenderArgs) { + this.setArguments(args); + super.executeRun(commandEncoder, this.renderPassDescriptor, this.pipeline, [ + this.currentBindGroup, + this.computeBGDescript.bindGroups[0], + ]); + } +} diff --git a/src/sample/bitonicSort/computeShader.ts b/src/sample/bitonicSort/computeShader.ts new file mode 100644 index 00000000..4011eb41 --- /dev/null +++ b/src/sample/bitonicSort/computeShader.ts @@ -0,0 +1,94 @@ +export const computeArgKeys = ['width', 'height', 'algo', 'blockHeight']; + +export const NaiveBitonicCompute = (threadsPerWorkgroup: number) => { + if (threadsPerWorkgroup % 2 !== 0 || threadsPerWorkgroup > 256) { + threadsPerWorkgroup = 256; + } + // Ensure that workgroupSize is half the number of elements + return ` + +struct Uniforms { + width: f32, + height: f32, + algo: u32, + blockHeight: u32, +} + +// Create local workgroup data that can contain all elements + +var local_data: array; + +//Compare and swap values in local_data +fn compare_and_swap(idx_before: u32, idx_after: u32) { + //idx_before should always be < idx_after + if (local_data[idx_after] < local_data[idx_before]) { + var temp: u32 = local_data[idx_before]; + local_data[idx_before] = local_data[idx_after]; + local_data[idx_after] = temp; + } + return; +} + +// thread_id goes from 0 to threadsPerWorkgroup +fn prepare_flip(thread_id: u32, block_height: u32) { + let q: u32 = ((2 * thread_id) / block_height) * block_height; + let half_height = block_height / 2; + var idx: vec2 = vec2( + thread_id % half_height, block_height - (thread_id % half_height) - 1, + ); + idx.x += q; + idx.y += q; + compare_and_swap(idx.x, idx.y); +} + +fn prepare_disperse(thread_id: u32, block_height: u32) { + var q: u32 = ((2 * thread_id) / block_height) * block_height; + let half_height = block_height / 2; + var idx: vec2 = vec2( + thread_id % half_height, (thread_id % half_height) + half_height + ); + idx.x += q; + idx.y += q; + compare_and_swap(idx.x, idx.y); +} + +@group(0) @binding(0) var input_data: array; +@group(0) @binding(1) var output_data: array; +@group(0) @binding(2) var uniforms: Uniforms; + +// Our compute shader will execute specified # of threads or elements / 2 threads +@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1) +fn computeMain( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, +) { + //Each thread will populate the workgroup data... (1 thread for every 2 elements) + local_data[local_id.x * 2] = input_data[local_id.x * 2]; + local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1]; + + //...and wait for each other to finish their own bit of data population. + workgroupBarrier(); + + var num_elements = uniforms.width * uniforms.height; + + switch uniforms.algo { + case 1: { //Local Flip + prepare_flip(local_id.x, uniforms.blockHeight); + } + case 2: { //Local Disperse + prepare_disperse(local_id.x, uniforms.blockHeight); + } + default: { + + } + } + + //Ensure that all threads have swapped their own regions of data + workgroupBarrier(); + + //Repopulate global data with local data + output_data[local_id.x * 2] = local_data[local_id.x * 2]; + output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1]; + +}`; +}; diff --git a/src/sample/bitonicSort/main.ts b/src/sample/bitonicSort/main.ts new file mode 100644 index 00000000..cc8e5865 --- /dev/null +++ b/src/sample/bitonicSort/main.ts @@ -0,0 +1,518 @@ +import { makeSample, SampleInit } from '../../components/SampleLayout'; +import { SampleInitFactoryWebGPU } from './utils'; +import { createBindGroupDescriptor } from './utils'; +import BitonicDisplayRenderer from './bitonicDisplay'; +import bitonicDisplay from './bitonicDisplay.frag.wgsl'; +import { NaiveBitonicCompute } from './computeShader'; +import fullscreenTexturedQuad from '../../shaders/fullscreenTexturedQuad.wgsl'; + +// Type of step that will be executed in our shader +enum StepEnum { + NONE = 0, + FLIP_LOCAL = 1, + DISPERSE_LOCAL = 2, + FLIP_DISPERSE_LOCAL = 3, +} + +// String access to StepEnum +type StepType = + | 'NONE' + | 'FLIP_LOCAL' + | 'DISPERSE_LOCAL' + | 'FLIP_DISPERSE_LOCAL'; + +// Gui settings object +interface SettingsInterface { + 'Total Elements': number; + 'Grid Width': number; + 'Grid Height': number; + 'Total Threads': number; + hoveredElement: number; + swappedElement: number; + 'Prev Step': StepType; + 'Next Step': StepType; + 'Prev Swap Span': number; + 'Next Swap Span': number; + workLoads: number; + executeStep: boolean; + 'Randomize Values': () => void; + 'Execute Sort Step': () => void; + 'Log Elements': () => void; + 'Complete Sort': () => void; + sortSpeed: number; +} + +let init: SampleInit; +SampleInitFactoryWebGPU( + async ({ pageState, device, gui, presentationFormat, context, canvas }) => { + const maxWorkgroupsX = device.limits.maxComputeWorkgroupSizeX; + + const totalElementLengths = []; + for (let i = maxWorkgroupsX * 2; i >= 4; i /= 2) { + totalElementLengths.push(i); + } + + const settings: SettingsInterface = { + // number of cellElements. Must equal gridWidth * gridHeight and 'Total Threads' * 2 + 'Total Elements': 16, + // width of screen in cells. + 'Grid Width': 4, + // height of screen in cells + 'Grid Height': 4, + // number of threads to execute in a workgroup ('Total Threads', 1, 1) + 'Total Threads': 16 / 2, + // currently highlighted element + hoveredElement: 0, + // element the hoveredElement just swapped with, + swappedElement: 1, + // Previously executed step + 'Prev Step': 'NONE', + // Next step to execute + 'Next Step': 'FLIP_LOCAL', + // Max thread span of previous block + 'Prev Swap Span': 0, + // Max thread span of next block + 'Next Swap Span': 2, + // workloads to dispatch per frame, + workLoads: 1, + // Whether we will dispatch a workload this frame + executeStep: false, + 'Randomize Values': () => { + return; + }, + 'Execute Sort Step': () => { + return; + }, + 'Log Elements': () => { + return; + }, + 'Complete Sort': () => { + return; + }, + sortSpeed: 200, + }; + + // Initialize initial elements array + let elements = new Uint32Array( + Array.from({ length: settings['Total Elements'] }, (_, i) => i) + ); + + // Initialize elementsBuffer and elementsStagingBuffer + const elementsBufferSize = Float32Array.BYTES_PER_ELEMENT * 512; + // Initialize input, output, staging buffers + const elementsInputBuffer = device.createBuffer({ + size: elementsBufferSize, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + const elementsOutputBuffer = device.createBuffer({ + size: elementsBufferSize, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + const elementsStagingBuffer = device.createBuffer({ + size: elementsBufferSize, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + + // Create uniform buffer for compute shader + const computeUniformsBuffer = device.createBuffer({ + // width, height, blockHeight, algo + size: Float32Array.BYTES_PER_ELEMENT * 4, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + const computeBGDescript = createBindGroupDescriptor( + [0, 1, 2], + [ + GPUShaderStage.COMPUTE | GPUShaderStage.FRAGMENT, + GPUShaderStage.COMPUTE, + GPUShaderStage.COMPUTE, + ], + ['buffer', 'buffer', 'buffer'], + [{ type: 'read-only-storage' }, { type: 'storage' }, { type: 'uniform' }], + [ + [ + { buffer: elementsInputBuffer }, + { buffer: elementsOutputBuffer }, + { buffer: computeUniformsBuffer }, + ], + ], + 'NaiveBitonicSort', + device + ); + + let computePipeline = device.createComputePipeline({ + layout: device.createPipelineLayout({ + bindGroupLayouts: [computeBGDescript.bindGroupLayout], + }), + compute: { + module: device.createShaderModule({ + code: NaiveBitonicCompute(settings['Total Threads']), + }), + entryPoint: 'computeMain', + }, + }); + + // Create bitonic debug renderer + const renderPassDescriptor: GPURenderPassDescriptor = { + colorAttachments: [ + { + view: undefined, // Assigned later + + clearValue: { r: 0.1, g: 0.4, b: 0.5, a: 1.0 }, + loadOp: 'clear', + storeOp: 'store', + }, + ], + }; + + const bitonicDisplayRenderer = new BitonicDisplayRenderer( + device, + presentationFormat, + renderPassDescriptor, + ['default'], + computeBGDescript, + 'BitonicDisplay' + ); + + const resetExecutionInformation = () => { + totalThreadsCell.setValue(settings['Total Elements'] / 2); + + // Get new width and height of screen display in cells + const newCellWidth = + Math.sqrt(settings['Total Elements']) % 2 === 0 + ? Math.floor(Math.sqrt(settings['Total Elements'])) + : Math.floor(Math.sqrt(settings['Total Elements'] / 2)); + const newCellHeight = settings['Total Elements'] / newCellWidth; + gridWidthCell.setValue(newCellWidth); + gridHeightCell.setValue(newCellHeight); + + // Set prevStep to None (restart) and next step to FLIP + prevStepCell.setValue('NONE'); + nextStepCell.setValue('FLIP_LOCAL'); + + // Reset block heights + prevBlockHeightCell.setValue(0); + nextBlockHeightCell.setValue(2); + highestBlockHeight = 2; + }; + + const randomizeElementArray = () => { + let currentIndex = elements.length; + // While there are elements to shuffle + while (currentIndex !== 0) { + // Pick a remaining element + const randomIndex = Math.floor(Math.random() * currentIndex); + currentIndex -= 1; + [elements[currentIndex], elements[randomIndex]] = [ + elements[randomIndex], + elements[currentIndex], + ]; + } + }; + + const resizeElementArray = () => { + // Recreate elements array with new length + elements = new Uint32Array( + Array.from({ length: settings['Total Elements'] }, (_, i) => i) + ); + + resetExecutionInformation(); + + // Create new shader invocation with workgroupSize that reflects number of threads + computePipeline = device.createComputePipeline({ + layout: device.createPipelineLayout({ + bindGroupLayouts: [computeBGDescript.bindGroupLayout], + }), + compute: { + module: device.createShaderModule({ + code: NaiveBitonicCompute(settings['Total Elements'] / 2), + }), + entryPoint: 'computeMain', + }, + }); + // Randomize array elements + randomizeElementArray(); + highestBlockHeight = 2; + }; + + randomizeElementArray(); + + const setSwappedElement = () => { + let swappedIndex: number; + switch (settings['Next Step']) { + case 'FLIP_LOCAL': + { + const blockHeight = settings['Next Swap Span']; + const p2 = Math.floor(settings.hoveredElement / blockHeight) + 1; + const p3 = settings.hoveredElement % blockHeight; + swappedIndex = blockHeight * p2 - p3 - 1; + swappedElementCell.setValue(swappedIndex); + } + break; + case 'DISPERSE_LOCAL': + { + const blockHeight = settings['Next Swap Span']; + const halfHeight = blockHeight / 2; + swappedIndex = + settings.hoveredElement % blockHeight < halfHeight + ? settings.hoveredElement + halfHeight + : settings.hoveredElement - halfHeight; + swappedElementCell.setValue(swappedIndex); + } + break; + case 'NONE': { + swappedIndex = settings.hoveredElement; + swappedElementCell.setValue(swappedIndex); + } + default: + { + swappedIndex = settings.hoveredElement; + swappedElementCell.setValue(swappedIndex); + } + break; + } + }; + + let completeSortIntervalID: ReturnType | null = null; + const endSortInterval = () => { + if (completeSortIntervalID !== null) { + clearInterval(completeSortIntervalID); + completeSortIntervalID = null; + } + }; + const startSortInterval = () => { + completeSortIntervalID = setInterval(() => { + if (settings['Next Step'] === 'NONE') { + clearInterval(completeSortIntervalID); + completeSortIntervalID = null; + } + settings.executeStep = true; + setSwappedElement(); + }, settings.sortSpeed); + }; + + // At top level, basic information about the number of elements sorted and the number of threads + // deployed per workgroup. + gui.add(settings, 'Total Elements', totalElementLengths).onChange(() => { + endSortInterval(); + resizeElementArray(); + }); + const totalThreadsCell = gui.add(settings, 'Total Threads'); + + // Folder with functions that control the execution of the sort + const controlFolder = gui.addFolder('Sort Controls'); + controlFolder.add(settings, 'Execute Sort Step').onChange(() => { + endSortInterval(); + settings.executeStep = true; + }); + controlFolder.add(settings, 'Randomize Values').onChange(() => { + endSortInterval(); + randomizeElementArray(); + resetExecutionInformation(); + }); + controlFolder + .add(settings, 'Log Elements') + .onChange(() => console.log(elements)); + controlFolder.add(settings, 'Complete Sort').onChange(startSortInterval); + controlFolder.open(); + + // Folder with indexes of the hovered element + const hoverFolder = gui.addFolder('Hover Information'); + const hoveredElementCell = hoverFolder + .add(settings, 'hoveredElement') + .onChange(setSwappedElement); + const swappedElementCell = hoverFolder.add(settings, 'swappedElement'); + + // Additional Information about the execution state of the sort + const executionInformationFolder = gui.addFolder('Execution Information'); + const prevStepCell = executionInformationFolder.add(settings, 'Prev Step'); + const nextStepCell = executionInformationFolder.add(settings, 'Next Step'); + const prevBlockHeightCell = executionInformationFolder.add( + settings, + 'Prev Swap Span' + ); + const nextBlockHeightCell = executionInformationFolder.add( + settings, + 'Next Swap Span' + ); + const gridWidthCell = executionInformationFolder.add( + settings, + 'Grid Width' + ); + const gridHeightCell = executionInformationFolder.add( + settings, + 'Grid Height' + ); + + // Adjust styles of Function List Elements within GUI + const liFunctionElements = document.getElementsByClassName('cr function'); + for (let i = 0; i < liFunctionElements.length; i++) { + (liFunctionElements[i].children[0] as HTMLElement).style.display = 'flex'; + (liFunctionElements[i].children[0] as HTMLElement).style.justifyContent = + 'center'; + ( + liFunctionElements[i].children[0].children[1] as HTMLElement + ).style.position = 'absolute'; + } + + canvas.addEventListener('mousemove', (event) => { + const currWidth = canvas.getBoundingClientRect().width; + const currHeight = canvas.getBoundingClientRect().height; + const cellSize: [number, number] = [ + currWidth / settings['Grid Width'], + currHeight / settings['Grid Height'], + ]; + const xIndex = Math.floor(event.offsetX / cellSize[0]); + const yIndex = + settings['Grid Height'] - 1 - Math.floor(event.offsetY / cellSize[1]); + hoveredElementCell.setValue(yIndex * settings['Grid Width'] + xIndex); + settings.hoveredElement = yIndex * settings['Grid Width'] + xIndex; + }); + + // Deactivate interaction with select GUI elements + prevStepCell.domElement.style.pointerEvents = 'none'; + prevBlockHeightCell.domElement.style.pointerEvents = 'none'; + nextStepCell.domElement.style.pointerEvents = 'none'; + nextBlockHeightCell.domElement.style.pointerEvents = 'none'; + totalThreadsCell.domElement.style.pointerEvents = 'none'; + gridWidthCell.domElement.style.pointerEvents = 'none'; + gridHeightCell.domElement.style.pointerEvents = 'none'; + + let highestBlockHeight = 2; + + async function frame() { + if (!pageState.active) return; + + // Write elements buffer + device.queue.writeBuffer( + elementsInputBuffer, + 0, + elements.buffer, + elements.byteOffset, + elements.byteLength + ); + + const dims = new Float32Array([ + settings['Grid Width'], + settings['Grid Height'], + ]); + const stepDetails = new Uint32Array([ + StepEnum[settings['Next Step']], + settings['Next Swap Span'], + ]); + device.queue.writeBuffer( + computeUniformsBuffer, + 0, + dims.buffer, + dims.byteOffset, + dims.byteLength + ); + + device.queue.writeBuffer(computeUniformsBuffer, 8, stepDetails); + + renderPassDescriptor.colorAttachments[0].view = context + .getCurrentTexture() + .createView(); + + const commandEncoder = device.createCommandEncoder(); + bitonicDisplayRenderer.startRun(commandEncoder, { + width: settings['Grid Width'], + height: settings['Grid Height'], + }); + if ( + settings.executeStep && + highestBlockHeight !== settings['Total Elements'] * 2 + ) { + const computePassEncoder = commandEncoder.beginComputePass(); + computePassEncoder.setPipeline(computePipeline); + computePassEncoder.setBindGroup(0, computeBGDescript.bindGroups[0]); + computePassEncoder.dispatchWorkgroups(1); + computePassEncoder.end(); + + prevStepCell.setValue(settings['Next Step']); + prevBlockHeightCell.setValue(settings['Next Swap Span']); + nextBlockHeightCell.setValue(settings['Next Swap Span'] / 2); + if (settings['Next Swap Span'] === 1) { + highestBlockHeight *= 2; + nextStepCell.setValue( + highestBlockHeight === settings['Total Elements'] * 2 + ? 'NONE' + : 'FLIP_LOCAL' + ); + nextBlockHeightCell.setValue( + highestBlockHeight === settings['Total Elements'] * 2 + ? 0 + : highestBlockHeight + ); + } else { + nextStepCell.setValue('DISPERSE_LOCAL'); + } + commandEncoder.copyBufferToBuffer( + elementsOutputBuffer, + 0, + elementsStagingBuffer, + 0, + elementsBufferSize + ); + } + device.queue.submit([commandEncoder.finish()]); + + if (settings.executeStep) { + // Copy GPU element data to CPU + await elementsStagingBuffer.mapAsync( + GPUMapMode.READ, + 0, + elementsBufferSize + ); + const copyElementsBuffer = elementsStagingBuffer.getMappedRange( + 0, + elementsBufferSize + ); + // Get correct range of data from CPU copy of GPU Data + const elementsData = copyElementsBuffer.slice( + 0, + Uint32Array.BYTES_PER_ELEMENT * settings['Total Elements'] + ); + // Extract data + const elementsOutput = new Uint32Array(elementsData); + elementsStagingBuffer.unmap(); + elements = elementsOutput; + setSwappedElement(); + } + settings.executeStep = false; + requestAnimationFrame(frame); + } + requestAnimationFrame(frame); + } +).then((resultInit) => (init = resultInit)); + +const bitonicSortExample: () => JSX.Element = () => + makeSample({ + name: 'Bitonic Sort', + description: + "A naive bitonic sort algorithm executed on the GPU, based on tgfrerer's implementation at poniesandlight.co.uk/reflect/bitonic_merge_sort/. Each invocation of the bitonic sort shader dispatches a workgroup containing elements/2 threads. The GUI's Execution Information folder contains information about the sort's current state. The visualizer displays the sort's results as colored cells sorted from brightest to darkest.", + init, + gui: true, + sources: [ + { + name: __filename.substring(__dirname.length + 1), + contents: __SOURCE__, + }, + BitonicDisplayRenderer.sourceInfo, + { + name: '../../../shaders/fullscreenTexturedQuad.vert.wgsl', + contents: fullscreenTexturedQuad, + }, + { + name: './bitonicDisplay.frag.wgsl', + contents: bitonicDisplay, + }, + { + name: './bitonicCompute.frag.wgsl', + contents: NaiveBitonicCompute(16), + }, + ], + filename: __filename, + }); + +export default bitonicSortExample; diff --git a/src/sample/bitonicSort/utils.ts b/src/sample/bitonicSort/utils.ts new file mode 100644 index 00000000..fea2992f --- /dev/null +++ b/src/sample/bitonicSort/utils.ts @@ -0,0 +1,212 @@ +import { SampleInit } from '../../components/SampleLayout'; +import type { GUI } from 'dat.gui'; +import fullscreenTexturedQuad from '../../shaders/fullscreenTexturedQuad.wgsl'; + +type BindGroupBindingLayout = + | GPUBufferBindingLayout + | GPUTextureBindingLayout + | GPUSamplerBindingLayout + | GPUStorageTextureBindingLayout + | GPUExternalTextureBindingLayout; + +export type BindGroupsObjectsAndLayout = { + bindGroups: GPUBindGroup[]; + bindGroupLayout: GPUBindGroupLayout; +}; + +type ResourceTypeName = + | 'buffer' + | 'texture' + | 'sampler' + | 'externalTexture' + | 'storageTexture'; + +/** + * @param {number[]} bindings - The binding value of each resource in the bind group. + * @param {number[]} visibilities - The GPUShaderStage visibility of the resource at the corresponding index. + * @param {ResourceTypeName[]} resourceTypes - The resourceType at the corresponding index. + * @returns {BindGroupsObjectsAndLayout} An object containing an array of bindGroups and the bindGroupLayout they implement. + */ +export const createBindGroupDescriptor = ( + bindings: number[], + visibilities: number[], + resourceTypes: ResourceTypeName[], + resourceLayouts: BindGroupBindingLayout[], + resources: GPUBindingResource[][], + label: string, + device: GPUDevice +): BindGroupsObjectsAndLayout => { + const layoutEntries: GPUBindGroupLayoutEntry[] = []; + for (let i = 0; i < bindings.length; i++) { + const layoutEntry: any = {}; + layoutEntry.binding = bindings[i]; + layoutEntry.visibility = visibilities[i % visibilities.length]; + layoutEntry[resourceTypes[i]] = resourceLayouts[i]; + layoutEntries.push(layoutEntry); + } + + const bindGroupLayout = device.createBindGroupLayout({ + label: `${label}.bindGroupLayout`, + entries: layoutEntries, + }); + + const bindGroups: GPUBindGroup[] = []; + //i represent the bindGroup index, j represents the binding index of the resource within the bindgroup + //i=0, j=0 bindGroup: 0, binding: 0 + //i=1, j=1, bindGroup: 0, binding: 1 + //NOTE: not the same as @group(0) @binding(1) group index within the fragment shader is set within a pipeline + for (let i = 0; i < resources.length; i++) { + const groupEntries: GPUBindGroupEntry[] = []; + for (let j = 0; j < resources[0].length; j++) { + const groupEntry: any = {}; + groupEntry.binding = j; + groupEntry.resource = resources[i][j]; + groupEntries.push(groupEntry); + } + const newBindGroup = device.createBindGroup({ + label: `${label}.bindGroup${i}`, + layout: bindGroupLayout, + entries: groupEntries, + }); + bindGroups.push(newBindGroup); + } + + return { + bindGroups, + bindGroupLayout, + }; +}; + +export type ShaderKeyInterface = { + [K in T[number]]: number; +}; + +export type SampleInitParams = { + canvas: HTMLCanvasElement; + pageState: { active: boolean }; + gui?: GUI; + stats?: Stats; +}; + +interface DeviceInitParms { + device: GPUDevice; +} + +interface DeviceInit3DParams extends DeviceInitParms { + context: GPUCanvasContext; + presentationFormat: GPUTextureFormat; +} + +type CallbackSync3D = (params: SampleInitParams & DeviceInit3DParams) => void; +type CallbackAsync3D = ( + params: SampleInitParams & DeviceInit3DParams +) => Promise; + +type SampleInitCallback3D = CallbackSync3D | CallbackAsync3D; + +export const SampleInitFactoryWebGPU = async ( + callback: SampleInitCallback3D +): Promise => { + const init: SampleInit = async ({ canvas, pageState, gui, stats }) => { + const adapter = await navigator.gpu.requestAdapter(); + const device = await adapter.requestDevice(); + if (!pageState.active) return; + const context = canvas.getContext('webgpu') as GPUCanvasContext; + const devicePixelRatio = window.devicePixelRatio || 1; + canvas.width = canvas.clientWidth * devicePixelRatio; + canvas.height = canvas.clientHeight * devicePixelRatio; + const presentationFormat = navigator.gpu.getPreferredCanvasFormat(); + context.configure({ + device, + format: presentationFormat, + alphaMode: 'premultiplied', + }); + + callback({ + canvas, + pageState, + gui, + device, + context, + presentationFormat, + stats, + }); + }; + return init; +}; + +export abstract class Base2DRendererClass { + abstract switchBindGroup(name: string): void; + abstract startRun(commandEncoder: GPUCommandEncoder, ...args: any[]): void; + renderPassDescriptor: GPURenderPassDescriptor; + pipeline: GPURenderPipeline; + bindGroupMap: Record; + currentBindGroup: GPUBindGroup; + currentBindGroupName: string; + + executeRun( + commandEncoder: GPUCommandEncoder, + renderPassDescriptor: GPURenderPassDescriptor, + pipeline: GPURenderPipeline, + bindGroups: GPUBindGroup[] + ) { + const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor); + passEncoder.setPipeline(pipeline); + for (let i = 0; i < bindGroups.length; i++) { + passEncoder.setBindGroup(i, bindGroups[i]); + } + passEncoder.draw(6, 1, 0, 0); + passEncoder.end(); + } + + setUniformArguments( + device: GPUDevice, + uniformBuffer: GPUBuffer, + instance: T, + keys: K + ) { + for (let i = 0; i < keys.length; i++) { + device.queue.writeBuffer( + uniformBuffer, + i * 4, + new Float32Array([instance[keys[i]]]) + ); + } + } + + create2DRenderPipeline( + device: GPUDevice, + label: string, + bgLayouts: GPUBindGroupLayout[], + code: string, + presentationFormat: GPUTextureFormat + ) { + return device.createRenderPipeline({ + label: `${label}.pipeline`, + layout: device.createPipelineLayout({ + bindGroupLayouts: bgLayouts, + }), + vertex: { + module: device.createShaderModule({ + code: fullscreenTexturedQuad, + }), + entryPoint: 'vert_main', + }, + fragment: { + module: device.createShaderModule({ + code: code, + }), + entryPoint: 'frag_main', + targets: [ + { + format: presentationFormat, + }, + ], + }, + primitive: { + topology: 'triangle-list', + cullMode: 'none', + }, + }); + } +}