@@ -14,8 +14,8 @@ pub fn matmul(
1414 #[ spirv( storage_buffer, descriptor_set = 0 , binding = 2 ) ] b : & [ f32 ] ,
1515 #[ spirv( storage_buffer, descriptor_set = 0 , binding = 3 ) ] result : & mut [ f32 ] ,
1616) {
17- let row = ( global_id. y * TILE_M as u32 ) as usize ;
18- let col = ( global_id. x * TILE_N as u32 ) as usize ;
17+ let row = ( global_id. y * TILE_M ) as usize ;
18+ let col = ( global_id. x * TILE_N ) as usize ;
1919
2020 // Initialize sums array to zeros
2121 // Note: This is uglier than it needs to be to work around
@@ -33,7 +33,7 @@ pub fn matmul(
3333
3434 for j in 0 ..TILE_N as usize {
3535 let b_element = if col + j < dimensions. n as usize {
36- b[ k * dimensions. n as usize + ( col + j as usize ) ]
36+ b[ k * dimensions. n as usize + ( col + j) ]
3737 } else {
3838 0.0
3939 } ;
@@ -46,8 +46,8 @@ pub fn matmul(
4646 // Write results
4747 for i in 0 ..TILE_M as usize {
4848 for j in 0 ..TILE_N as usize {
49- let output_row = row + i as usize ;
50- let output_col = col + j as usize ;
49+ let output_row = row + i;
50+ let output_col = col + j;
5151
5252 if output_row < dimensions. m as usize && output_col < dimensions. n as usize {
5353 result[ output_row * dimensions. n as usize + output_col] = sums[ i] [ j] ;
0 commit comments