Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use uniform for batch broadcasting and K
Browse files Browse the repository at this point in the history
jiangzhaoming committed Nov 7, 2024
1 parent 08d98ee commit d567dfe
Showing 1 changed file with 133 additions and 70 deletions.
203 changes: 133 additions & 70 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul-template.ts
Original file line number Diff line number Diff line change
@@ -703,7 +703,8 @@ export class SubgroupMemoryCacheHelper extends TSGBlockCacheHelperBase implement
cacheMemoryModuleDefinitionWGSL() {
return `
// ${this.cache_name}:
// cache_loading_points_per_subgroup: rows ${this.cache_loading_points_per_subgroup.rows}, cols ${this.cache_loading_points_per_subgroup.cols}
// cache_loading_points_per_subgroup: rows ${this.cache_loading_points_per_subgroup.rows}, cols ${
this.cache_loading_points_per_subgroup.cols}
// cache_loading_points_number_per_subgroup: ${this.cache_loading_points_number_per_subgroup}
const ${this.cache_name}_loading_points_per_thread = ${this.cache_loading_points_per_thread}u;
var<private> ${this.cache_name}: array<${this.loading_point_WGSL_type}, ${this.cache_name}_loading_points_per_thread>;
@@ -748,7 +749,7 @@ var<private> ${this.cache_name}: array<${this.loading_point_WGSL_type}, ${this.c
'loading_point_id_subgroup', this.loading_points_prefer_major, loading_points_major_stride)
.loading_point_col;
return `
// Update shared memory cache for ${this.variable_name}
// Update subgroup memory cache for ${this.variable_name}
workgroupBarrier();
{
let loading_point_row_cache_step_TSG_base: i32 = i32(${this.loading_point_row_cache_step_TSG_base_WGSL});
@@ -1044,12 +1045,6 @@ export function templatedMatMulProgram(

const has_bias = typeof input_Bias !== 'undefined';

/*
assert(
output.aggregated_batches === input_A.aggregated_batches * input_B.aggregated_batches,
`Expect output.batches ${output.aggregated_batches} === input_A.batches ${input_A.aggregated_batches} *
input_B.batches ${ input_B.aggregated_batches}`);
*/
assert(
!has_bias || output.aggregated_batches === input_Bias.aggregated_batches,
`Expect input_bias.batches ${input_Bias?.aggregated_batches} === output.batches ${output.aggregated_batches}`);
@@ -1159,7 +1154,7 @@ dividable by compute_block_shape ${PrintMKN(compute_block_shape)}`);
assert(
(!threads_within_TSG_major_M_not_N) || buffer_B_cache_type !== 'subgroup',
`Using subgroup cache for input B require !threads_within_TSG_major_M_not_N.`);
const spatial_loop_order: 'M_outer'|'N_outer' = 'N_outer' as 'M_outer'|'N_outer';
const spatial_loop_order: 'M_outer'|'N_outer' = 'N_outer' as 'M_outer' | 'N_outer';
// Might be parameter
const unfold_spatial_M_loop: boolean = buffer_A_cache_type === 'subgroup';
const unfold_spatial_N_loop: boolean = buffer_B_cache_type === 'subgroup';
@@ -1207,15 +1202,17 @@ dividable by compute_block_shape ${PrintMKN(compute_block_shape)}`);
// -----------------------------------------------------------------------------
// Uniform definition
const uniforms_WGSL_info: UniformsArrayType = [
// { name: 'input_A_batches', type: 'u32' },
// {name: 'input_B_batches', type: 'u32'},
{name: 'compute_blocks_per_thread_K', type: 'u32'},
];
appendActivationUniforms(activation, uniforms_WGSL_info);

// Uniforms values
const programUniforms: ProgramUniform[] = [
// { /* input_A_batches */ type: DataType.uint32, data: input_A.batches },
// {/* input_B_batches */ type: DataType.uint32, data: input_B.aggregated_batches},
{
// compute_blocks_per_thread_K
type: DataType.uint32,
data: Math.ceil(compute_blocks_per_workgroup.K / tensor_slice_factor)
},
];
appendActivationUniformsData(activation, programUniforms);
// Tensor shapes of all input/output variables must be pushed into programUniforms
@@ -1301,9 +1298,15 @@ ${
input_A_tensor_loading_points_shape, input_B_tensor_loading_points_shape,
...(has_bias ? [output_tensor_loading_points_shape] : [])
];

// Push unaggregated batch dims for batch helper
const batch_helpers_dims = [output_batch_dims, input_A_batch_dims, input_B_batch_dims];
programUniforms.push(...createTensorShapeVariables(...batch_helpers_dims));
// Push aggregated inputs and output shape
programUniforms.push(...createTensorShapeVariables(...input_tensors_shape));
programUniforms.push(...createTensorShapeVariables(output_tensor_loading_points_shape));


const trival_shader =
(shader_helper: ShaderHelper) => {
// -----------------------------------------------------------------------------
@@ -1319,6 +1322,22 @@ ${
const output_variable = outputVariable('output', tensor_data_type, 3, output.buffer_layout.packed_vector_size);

const input_variables = [input_A_variable, input_B_variable, ...(has_bias ? [input_Bias_variable!] : [])];

// -----------------------------------------------------------------------------
// Batch helper variables
// -----------------------------------------------------------------------------
const output_batch_helper =
internalVariable('output_batch_helper', tensor_data_type, output_batch_dims.length, 1);
const input_A_batch_helper =
internalVariable('input_A_batch_helper', tensor_data_type, input_A_batch_dims.length, 1);
const input_B_batch_helper =
internalVariable('input_B_batch_helper', tensor_data_type, input_B_batch_dims.length, 1);
// Currently assume bias has the same batches as output
// const input_Bias_batch_helper = has_bias? internalVariable('input_B_batch_helper',
// tensor_data_type,input_Bias_batch_dims.length, 1):undefined;

const batch_helper_variables = [output_batch_helper, input_A_batch_helper, input_B_batch_helper];

// -----------------------------------------------------------------------------
// WGSL types
// -----------------------------------------------------------------------------
@@ -1350,31 +1369,40 @@ ${
`
// Convert output aggregated batch to vec2(input_A_aggregated_batch, input_B_aggregated_batch)
fn ConvertOutputBatchToInputBatch(output_batch: u32) -> vec2<u32> {
var input_A_batch: u32 = 0;
var input_B_batch: u32 = 0;
var degregating_output_batch = output_batch;
let batch_indices = ${output_batch_helper.offsetToIndices('output_batch')};
var input_A_indices: ${input_A_batch_helper.type.indices};
${
arrayMap(
Math.max(input_A_batch_dims.length, input_B_batch_dims.length),
i => `
{
let dim = degregating_output_batch % ${output_batch_dims[output_batch_dims.length - 1 - i]};
degregating_output_batch = degregating_output_batch / ${output_batch_dims[output_batch_dims.length - 1 - i]};
u32LoopUpFrom0WGSL(
'index_A', input_A_batch_dims.length,
(index_A: number) => `
${
((i < input_A_batch_dims.length) &&
(!input_A_broadcasted_dims.includes(input_A_batch_dims.length - 1 - i))) ?
`
input_A_batch += dim * ${input_A_batch_dims.slice(input_A_batch_dims.length - i).reduce((a, b) => a * b, 1)};
` :
''}
input_A_batch_helper.indicesSet(
'input_A_indices', index_A,
input_A_broadcasted_dims.includes(index_A) ?
0 :
output_batch_helper.indicesGet(
'batch_indices', index_A + output_batch_dims.length - input_A_batch_dims.length))}
`,
false, 4)}
var input_B_indices: ${input_B_batch_helper.type.indices};
${
u32LoopUpFrom0WGSL(
'index_B', input_B_batch_dims.length,
(index_B: number) => `
${
((i < input_B_batch_dims.length) &&
(!input_B_broadcasted_dims.includes(input_B_batch_dims.length - 1 - i))) ?
`
input_B_batch += dim * ${input_B_batch_dims.slice(input_B_batch_dims.length - i).reduce((a, b) => a * b, 1)};
` :
''}
}`).join('\n')}
input_B_batch_helper.indicesSet(
'input_B_indices', index_B,
input_B_broadcasted_dims.includes(index_B) ?
0 :
output_batch_helper.indicesGet(
'batch_indices', index_B + output_batch_dims.length - input_B_batch_dims.length))}
`,
false, 4)}
let input_A_batch: u32 = ${input_A_batch_helper.indicesToOffset('input_A_indices')};
let input_B_batch: u32 = ${input_B_batch_helper.indicesToOffset('input_B_indices')};
return vec2<u32>(input_A_batch, input_B_batch);
}`);
@@ -1384,14 +1412,16 @@ fn ConvertOutputBatchToInputBatch(output_batch: u32) -> vec2<u32> {
function_name: string,
buffer_variable: IndicesHelper,
packed_type_in_buffer_WGSL: string,
buffer_layout: BufferLayoutInfo,
// buffer_layout: BufferLayoutInfo,
loading_points_layout: 'NHW'|'NWH',
loading_points_matrix_dims_vec2i_WGSL: string,
) => {
if (buffer_variable.rank !== 3) {
throw new Error(`Expected buffer A of rank 3, got ${buffer_variable.rank}.`);
}
assert(!helper_functions.has(function_name), `Redefining function ${function_name}`);

const {loading_points_layout, buffer_inner_boundary, buffer_outer_boundary} = buffer_layout;
// const {loading_points_layout} = buffer_layout;

helper_functions.set(
function_name,
@@ -1404,8 +1434,8 @@ fn ${function_name}(loading_point_row: i32, loading_point_col: i32, batch: u32)
let inner_dim = ${loading_points_layout === 'NHW' ? 'loading_point_row' : 'loading_point_col'};
let outer_dim = ${loading_points_layout === 'NHW' ? 'loading_point_col' : 'loading_point_row'};
if (inner_dim >= 0 && inner_dim < ${buffer_inner_boundary} && outer_dim >= 0 && outer_dim < ${
buffer_outer_boundary})
if (inner_dim >= 0 && inner_dim < ${loading_points_matrix_dims_vec2i_WGSL}.x && outer_dim >= 0 && outer_dim < ${
loading_points_matrix_dims_vec2i_WGSL}.y)
{
// Within boundary, read from buffer[batch][inner_dim][outer_dim].
var indices: ${buffer_variable.type.indices};
@@ -1420,26 +1450,31 @@ fn ${function_name}(loading_point_row: i32, loading_point_col: i32, batch: u32)
};

const buffer_A_loading_function_name = createLoadBufferHelperFunction(
'LoadFromBufferA', input_A_variable, input_A_packed_type_WGSL, input_A.buffer_layout);
'LoadFromBufferA', input_A_variable, input_A_packed_type_WGSL, input_A.buffer_layout.loading_points_layout,
'input_A_loading_points_matrix_dims');
const buffer_B_loading_function_name = createLoadBufferHelperFunction(
'LoadFromBufferB', input_B_variable, input_B_packed_type_WGSL, input_B.buffer_layout);
'LoadFromBufferB', input_B_variable, input_B_packed_type_WGSL, input_B.buffer_layout.loading_points_layout,
'input_B_loading_points_matrix_dims');
const buffer_Bias_loading_function_name = has_bias ?
createLoadBufferHelperFunction(
'LoadFromBufferBias', input_Bias_variable!, output_packed_type_WGSL, input_Bias.buffer_layout) :
'LoadFromBufferBias', input_Bias_variable!, output_packed_type_WGSL,
input_Bias.buffer_layout.loading_points_layout, 'input_Bias_loading_points_matrix_dims') :
undefined;

const createStoreBufferHelperFunction = (
function_name: string,
buffer_variable: IndicesHelper,
packed_type_in_buffer_WGSL: string,
buffer_layout: BufferLayoutInfo,
// buffer_layout: BufferLayoutInfo,
loading_points_layout: 'NHW'|'NWH',
loading_points_matrix_dims_vec2i_WGSL: string,
) => {
if (buffer_variable.rank !== 3) {
throw new Error(`Expected buffer A of rank 3, got ${buffer_variable.rank}.`);
}
assert(!helper_functions.has(function_name), `Redefining function ${function_name}`);

const {loading_points_layout, buffer_inner_boundary, buffer_outer_boundary} = buffer_layout;
// const {loading_points_layout, buffer_inner_boundary, buffer_outer_boundary} = buffer_layout;

helper_functions.set(
function_name,
@@ -1450,8 +1485,8 @@ fn ${function_name}(value: ${packed_type_in_buffer_WGSL}, loading_point_row: i32
let inner_dim = ${loading_points_layout === 'NHW' ? 'loading_point_row' : 'loading_point_col'};
let outer_dim = ${loading_points_layout === 'NHW' ? 'loading_point_col' : 'loading_point_row'};
if (inner_dim >= 0 && inner_dim < ${buffer_inner_boundary} && outer_dim >= 0 && outer_dim < ${
buffer_outer_boundary})
if (inner_dim >= 0 && inner_dim < ${loading_points_matrix_dims_vec2i_WGSL}.x && outer_dim >= 0 && outer_dim < ${
loading_points_matrix_dims_vec2i_WGSL}.y)
{
// Within boundary, read from buffer[batch][inner_dim][outer_dim].
var indices: ${buffer_variable.type.indices};
@@ -1465,7 +1500,8 @@ fn ${function_name}(value: ${packed_type_in_buffer_WGSL}, loading_point_row: i32
};

const buffer_Output_storing_function_name = createStoreBufferHelperFunction(
'StoreToBufferOutput', output_variable, output_packed_type_WGSL, output.buffer_layout);
'StoreToBufferOutput', output_variable, output_packed_type_WGSL, output.buffer_layout.loading_points_layout,
'output_loading_points_matrix_dims');

const callBufferLoadingFunctionExprBuilder = (buffer_loading_function: string, batch_variable: string) =>
((loading_point_row: number|string, loading_point_col: number|string) =>
@@ -1927,10 +1963,10 @@ let compute_block_thread_K_in_TSG = compute_block_thread_K_outer_loop + ${
valid_loop_vars_or_values.compute_block_thread_K_inner!};
// Compute block K is unbiased from workgroup to global
let compute_block_K_global_biased = compute_block_thread_K_in_TSG + compute_blocks_TSG_workgroup_base_K;
if (compute_block_thread_K_in_TSG < compute_blocks_per_thread_K) {
if (compute_block_thread_K_in_TSG < uniforms.compute_blocks_per_thread_K) {
`,
loop_body_after_inner_loop: () => `
// End of condition (compute_block_thread_K_in_TSG < compute_blocks_per_thread_K)
// End of condition (compute_block_thread_K_in_TSG < uniforms.compute_blocks_per_thread_K)
}
// End of compute_block_thread_K_inner loop`,
disable_unfold: !unfold_K_inner_loop,
@@ -2018,7 +2054,7 @@ const compute_block_shape_K = ${compute_block_shape.K}u;
const compute_block_shape_N = ${compute_block_shape.N}u;
const compute_blocks_per_workgroup_M = ${compute_blocks_per_workgroup.M}u;
const compute_blocks_per_workgroup_K = ${compute_blocks_per_workgroup.K}u;
// const compute_blocks_per_workgroup_K = ${compute_blocks_per_workgroup.K}u;
const compute_blocks_per_workgroup_N = ${compute_blocks_per_workgroup.N}u;
const A_loading_point_rows_per_compute_block = ${loading_points_per_compute_block_A.rows}u;
@@ -2032,7 +2068,8 @@ const threads_along_M_per_TSG = ${threads_along_M_per_TSG}u;
const threads_along_N_per_TSG = ${threads_along_N_per_TSG}u;
const_assert(threads_along_M_per_TSG * threads_along_N_per_TSG == threads_per_TSG);
const compute_blocks_per_thread_K = ${Math.ceil(compute_blocks_per_workgroup.K / tensor_slice_factor)}u;
// compute_blocks_per_thread_K = ceil(compute_blocks_per_workgroup_K / tensor_slice_factor) comes from uniform
// to allow more shader reusing.
const compute_blocks_per_thread_M = ${
compute_blocks_per_thread_M}u; // compute_blocks_per_workgroup_M / threads_along_M_per_TSG
const compute_blocks_per_thread_N = ${
@@ -2046,7 +2083,23 @@ alias OutputPackedType = ${output_packed_type_WGSL}; // Bias, if any, should
// 2D array type holding a whole output block
alias ComputeBlockOutputType = ${compute_block_output_type_WGSL};
${shader_helper.registerUniforms(uniforms_WGSL_info).declareVariables(...input_variables, output_variable)}
${
shader_helper.registerUniforms(uniforms_WGSL_info)
.registerInternalVariables(...batch_helper_variables)
.declareVariables(...input_variables, output_variable)}
// Workgroup-cached uniforms
var<workgroup> input_A_loading_points_matrix_dims: vec2<i32>;
var<workgroup> input_B_loading_points_matrix_dims: vec2<i32>;
${has_bias ? 'var<workgroup> input_Bias_loading_points_matrix_dims: vec2<i32>;' : ''}
var<workgroup> output_loading_points_matrix_dims: vec2<i32>;
fn WorkgroupInit() {
input_A_loading_points_matrix_dims = vec2<i32>(${input_A_variable.shape}.yz);
input_B_loading_points_matrix_dims = vec2<i32>(${input_B_variable.shape}.yz);
${has_bias ? `input_Bias_loading_points_matrix_dims = vec2<i32>(${input_Bias_variable!.shape}.yz);` : ''}
output_loading_points_matrix_dims = vec2<i32>(${output_variable.shape}.yz);
}
// Invocation scope variables
var<private> tensor_slice_group: u32;
@@ -2121,9 +2174,14 @@ ${buffer_B_cache.cacheMemoryModuleDefinitionWGSL()}
fn main(
@builtin(workgroup_id) wid: vec3u,
@builtin(local_invocation_id) lid: vec3u,
${subgroup_cache_params.use_subgroups?'@builtin(subgroup_invocation_id) subgroup_id: u32,':''}
${subgroup_cache_params.use_subgroups?'@builtin(subgroup_size) subgroup_size: u32,':''}
${subgroup_cache_params.use_subgroups ? '@builtin(subgroup_invocation_id) subgroup_id: u32,' : ''}
${subgroup_cache_params.use_subgroups ? '@builtin(subgroup_size) subgroup_size: u32,' : ''}
) {
if(all(lid == vec3u())) {
WorkgroupInit();
}
workgroupBarrier();
let compute_blocks_workgroup_global_base_M = wid.x * compute_blocks_per_workgroup_M;
// compute_blocks_workgroup_global_base_K == 0
let compute_blocks_workgroup_global_base_N = wid.y * compute_blocks_per_workgroup_N;
@@ -2148,16 +2206,16 @@ fn main(
'thread_id_in_TSG / threads_along_M_per_TSG'};
loading_point_A_TSG_global_row_base = compute_blocks_workgroup_global_base_M * A_loading_point_rows_per_compute_block;
loading_point_A_TSG_global_col_base = tensor_slice_group * compute_blocks_per_thread_K * A_loading_point_cols_per_compute_block;
loading_point_B_TSG_global_row_base = tensor_slice_group * compute_blocks_per_thread_K * B_loading_point_rows_per_compute_block;
loading_point_A_TSG_global_col_base = tensor_slice_group * uniforms.compute_blocks_per_thread_K * A_loading_point_cols_per_compute_block;
loading_point_B_TSG_global_row_base = tensor_slice_group * uniforms.compute_blocks_per_thread_K * B_loading_point_rows_per_compute_block;
loading_point_B_TSG_global_col_base = compute_blocks_workgroup_global_base_N * B_loading_point_cols_per_compute_block;
loading_point_Output_TSG_global_row_base = compute_blocks_workgroup_global_base_M * Output_loading_point_rows_per_compute_block;
loading_point_Output_TSG_global_col_base = compute_blocks_workgroup_global_base_N * Output_loading_point_cols_per_compute_block;
compute_blocks_thread_TSG_base_M = thread_M_in_TSG * compute_blocks_per_thread_M;
compute_blocks_thread_TSG_base_N = thread_N_in_TSG * compute_blocks_per_thread_N;
let compute_blocks_TSG_workgroup_base_K = tensor_slice_group * compute_blocks_per_thread_K;
let compute_blocks_TSG_workgroup_base_K = tensor_slice_group * uniforms.compute_blocks_per_thread_K;
// Buffer cache function-scope definition, if any
${buffer_A_cache.cacheMemoryFunctionDefinitionWGSL()}
@@ -2167,7 +2225,7 @@ fn main(
for (
var compute_block_thread_K_outer_loop: u32 = 0;
compute_block_thread_K_outer_loop < compute_blocks_per_thread_K;
compute_block_thread_K_outer_loop < uniforms.compute_blocks_per_thread_K;
compute_block_thread_K_outer_loop += compute_block_thread_K_inner_loop_step
) {
// Handle cache steps along K, if any
@@ -2219,17 +2277,27 @@ fn main(
activation,
workgroup_params,
compute_block_shape,
compute_blocks_per_workgroup,
output_compute_blocks_per_workgroup: {M: compute_blocks_per_workgroup.M, N: compute_blocks_per_workgroup.N},
tensor_slice_factor,
input_A_layout: input_A.buffer_layout,
input_B_layout: input_B.buffer_layout,
output_layout: output.buffer_layout,
input_A_layout: input_A.buffer_layout.loading_points_layout,
input_B_layout: input_B.buffer_layout.loading_points_layout,
output_layout: output.buffer_layout.loading_points_layout,
packing_size: {
input_A: input_A.buffer_layout.packed_vector_size,
input_B: input_B.buffer_layout.packed_vector_size,
output: output.buffer_layout.packed_vector_size,
},
batch_indices: {
input_A_broadcasted_dims,
input_B_broadcasted_dims,
},
};

// LOG(`fatal`, `templatedMatMulProgram Return`);

return {
name: `MatMulTemplatedTrival-${Object.entries(cacheKey).map((entry) => entry.map(x => JSON.stringify(x)).join(':')).join('-')};`,
name: `MatMulTemplatedTrival-${
Object.entries(cacheKey).map((entry) => entry.map(x => JSON.stringify(x)).join(':')).join('-')};`,
shaderCache: {
// hint: `${activation.activation};${Object.entries(workgroup_params).map((entry) =>
// entry.join(':')).join(';')};`,
@@ -2267,7 +2335,7 @@ export const templatedMatMulDriver = (
assert(
batchAggregatedInputs[0].dims[0] * batchAggregatedInputs[1].dims[0] === batchAggregatedOutputShape[0],
`MatMul output batches ${batchAggregatedOutputShape[0]} should be equal to \
input A batches ${batchAggregatedInputs[0].dims[0]} * input B batches ${batchAggregatedInputs[1].dims[0]}`);
input A batches ${batchAggregatedInputs[0].dims[0]} * input B batches ${batchAggregatedInputs[1].dims[0]}`);
*/

const createMatrixTensorInfo = (
@@ -2349,12 +2417,7 @@ input A batches ${batchAggregatedInputs[0].dims[0]} * input B batches ${batchAgg

LOG(`fatal`, `Before templatedMatMulProgram`);

return templatedMatMulProgram(
op_params,
workgroup_params,
schedule_params,
subgroup_cache_params
);
return templatedMatMulProgram(op_params, workgroup_params, schedule_params, subgroup_cache_params);
}

export const createComputeBlockMatmulNaiveProgramInfo =

0 comments on commit d567dfe

Please sign in to comment.