diff --git a/src/nodes/code/FunctionCallNode.js b/src/nodes/code/FunctionCallNode.js index 9f22cd11b4ff1d..9dd6d9d757752d 100644 --- a/src/nodes/code/FunctionCallNode.js +++ b/src/nodes/code/FunctionCallNode.js @@ -104,15 +104,38 @@ class FunctionCallNode extends TempNode { const inputs = functionNode.getInputs( builder ); const parameters = this.parameters; + // Track storage pointer bindings for this function + const storageBindingMap = {}; + const generateInput = ( node, inputNode ) => { const type = inputNode.type; const pointer = type === 'pointer'; + const storagePointer = type === 'storagePointer'; let output; - if ( pointer ) output = '&' + node.build( builder ); - else output = node.build( builder, type ); + if ( storagePointer ) { + + // Build the storage buffer node - this registers it as a uniform/binding + // and returns the property name (e.g., "nodeU1") + output = node.build( builder ); + + // Store the mapping from parameter name to the generated binding name + storageBindingMap[ inputNode.name ] = output; + + // Return null to indicate this parameter should not be in the call + return null; + + } else if ( pointer ) { + + output = '&' + node.build( builder ); + + } else { + + output = node.build( builder, type ); + + } return output; @@ -140,7 +163,14 @@ class FunctionCallNode extends TempNode { for ( let i = 0; i < parameters.length; i ++ ) { - params.push( generateInput( parameters[ i ], inputs[ i ] ) ); + const result = generateInput( parameters[ i ], inputs[ i ] ); + + // Only add non-storage pointer parameters to the call + if ( result !== null ) { + + params.push( result ); + + } } @@ -152,13 +182,25 @@ class FunctionCallNode extends TempNode { if ( node !== undefined ) { - params.push( generateInput( node, inputNode ) ); + const result = generateInput( node, inputNode ); + + // Only add non-storage pointer parameters to the call + if ( result !== null ) { + + params.push( result ); + + } } else { - error( `TSL: Input '${ inputNode.name }' not found in \'Fn()\'.` ); + // Only error for non-storage pointer parameters + if ( inputNode.type !== 'storagePointer' ) { + + error( `TSL: Input '${ inputNode.name }' not found in \'Fn()\'.` ); - params.push( generateInput( float( 0 ), inputNode ) ); + params.push( generateInput( float( 0 ), inputNode ) ); + + } } @@ -166,6 +208,14 @@ class FunctionCallNode extends TempNode { } + // Store storage binding map in builder's node data for FunctionNode to use + if ( Object.keys( storageBindingMap ).length > 0 ) { + + const functionNodeData = builder.getDataFromNode( functionNode ); + functionNodeData.storageBindingMap = storageBindingMap; + + } + const functionName = functionNode.build( builder, 'property' ); return `${ functionName }( ${ params.join( ', ' ) } )`; diff --git a/src/nodes/code/FunctionNode.js b/src/nodes/code/FunctionNode.js index 190e120b0ffc32..4e482fce71a88b 100644 --- a/src/nodes/code/FunctionNode.js +++ b/src/nodes/code/FunctionNode.js @@ -1,4 +1,5 @@ import CodeNode from './CodeNode.js'; +import { nodeObject } from '../tsl/TSLBase.js'; /** * This class represents a native shader function. It can be used to implement @@ -45,6 +46,15 @@ class FunctionNode extends CodeNode { super( code, includes, language ); + /** + * This flag can be used for type testing. + * + * @type {boolean} + * @readonly + * @default true + */ + this.isFunctionNode = true; + } /** @@ -114,6 +124,35 @@ class FunctionNode extends CodeNode { generate( builder, output ) { + // Get storage binding map if set by FunctionCallNode (before building includes) + const nodeData = builder.getDataFromNode( this ); + const storageBindingMap = nodeData.storageBindingMap || null; + + // Propagate storage binding map to included functions before building them + if ( storageBindingMap !== null ) { + + const includes = this.getIncludes( builder ); + + for ( const include of includes ) { + + if ( include.isFunctionNode || ( include.value && include.value.isFunctionNode ) ) { + + const includeNode = include.isFunctionNode ? include : include.value; + const includeData = builder.getDataFromNode( includeNode ); + + // Merge storage binding maps (don't overwrite if already set) + if ( ! includeData.storageBindingMap ) { + + includeData.storageBindingMap = storageBindingMap; + + } + + } + + } + + } + super.generate( builder ); const nodeFunction = this.getNodeFunction( builder ); @@ -133,7 +172,7 @@ class FunctionNode extends CodeNode { const propertyName = builder.getPropertyName( nodeCode ); - const code = this.getNodeFunction( builder ).getCode( propertyName ); + const code = this.getNodeFunction( builder ).getCode( propertyName, storageBindingMap ); nodeCode.code = code + '\n'; diff --git a/src/renderers/webgpu/nodes/WGSLNodeFunction.js b/src/renderers/webgpu/nodes/WGSLNodeFunction.js index e984f9952d3b0f..106c51a398aad9 100644 --- a/src/renderers/webgpu/nodes/WGSLNodeFunction.js +++ b/src/renderers/webgpu/nodes/WGSLNodeFunction.js @@ -75,6 +75,30 @@ const wgslTypeLib = { }; +/** + * Parses a storage pointer type like "ptr, read>" and extracts metadata. + * + * @param {string} ptrType - The full ptr type string. + * @return {Object|null} Object with { storageAccess, baseType } or null if not a storage pointer. + */ +const parseStoragePointer = ( ptrType ) => { + + // Match ptr or ptr + const match = ptrType.match( /^ptr\s*<\s*storage\s*,\s*(.+?)(?:\s*,\s*(read|write|read_write))?\s*>$/i ); + + if ( match ) { + + return { + baseType: match[ 1 ].trim(), + storageAccess: match[ 2 ] || 'read' + }; + + } + + return null; + +}; + const parse = ( source ) => { source = source.trim(); @@ -100,10 +124,24 @@ const parse = ( source ) => { const { name, type } = propsMatches[ i ]; let resolvedType = type; + let storagePointerInfo = null; if ( resolvedType.startsWith( 'ptr' ) ) { - resolvedType = 'pointer'; + // Check if it's a storage pointer + storagePointerInfo = parseStoragePointer( type ); + + if ( storagePointerInfo !== null ) { + + // Storage pointer - mark as special type + resolvedType = 'storagePointer'; + + } else { + + // Other pointer types (function, private, workgroup) + resolvedType = 'pointer'; + + } } else { @@ -117,7 +155,17 @@ const parse = ( source ) => { } - inputs.push( new NodeFunctionInput( resolvedType, name ) ); + const input = new NodeFunctionInput( resolvedType, name ); + + // Store storage pointer metadata if applicable + if ( storagePointerInfo !== null ) { + + input.storageAccess = storagePointerInfo.storageAccess; + input.storageBaseType = storagePointerInfo.baseType; + + } + + inputs.push( input ); } @@ -172,13 +220,159 @@ class WGSLNodeFunction extends NodeFunction { * This method returns the WGSL code of the node function. * * @param {string} [name=this.name] - The function's name. + * @param {Object} [storageBindingMap=null] - Map of storage pointer parameter names to their bound variable names. * @return {string} The shader code. */ - getCode( name = this.name ) { + getCode( name = this.name, storageBindingMap = null ) { const outputType = this.outputType !== 'void' ? '-> ' + this.outputType : ''; - return `fn ${ name } ( ${ this.inputsCode.trim() } ) ${ outputType }` + this.blockCode; + let inputsCode = this.inputsCode.trim(); + let blockCode = this.blockCode; + + // If we have storage bindings, filter them from the parameter list + // and replace references in the function body + if ( storageBindingMap !== null && Object.keys( storageBindingMap ).length > 0 ) { + + // Build list of non-storage parameters + // Split by comma but respect angle bracket nesting (for types like texture_storage_3d) + const filteredParams = []; + const storageParamNames = new Set(); + const inputParts = []; + let depth = 0; + let current = ''; + + for ( let i = 0; i < inputsCode.length; i ++ ) { + + const char = inputsCode[ i ]; + + if ( char === '<' ) { + + depth ++; + current += char; + + } else if ( char === '>' ) { + + depth --; + current += char; + + } else if ( char === ',' && depth === 0 ) { + + inputParts.push( current.trim() ); + current = ''; + + } else { + + current += char; + + } + + } + + if ( current.trim() !== '' ) { + + inputParts.push( current.trim() ); + + } + + for ( const part of inputParts ) { + + const trimmed = part.trim(); + if ( trimmed === '' ) continue; + + // Extract parameter name + const colonIndex = trimmed.indexOf( ':' ); + if ( colonIndex === - 1 ) continue; + + const paramName = trimmed.substring( 0, colonIndex ).trim(); + + // Check if this is a storage pointer parameter (either in map or has ptr someFunc(otherArg) + // We need to remove arguments that are now global bindings + for ( const paramName in storageBindingMap ) { + + const boundName = storageBindingMap[ paramName ]; + + // Remove as first argument: func( boundName, ... ) -> func( ... ) + blockCode = blockCode.replace( + new RegExp( '(\\w+\\s*\\()\\s*' + this._escapeRegex( boundName ) + '\\s*,\\s*', 'g' ), + '$1' + ); + + // Remove as middle argument: func( ..., boundName, ... ) -> func( ..., ... ) + blockCode = blockCode.replace( + new RegExp( ',\\s*' + this._escapeRegex( boundName ) + '\\s*,', 'g' ), + ',' + ); + + // Remove as last argument: func( ..., boundName ) -> func( ... ) + blockCode = blockCode.replace( + new RegExp( ',\\s*' + this._escapeRegex( boundName ) + '\\s*\\)', 'g' ), + ')' + ); + + // Remove as only argument: func( boundName ) -> func( ) + blockCode = blockCode.replace( + new RegExp( '(\\w+\\s*\\()\\s*' + this._escapeRegex( boundName ) + '\\s*\\)', 'g' ), + '$1)' + ); + + } + + } + + return `fn ${ name } ( ${ inputsCode } ) ${ outputType }` + blockCode; + + } + + /** + * Escapes special regex characters in a string. + * @private + */ + _escapeRegex( str ) { + + return str.replace( /[.*+?^${}()|[\]\\]/g, '\\$&' ); }