Skip to content

New pass Reduce variable liveness #3965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
815681b
Add new pass that tries to reduce the register pressure by moving loa…
mfrancepillois Apr 18, 2025
34955bf
Fix types mismatch bug.
mfrancepillois Apr 18, 2025
3915019
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois Apr 23, 2025
bdf0999
Rename pass + use liveness analysis for heuristic.
mfrancepillois Apr 24, 2025
ac4da4b
Fix typo
mfrancepillois Apr 24, 2025
f34ef03
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois Apr 25, 2025
ec4394a
Improve heuristic.
mfrancepillois Apr 29, 2025
a937fb8
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois Apr 29, 2025
4d17f01
Add new test cases to match new heuristic conditions.
mfrancepillois Apr 29, 2025
e84b4ca
Improves the heuristic + addresses comments
mfrancepillois Apr 30, 2025
67aec58
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois Apr 30, 2025
db2c1ed
Enforce only ConvertLayoutOp between dot and load + bug fix.
mfrancepillois Apr 30, 2025
79ab6ed
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois Apr 30, 2025
6fd22d8
Bug fix
mfrancepillois Apr 30, 2025
9428627
remove unused variable.
mfrancepillois Apr 30, 2025
3bae4e6
Add element type check.
mfrancepillois Apr 30, 2025
d92e667
Address comments: improve code quality.
mfrancepillois May 7, 2025
68f9b48
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois May 7, 2025
3996a60
Extend support to handle tensor of pointers without mask + add test.
mfrancepillois May 12, 2025
6971b37
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois May 12, 2025
2da8a5d
Add support for multiple users
mfrancepillois May 13, 2025
ddc692b
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois May 13, 2025
2456692
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois May 20, 2025
7ce5195
Merge branch 'main' into maxime/reduceRegisterPressure
mfrancepillois May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
832 changes: 832 additions & 0 deletions test/TritonIntelGPU/reduce-variable-liveness.mlir

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class XPUOptions:
optimize_epilogue: bool = False
enable_fp_fusion: bool = True
launch_cooperative_grid: bool = False
reduce_variable_liveness: bool = True
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4nv", "fp8e4b15")
deprecated_fp8_dtypes: Tuple[str] = ()
default_dot_input_precision: str = "tf32"
Expand Down Expand Up @@ -291,6 +292,9 @@ def make_ttgir(mod, metadata, opt, properties):
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt))

if (opt.reduce_variable_liveness):
intel.passes.ttgpuir.add_reduce_variable_liveness(pm)

passes.ttgpuir.add_fuse_nested_loops(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,18 @@ def TritonIntelGPURewriteStackPtr
"mlir::arith::ArithDialect"
];
}

