diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul-template.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul-template.ts index 26acc4b349e3c..168b1a7fe6586 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul-template.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul-template.ts @@ -703,7 +703,8 @@ export class SubgroupMemoryCacheHelper extends TSGBlockCacheHelperBase implement cacheMemoryModuleDefinitionWGSL() { return ` // ${this.cache_name}: -// cache_loading_points_per_subgroup: rows ${this.cache_loading_points_per_subgroup.rows}, cols ${this.cache_loading_points_per_subgroup.cols} +// cache_loading_points_per_subgroup: rows ${this.cache_loading_points_per_subgroup.rows}, cols ${ + this.cache_loading_points_per_subgroup.cols} // cache_loading_points_number_per_subgroup: ${this.cache_loading_points_number_per_subgroup} const ${this.cache_name}_loading_points_per_thread = ${this.cache_loading_points_per_thread}u; var ${this.cache_name}: array<${this.loading_point_WGSL_type}, ${this.cache_name}_loading_points_per_thread>; @@ -748,7 +749,7 @@ var ${this.cache_name}: array<${this.loading_point_WGSL_type}, ${this.c 'loading_point_id_subgroup', this.loading_points_prefer_major, loading_points_major_stride) .loading_point_col; return ` -// Update shared memory cache for ${this.variable_name} +// Update subgroup memory cache for ${this.variable_name} workgroupBarrier(); { let loading_point_row_cache_step_TSG_base: i32 = i32(${this.loading_point_row_cache_step_TSG_base_WGSL}); @@ -1044,12 +1045,6 @@ export function templatedMatMulProgram( const has_bias = typeof input_Bias !== 'undefined'; - /* - assert( - output.aggregated_batches === input_A.aggregated_batches * input_B.aggregated_batches, - `Expect output.batches ${output.aggregated_batches} === input_A.batches ${input_A.aggregated_batches} * - input_B.batches ${ input_B.aggregated_batches}`); - */ assert( !has_bias || output.aggregated_batches === input_Bias.aggregated_batches, `Expect input_bias.batches ${input_Bias?.aggregated_batches} === output.batches ${output.aggregated_batches}`); @@ -1159,7 +1154,7 @@ dividable by compute_block_shape ${PrintMKN(compute_block_shape)}`); assert( (!threads_within_TSG_major_M_not_N) || buffer_B_cache_type !== 'subgroup', `Using subgroup cache for input B require !threads_within_TSG_major_M_not_N.`); - const spatial_loop_order: 'M_outer'|'N_outer' = 'N_outer' as 'M_outer'|'N_outer'; + const spatial_loop_order: 'M_outer'|'N_outer' = 'N_outer' as 'M_outer' | 'N_outer'; // Might be parameter const unfold_spatial_M_loop: boolean = buffer_A_cache_type === 'subgroup'; const unfold_spatial_N_loop: boolean = buffer_B_cache_type === 'subgroup'; @@ -1207,15 +1202,17 @@ dividable by compute_block_shape ${PrintMKN(compute_block_shape)}`); // ----------------------------------------------------------------------------- // Uniform definition const uniforms_WGSL_info: UniformsArrayType = [ - // { name: 'input_A_batches', type: 'u32' }, - // {name: 'input_B_batches', type: 'u32'}, + {name: 'compute_blocks_per_thread_K', type: 'u32'}, ]; appendActivationUniforms(activation, uniforms_WGSL_info); // Uniforms values const programUniforms: ProgramUniform[] = [ - // { /* input_A_batches */ type: DataType.uint32, data: input_A.batches }, - // {/* input_B_batches */ type: DataType.uint32, data: input_B.aggregated_batches}, + { + // compute_blocks_per_thread_K + type: DataType.uint32, + data: Math.ceil(compute_blocks_per_workgroup.K / tensor_slice_factor) + }, ]; appendActivationUniformsData(activation, programUniforms); // Tensor shapes of all input/output variables must be pushed into programUniforms @@ -1301,9 +1298,15 @@ ${ input_A_tensor_loading_points_shape, input_B_tensor_loading_points_shape, ...(has_bias ? [output_tensor_loading_points_shape] : []) ]; + + // Push unaggregated batch dims for batch helper + const batch_helpers_dims = [output_batch_dims, input_A_batch_dims, input_B_batch_dims]; + programUniforms.push(...createTensorShapeVariables(...batch_helpers_dims)); + // Push aggregated inputs and output shape programUniforms.push(...createTensorShapeVariables(...input_tensors_shape)); programUniforms.push(...createTensorShapeVariables(output_tensor_loading_points_shape)); + const trival_shader = (shader_helper: ShaderHelper) => { // ----------------------------------------------------------------------------- @@ -1319,6 +1322,22 @@ ${ const output_variable = outputVariable('output', tensor_data_type, 3, output.buffer_layout.packed_vector_size); const input_variables = [input_A_variable, input_B_variable, ...(has_bias ? [input_Bias_variable!] : [])]; + + // ----------------------------------------------------------------------------- + // Batch helper variables + // ----------------------------------------------------------------------------- + const output_batch_helper = + internalVariable('output_batch_helper', tensor_data_type, output_batch_dims.length, 1); + const input_A_batch_helper = + internalVariable('input_A_batch_helper', tensor_data_type, input_A_batch_dims.length, 1); + const input_B_batch_helper = + internalVariable('input_B_batch_helper', tensor_data_type, input_B_batch_dims.length, 1); + // Currently assume bias has the same batches as output + // const input_Bias_batch_helper = has_bias? internalVariable('input_B_batch_helper', + // tensor_data_type,input_Bias_batch_dims.length, 1):undefined; + + const batch_helper_variables = [output_batch_helper, input_A_batch_helper, input_B_batch_helper]; + // ----------------------------------------------------------------------------- // WGSL types // ----------------------------------------------------------------------------- @@ -1350,31 +1369,40 @@ ${ ` // Convert output aggregated batch to vec2(input_A_aggregated_batch, input_B_aggregated_batch) fn ConvertOutputBatchToInputBatch(output_batch: u32) -> vec2 { - var input_A_batch: u32 = 0; - var input_B_batch: u32 = 0; - var degregating_output_batch = output_batch; + let batch_indices = ${output_batch_helper.offsetToIndices('output_batch')}; + + var input_A_indices: ${input_A_batch_helper.type.indices}; ${ - arrayMap( - Math.max(input_A_batch_dims.length, input_B_batch_dims.length), - i => ` - { - let dim = degregating_output_batch % ${output_batch_dims[output_batch_dims.length - 1 - i]}; - degregating_output_batch = degregating_output_batch / ${output_batch_dims[output_batch_dims.length - 1 - i]}; + u32LoopUpFrom0WGSL( + 'index_A', input_A_batch_dims.length, + (index_A: number) => ` ${ - ((i < input_A_batch_dims.length) && - (!input_A_broadcasted_dims.includes(input_A_batch_dims.length - 1 - i))) ? - ` - input_A_batch += dim * ${input_A_batch_dims.slice(input_A_batch_dims.length - i).reduce((a, b) => a * b, 1)}; -` : - ''} + input_A_batch_helper.indicesSet( + 'input_A_indices', index_A, + input_A_broadcasted_dims.includes(index_A) ? + 0 : + output_batch_helper.indicesGet( + 'batch_indices', index_A + output_batch_dims.length - input_A_batch_dims.length))} + `, + false, 4)} + + var input_B_indices: ${input_B_batch_helper.type.indices}; + ${ + u32LoopUpFrom0WGSL( + 'index_B', input_B_batch_dims.length, + (index_B: number) => ` ${ - ((i < input_B_batch_dims.length) && - (!input_B_broadcasted_dims.includes(input_B_batch_dims.length - 1 - i))) ? - ` - input_B_batch += dim * ${input_B_batch_dims.slice(input_B_batch_dims.length - i).reduce((a, b) => a * b, 1)}; -` : - ''} - }`).join('\n')} + input_B_batch_helper.indicesSet( + 'input_B_indices', index_B, + input_B_broadcasted_dims.includes(index_B) ? + 0 : + output_batch_helper.indicesGet( + 'batch_indices', index_B + output_batch_dims.length - input_B_batch_dims.length))} + `, + false, 4)} + + let input_A_batch: u32 = ${input_A_batch_helper.indicesToOffset('input_A_indices')}; + let input_B_batch: u32 = ${input_B_batch_helper.indicesToOffset('input_B_indices')}; return vec2(input_A_batch, input_B_batch); }`); @@ -1384,14 +1412,16 @@ fn ConvertOutputBatchToInputBatch(output_batch: u32) -> vec2 { function_name: string, buffer_variable: IndicesHelper, packed_type_in_buffer_WGSL: string, - buffer_layout: BufferLayoutInfo, + // buffer_layout: BufferLayoutInfo, + loading_points_layout: 'NHW'|'NWH', + loading_points_matrix_dims_vec2i_WGSL: string, ) => { if (buffer_variable.rank !== 3) { throw new Error(`Expected buffer A of rank 3, got ${buffer_variable.rank}.`); } assert(!helper_functions.has(function_name), `Redefining function ${function_name}`); - const {loading_points_layout, buffer_inner_boundary, buffer_outer_boundary} = buffer_layout; + // const {loading_points_layout} = buffer_layout; helper_functions.set( function_name, @@ -1404,8 +1434,8 @@ fn ${function_name}(loading_point_row: i32, loading_point_col: i32, batch: u32) let inner_dim = ${loading_points_layout === 'NHW' ? 'loading_point_row' : 'loading_point_col'}; let outer_dim = ${loading_points_layout === 'NHW' ? 'loading_point_col' : 'loading_point_row'}; - if (inner_dim >= 0 && inner_dim < ${buffer_inner_boundary} && outer_dim >= 0 && outer_dim < ${ - buffer_outer_boundary}) + if (inner_dim >= 0 && inner_dim < ${loading_points_matrix_dims_vec2i_WGSL}.x && outer_dim >= 0 && outer_dim < ${ + loading_points_matrix_dims_vec2i_WGSL}.y) { // Within boundary, read from buffer[batch][inner_dim][outer_dim]. var indices: ${buffer_variable.type.indices}; @@ -1420,26 +1450,31 @@ fn ${function_name}(loading_point_row: i32, loading_point_col: i32, batch: u32) }; const buffer_A_loading_function_name = createLoadBufferHelperFunction( - 'LoadFromBufferA', input_A_variable, input_A_packed_type_WGSL, input_A.buffer_layout); + 'LoadFromBufferA', input_A_variable, input_A_packed_type_WGSL, input_A.buffer_layout.loading_points_layout, + 'input_A_loading_points_matrix_dims'); const buffer_B_loading_function_name = createLoadBufferHelperFunction( - 'LoadFromBufferB', input_B_variable, input_B_packed_type_WGSL, input_B.buffer_layout); + 'LoadFromBufferB', input_B_variable, input_B_packed_type_WGSL, input_B.buffer_layout.loading_points_layout, + 'input_B_loading_points_matrix_dims'); const buffer_Bias_loading_function_name = has_bias ? createLoadBufferHelperFunction( - 'LoadFromBufferBias', input_Bias_variable!, output_packed_type_WGSL, input_Bias.buffer_layout) : + 'LoadFromBufferBias', input_Bias_variable!, output_packed_type_WGSL, + input_Bias.buffer_layout.loading_points_layout, 'input_Bias_loading_points_matrix_dims') : undefined; const createStoreBufferHelperFunction = ( function_name: string, buffer_variable: IndicesHelper, packed_type_in_buffer_WGSL: string, - buffer_layout: BufferLayoutInfo, + // buffer_layout: BufferLayoutInfo, + loading_points_layout: 'NHW'|'NWH', + loading_points_matrix_dims_vec2i_WGSL: string, ) => { if (buffer_variable.rank !== 3) { throw new Error(`Expected buffer A of rank 3, got ${buffer_variable.rank}.`); } assert(!helper_functions.has(function_name), `Redefining function ${function_name}`); - const {loading_points_layout, buffer_inner_boundary, buffer_outer_boundary} = buffer_layout; + // const {loading_points_layout, buffer_inner_boundary, buffer_outer_boundary} = buffer_layout; helper_functions.set( function_name, @@ -1450,8 +1485,8 @@ fn ${function_name}(value: ${packed_type_in_buffer_WGSL}, loading_point_row: i32 let inner_dim = ${loading_points_layout === 'NHW' ? 'loading_point_row' : 'loading_point_col'}; let outer_dim = ${loading_points_layout === 'NHW' ? 'loading_point_col' : 'loading_point_row'}; - if (inner_dim >= 0 && inner_dim < ${buffer_inner_boundary} && outer_dim >= 0 && outer_dim < ${ - buffer_outer_boundary}) + if (inner_dim >= 0 && inner_dim < ${loading_points_matrix_dims_vec2i_WGSL}.x && outer_dim >= 0 && outer_dim < ${ + loading_points_matrix_dims_vec2i_WGSL}.y) { // Within boundary, read from buffer[batch][inner_dim][outer_dim]. var indices: ${buffer_variable.type.indices}; @@ -1465,7 +1500,8 @@ fn ${function_name}(value: ${packed_type_in_buffer_WGSL}, loading_point_row: i32 }; const buffer_Output_storing_function_name = createStoreBufferHelperFunction( - 'StoreToBufferOutput', output_variable, output_packed_type_WGSL, output.buffer_layout); + 'StoreToBufferOutput', output_variable, output_packed_type_WGSL, output.buffer_layout.loading_points_layout, + 'output_loading_points_matrix_dims'); const callBufferLoadingFunctionExprBuilder = (buffer_loading_function: string, batch_variable: string) => ((loading_point_row: number|string, loading_point_col: number|string) => @@ -1927,10 +1963,10 @@ let compute_block_thread_K_in_TSG = compute_block_thread_K_outer_loop + ${ valid_loop_vars_or_values.compute_block_thread_K_inner!}; // Compute block K is unbiased from workgroup to global let compute_block_K_global_biased = compute_block_thread_K_in_TSG + compute_blocks_TSG_workgroup_base_K; -if (compute_block_thread_K_in_TSG < compute_blocks_per_thread_K) { +if (compute_block_thread_K_in_TSG < uniforms.compute_blocks_per_thread_K) { `, loop_body_after_inner_loop: () => ` - // End of condition (compute_block_thread_K_in_TSG < compute_blocks_per_thread_K) + // End of condition (compute_block_thread_K_in_TSG < uniforms.compute_blocks_per_thread_K) } // End of compute_block_thread_K_inner loop`, disable_unfold: !unfold_K_inner_loop, @@ -2018,7 +2054,7 @@ const compute_block_shape_K = ${compute_block_shape.K}u; const compute_block_shape_N = ${compute_block_shape.N}u; const compute_blocks_per_workgroup_M = ${compute_blocks_per_workgroup.M}u; -const compute_blocks_per_workgroup_K = ${compute_blocks_per_workgroup.K}u; +// const compute_blocks_per_workgroup_K = ${compute_blocks_per_workgroup.K}u; const compute_blocks_per_workgroup_N = ${compute_blocks_per_workgroup.N}u; const A_loading_point_rows_per_compute_block = ${loading_points_per_compute_block_A.rows}u; @@ -2032,7 +2068,8 @@ const threads_along_M_per_TSG = ${threads_along_M_per_TSG}u; const threads_along_N_per_TSG = ${threads_along_N_per_TSG}u; const_assert(threads_along_M_per_TSG * threads_along_N_per_TSG == threads_per_TSG); -const compute_blocks_per_thread_K = ${Math.ceil(compute_blocks_per_workgroup.K / tensor_slice_factor)}u; +// compute_blocks_per_thread_K = ceil(compute_blocks_per_workgroup_K / tensor_slice_factor) comes from uniform +// to allow more shader reusing. const compute_blocks_per_thread_M = ${ compute_blocks_per_thread_M}u; // compute_blocks_per_workgroup_M / threads_along_M_per_TSG const compute_blocks_per_thread_N = ${ @@ -2046,7 +2083,23 @@ alias OutputPackedType = ${output_packed_type_WGSL}; // Bias, if any, should // 2D array type holding a whole output block alias ComputeBlockOutputType = ${compute_block_output_type_WGSL}; -${shader_helper.registerUniforms(uniforms_WGSL_info).declareVariables(...input_variables, output_variable)} +${ + shader_helper.registerUniforms(uniforms_WGSL_info) + .registerInternalVariables(...batch_helper_variables) + .declareVariables(...input_variables, output_variable)} + +// Workgroup-cached uniforms +var input_A_loading_points_matrix_dims: vec2; +var input_B_loading_points_matrix_dims: vec2; +${has_bias ? 'var input_Bias_loading_points_matrix_dims: vec2;' : ''} +var output_loading_points_matrix_dims: vec2; + +fn WorkgroupInit() { + input_A_loading_points_matrix_dims = vec2(${input_A_variable.shape}.yz); + input_B_loading_points_matrix_dims = vec2(${input_B_variable.shape}.yz); + ${has_bias ? `input_Bias_loading_points_matrix_dims = vec2(${input_Bias_variable!.shape}.yz);` : ''} + output_loading_points_matrix_dims = vec2(${output_variable.shape}.yz); +} // Invocation scope variables var tensor_slice_group: u32; @@ -2121,9 +2174,14 @@ ${buffer_B_cache.cacheMemoryModuleDefinitionWGSL()} fn main( @builtin(workgroup_id) wid: vec3u, @builtin(local_invocation_id) lid: vec3u, - ${subgroup_cache_params.use_subgroups?'@builtin(subgroup_invocation_id) subgroup_id: u32,':''} - ${subgroup_cache_params.use_subgroups?'@builtin(subgroup_size) subgroup_size: u32,':''} + ${subgroup_cache_params.use_subgroups ? '@builtin(subgroup_invocation_id) subgroup_id: u32,' : ''} + ${subgroup_cache_params.use_subgroups ? '@builtin(subgroup_size) subgroup_size: u32,' : ''} ) { + if(all(lid == vec3u())) { + WorkgroupInit(); + } + workgroupBarrier(); + let compute_blocks_workgroup_global_base_M = wid.x * compute_blocks_per_workgroup_M; // compute_blocks_workgroup_global_base_K == 0 let compute_blocks_workgroup_global_base_N = wid.y * compute_blocks_per_workgroup_N; @@ -2148,8 +2206,8 @@ fn main( 'thread_id_in_TSG / threads_along_M_per_TSG'}; loading_point_A_TSG_global_row_base = compute_blocks_workgroup_global_base_M * A_loading_point_rows_per_compute_block; - loading_point_A_TSG_global_col_base = tensor_slice_group * compute_blocks_per_thread_K * A_loading_point_cols_per_compute_block; - loading_point_B_TSG_global_row_base = tensor_slice_group * compute_blocks_per_thread_K * B_loading_point_rows_per_compute_block; + loading_point_A_TSG_global_col_base = tensor_slice_group * uniforms.compute_blocks_per_thread_K * A_loading_point_cols_per_compute_block; + loading_point_B_TSG_global_row_base = tensor_slice_group * uniforms.compute_blocks_per_thread_K * B_loading_point_rows_per_compute_block; loading_point_B_TSG_global_col_base = compute_blocks_workgroup_global_base_N * B_loading_point_cols_per_compute_block; loading_point_Output_TSG_global_row_base = compute_blocks_workgroup_global_base_M * Output_loading_point_rows_per_compute_block; loading_point_Output_TSG_global_col_base = compute_blocks_workgroup_global_base_N * Output_loading_point_cols_per_compute_block; @@ -2157,7 +2215,7 @@ fn main( compute_blocks_thread_TSG_base_M = thread_M_in_TSG * compute_blocks_per_thread_M; compute_blocks_thread_TSG_base_N = thread_N_in_TSG * compute_blocks_per_thread_N; - let compute_blocks_TSG_workgroup_base_K = tensor_slice_group * compute_blocks_per_thread_K; + let compute_blocks_TSG_workgroup_base_K = tensor_slice_group * uniforms.compute_blocks_per_thread_K; // Buffer cache function-scope definition, if any ${buffer_A_cache.cacheMemoryFunctionDefinitionWGSL()} @@ -2167,7 +2225,7 @@ fn main( for ( var compute_block_thread_K_outer_loop: u32 = 0; - compute_block_thread_K_outer_loop < compute_blocks_per_thread_K; + compute_block_thread_K_outer_loop < uniforms.compute_blocks_per_thread_K; compute_block_thread_K_outer_loop += compute_block_thread_K_inner_loop_step ) { // Handle cache steps along K, if any @@ -2219,17 +2277,27 @@ fn main( activation, workgroup_params, compute_block_shape, - compute_blocks_per_workgroup, + output_compute_blocks_per_workgroup: {M: compute_blocks_per_workgroup.M, N: compute_blocks_per_workgroup.N}, tensor_slice_factor, - input_A_layout: input_A.buffer_layout, - input_B_layout: input_B.buffer_layout, - output_layout: output.buffer_layout, + input_A_layout: input_A.buffer_layout.loading_points_layout, + input_B_layout: input_B.buffer_layout.loading_points_layout, + output_layout: output.buffer_layout.loading_points_layout, + packing_size: { + input_A: input_A.buffer_layout.packed_vector_size, + input_B: input_B.buffer_layout.packed_vector_size, + output: output.buffer_layout.packed_vector_size, + }, + batch_indices: { + input_A_broadcasted_dims, + input_B_broadcasted_dims, + }, }; // LOG(`fatal`, `templatedMatMulProgram Return`); return { - name: `MatMulTemplatedTrival-${Object.entries(cacheKey).map((entry) => entry.map(x => JSON.stringify(x)).join(':')).join('-')};`, + name: `MatMulTemplatedTrival-${ + Object.entries(cacheKey).map((entry) => entry.map(x => JSON.stringify(x)).join(':')).join('-')};`, shaderCache: { // hint: `${activation.activation};${Object.entries(workgroup_params).map((entry) => // entry.join(':')).join(';')};`, @@ -2267,7 +2335,7 @@ export const templatedMatMulDriver = ( assert( batchAggregatedInputs[0].dims[0] * batchAggregatedInputs[1].dims[0] === batchAggregatedOutputShape[0], `MatMul output batches ${batchAggregatedOutputShape[0]} should be equal to \ -input A batches ${batchAggregatedInputs[0].dims[0]} * input B batches ${batchAggregatedInputs[1].dims[0]}`); + input A batches ${batchAggregatedInputs[0].dims[0]} * input B batches ${batchAggregatedInputs[1].dims[0]}`); */ const createMatrixTensorInfo = ( @@ -2349,12 +2417,7 @@ input A batches ${batchAggregatedInputs[0].dims[0]} * input B batches ${batchAgg LOG(`fatal`, `Before templatedMatMulProgram`); - return templatedMatMulProgram( - op_params, - workgroup_params, - schedule_params, - subgroup_cache_params - ); + return templatedMatMulProgram(op_params, workgroup_params, schedule_params, subgroup_cache_params); } export const createComputeBlockMatmulNaiveProgramInfo =