Skip to content

Commit d865d15

Browse files
committed
[TritonGEN] Use the sub-group-size of the module instead of hard code number of 16 in block load.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent c08b0ba commit d865d15

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "llvm/ADT/STLExtras.h"
1414
#include <cstdint>
1515

16+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
17+
1618
using namespace mlir;
1719
using namespace mlir::triton;
1820

@@ -238,7 +240,9 @@ verify2DBlockLoadHWRestriction(TritonGEN::Matrix2DBlockLoadOp op) {
238240
VectorType resTy = op.getRes().getType();
239241
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
240242
unsigned resSize = resTy.getNumElements() * resElemTySize;
241-
constexpr unsigned subgroupSize = 16;
243+
unsigned subgroupSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
244+
op->getParentOfType<mlir::ModuleOp>());
245+
;
242246
unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
243247
op.getTileWidth() * op.getVBlocks() / subgroupSize;
244248
if (resSize != expectedSize)

0 commit comments

Comments
 (0)