def TritonIntelGPUReduceVariableLiveness
: Pass<"tritonintelgpu-reduce-variable-liveness", "mlir::ModuleOp"> {
let summary = "Attempt to reduce the variable liveness";

let description = [{
This pass attempts to reduce the variable liveness
by reducing the distance between loads and usage.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
}
#endif // TRITON_INTEL_GPU_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_triton_library(TritonIntelGPUTransforms
Pipeliner/SoftwarePipeliner.cpp
PrefetchBlock.cpp
ReduceDataDuplication.cpp
ReduceVariableLiveness.cpp
RemoveLayoutConversions.cpp
RewriteStackPtr.cpp
ScheduleLoad.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
#include "Dialect/TritonIntelGPU/IR/Attributes.h"
#include "Dialect/TritonIntelGPU/Transforms/Utility.h"
#include "intel/include/Analysis/DPAS.h"
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "llvm/Support/Debug.h"

#include "intel/include/Analysis/Liveness.h"
#include "mlir/Analysis/Liveness.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <optional>

namespace mlir::triton::gpu::intel {
#define GEN_PASS_DEF_TRITONINTELGPUREDUCEVARIABLELIVENESS
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
} // namespace mlir::triton::gpu::intel

using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttgi = mlir::triton::gpu::intel;

using TensorValue = TypedValue<RankedTensorType>;

#define DEBUG_TYPE "tritonintelgpu-reduce-variable-liveness"

namespace {

#define TOTAL_BLOCK_SIZE_THRESHOLD_IN_BYTES 32768
#define LARGE_TENSOR_SIZE_THRESHOLD_IN_BYTES 128 * 64 * 2

static unsigned getSizeInBytes(RankedTensorType &tensorType) {
Type elType = tensorType.getElementType();
if (!elType.isIntOrFloat())
return 0;
unsigned elTypeBitWidth = elType.getIntOrFloatBitWidth();
unsigned totalNumElement = 1;
for (int64_t dim : tensorType.getShape()) {
totalNumElement *= dim;
}
return totalNumElement * (elTypeBitWidth / 8);
}

static unsigned
getBlockLiveInSizeInBytes(const LivenessBlockInfo *livenessBlockInfo) {
unsigned blockInSize = 0;
for (Value liveVal : livenessBlockInfo->in()) {
Type liveValTy = liveVal.getType();
if (TensorValue tensorV = dyn_cast<TensorValue>(liveVal)) {
auto tensorType = dyn_cast<RankedTensorType>(tensorV.getType());
blockInSize += getSizeInBytes(tensorType);
} else if (liveValTy.isIntOrFloat()) {
blockInSize += liveValTy.getIntOrFloatBitWidth() / 8;
}
}
return blockInSize;
}

/// Return true if the lifespan of the \p v value is considered long.
static bool isLongLifeSpanVariable(Value v,
const LivenessBlockInfo *livenessBlockInfo,
unsigned LiveInSizeInBytes) {
// The variable is considered as a long life span elected for being moved if:
// The live-in variables of the forOp consist in a large amount of bytes and
// The variable defined by `v` is a large tensor and
// The variable liveness of `v` expends before the dot block.
// i.e. used in a block - loaded in another block
if (TensorValue tensorV = dyn_cast<TensorValue>(v)) {
auto tensorType = cast<RankedTensorType>(tensorV.getType());
return (
(LiveInSizeInBytes > TOTAL_BLOCK_SIZE_THRESHOLD_IN_BYTES) &&
(getSizeInBytes(tensorType) >= LARGE_TENSOR_SIZE_THRESHOLD_IN_BYTES) &&
livenessBlockInfo->isLiveIn(v));
}
return false;
}

/// Return true if the \p loadOp is a suitable to be moved.
/// \p expectedElementType is the element type expected for the load to be a
/// candidate,
/// \p forOp operation to which we want to move the loadOp
static bool isLoadCandidate(tt::LoadOp loadOp, Type expectedElementType,
Operation *forOp) {
if (!mlir::triton::isTensorOrTensorPointerType(loadOp.getPtr().getType()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional]

Suggested change
if (!mlir::triton::isTensorOrTensorPointerType(loadOp.getPtr().getType()))
if (!mlir::triton::isTensorPointerType(loadOp.getResult().getType()))

return false;
// LoadOps with non-null mask are not considered to be moved
if (loadOp.getMask())
return false;
RankedTensorType loadType =
cast<RankedTensorType>(loadOp.getResult().getType());
Type loadElType = loadType.getElementType();
// Types mismatch => Skip this case to avoid inserting too
// many addtional operations in the loop.
if (expectedElementType != loadElType)
return false;
Attribute blockIOAttr = loadOp->getAttr(
mlir::triton::gpu::intel::TritonIntelGPUDialect::getBlockIOAttrName());
if (!blockIOAttr)
return false;
// Only tensor with rank = 2 are considered to be moved
if (loadType.getShape().size() != 2)
return false;
// Only loadOp out of the for loop body are considered to be moved
if (loadOp->getParentOp() == forOp)
return false;
// Multiple users
if (any_of(loadOp->getUsers(), [&](Operation *user) {
return ((user->getBlock() == forOp->getBlock()) &&
user->isBeforeInBlock(forOp));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does user->isBeforeInBlock(forOp) mean?
user->getBlock() == forOp->getBlock() means user is part of the loop?

}))
return false;
// We skip the load if the defining op is not is the same region.
// To avoid prefetching this data in another region
// (as the prefetch is added after the defining op).
if (!loadOp.getPtr().getDefiningOp())
return false;
return true;
}

/// Create a prefetch operation for the given load operation.
static void createPrefetchOp(tt::LoadOp loadOp) {
Operation *op = loadOp.getPtr().getDefiningOp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when did we check that loadOp.getPtr() is an operation? do we need to add that to isLoadCandidate?
Or should we add the support of when pointer is a region argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing. A check has been added to isLoadCandidate.
As the pass adds a prefetch right after the defining op, I'm concerned that adding this prefetch in another region (in the case the load ptr has been defined in another region) could have side effects on the cache (as an early data fetch could mean evincing data that are still needed).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we care about the case that the pointer directly come from function argument?

OpBuilder builder(op);
// TODO: Add prefetchOp after last dependency between ptr and mask,
// if this support is extended to support masks.
builder.setInsertionPointAfter(op);
auto prefetchOp = builder.create<ttgi::PrefetchOp>(
loadOp->getLoc(), loadOp.getPtr(), loadOp.getCache(), loadOp.getEvict(),
loadOp.getIsVolatile());

// inherit attributes from the load operation
auto attrs = loadOp->getAttrDictionary();
prefetchOp->setAttrs(attrs);
}

/// Investigate opportunities for the reducing register pressure by moving DotOp
/// operands.
static bool optimizeDotOperands(scf::ForOp forOp,
SmallVector<Value> &prefetchedValue,
Liveness &livenessAnalysis) {
Block *loop = forOp.getBody();

auto getEncoding = [](Value v) {
return cast<RankedTensorType>(v.getType()).getEncoding();
};

// returns loadOp that loads the value v.
auto getLoad = [](Value v) -> std::optional<triton::LoadOp> {
// walk back to Load operation
Operation *op = v.getDefiningOp();
while (op) {
if (auto loadOp = dyn_cast<triton::LoadOp>(op))
return loadOp;
if (!isa<ttg::ConvertLayoutOp>(op))
break;
op = op->getOperand(0).getDefiningOp();
}
return std::nullopt;
};

// Prefetch the dotOp operand and move it closer to dotOp.
auto moveOperand = [&prefetchedValue](uint8_t opId, triton::DotOp dotOp,
tt::LoadOp loadOp) {
assert(opId < 2 && "opId must be 0 or 1");
OpBuilder b(dotOp);
TensorValue tensorV = opId == 0 ? dotOp.getA() : dotOp.getB();
auto tensorType = cast<RankedTensorType>(tensorV.getType());
Operation *insertBeforeOp = dotOp;
SmallVector<Operation *> usesInSameLoop;
// Other use(s) in the same loop
for (Operation *user : loadOp->getUsers()) {
if (user == dotOp)
continue;
if (user->getParentOp() == dotOp->getParentOp()) {
usesInSameLoop.push_back(user);
if (user->isBeforeInBlock(insertBeforeOp))
insertBeforeOp = user;
}
}

if (std::find(prefetchedValue.begin(), prefetchedValue.end(),
loadOp.getPtr()) == prefetchedValue.end()) {
createPrefetchOp(loadOp);
prefetchedValue.push_back(loadOp.getPtr());
}
b.setInsertionPoint(insertBeforeOp);
auto newLoad = cast<tt::LoadOp>(b.clone(*loadOp.getOperation()));
auto newCvt = b.create<ttg::ConvertLayoutOp>(tensorV.getLoc(), tensorType,
newLoad.getResult());
dotOp.setOperand(opId, newCvt.getResult());

// Update other user in the same loop if any
for (Operation *user : usesInSameLoop)
user->replaceUsesOfWith(loadOp.getResult(), newLoad.getResult());

// Multiple users:
// Note that if other users come before the loop, the loadOp is not a
// candidate for being moved.
if (!loadOp->use_empty()) {
b.setInsertionPointAfter(dotOp->getParentOp());
auto copyLoad = cast<tt::LoadOp>(b.clone(*loadOp.getOperation()));
loadOp->replaceAllUsesWith(copyLoad);
}
};

SmallVector<triton::DotOp> dotsInFor;
for (Operation &op : *loop)
if (auto dotOp = dyn_cast<triton::DotOp>(op)) {
// Only accepts dotOps encoded as DPAS MMA
if (!mlir::triton::gpu::intel::hasDpasEncoding(
dotOp.getResult().getType()))
// Don't rewrite if any other type is found.
return false;
dotsInFor.push_back(dotOp);
}

if (dotsInFor.empty())
return false;

for (triton::DotOp dot : dotsInFor) {
auto aVals = getLoad(dot.getA());
auto bVals = getLoad(dot.getB());

auto livenessBlockInfo = livenessAnalysis.getLiveness(dot->getBlock());
unsigned LiveInSizeInBytes = getBlockLiveInSizeInBytes(livenessBlockInfo);

if (aVals && isLongLifeSpanVariable(aVals.value(), livenessBlockInfo,
LiveInSizeInBytes)) {
tt::LoadOp loadOp = aVals.value();
auto tensorType = cast<RankedTensorType>(dot.getA().getType());
if (isLoadCandidate(loadOp, tensorType.getElementType(), forOp))
moveOperand(0, dot, loadOp);
}
if (bVals && isLongLifeSpanVariable(bVals.value(), livenessBlockInfo,
LiveInSizeInBytes)) {
tt::LoadOp loadOp = bVals.value();
auto tensorType = cast<RankedTensorType>(dot.getB().getType());
if (isLoadCandidate(loadOp, tensorType, forOp))
moveOperand(1, dot, loadOp);
}
}
return true;
}

class ReduceVariableLivenessPass
: public triton::gpu::intel::impl::TritonIntelGPUReduceVariableLivenessBase<
ReduceVariableLivenessPass> {
public:
using triton::gpu::intel::impl::TritonIntelGPUReduceVariableLivenessBase<
ReduceVariableLivenessPass>::TritonIntelGPUReduceVariableLivenessBase;

void runOnOperation() override {
// Canonicalize convert ops to make the pattern matching easier.
SmallVector<Value> prefetchedValue;
RewritePatternSet cleanUpPatterns(&getContext());
triton::gpu::ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns,
&getContext());
if (mlir::applyPatternsGreedily(getOperation(), std::move(cleanUpPatterns))
.failed()) {
signalPassFailure();
}

Operation *rootOperation = getOperation();
rootOperation->walk([&](scf::ForOp forOp) {
// The liveness analysis must be re-performed before the processing of
// each "for loop" given that the liveness of variables may have changed
// as a result of the code, and specifically `LoadOps`, being modified
// by the pass.
Liveness livenessAnalysis(rootOperation);
if (optimizeDotOperands(forOp, prefetchedValue, livenessAnalysis))
return;
});
}
};

} // namespace
2 changes: 2 additions & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) {
gpu::intel::createTritonIntelGPUMaterializeBlockPointer);
ADD_PASS_WRAPPER_0("add_optimize_reduction_locality",
gpu::intel::createTritonIntelGPUOptimizeReductionLocality);
ADD_PASS_WRAPPER_0("add_reduce_variable_liveness",
gpu::intel::createTritonIntelGPUReduceVariableLiveness);
}

void init_triton_intel_passes_arith(py::module &&m) {
Expand Down