-
Notifications
You must be signed in to change notification settings - Fork 62
[MatmulLoopPipeline] Predicate PrefetchOp
#4016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
33558fa
21e483c
1a38544
ed31253
cc4150d
f47bf85
5b538dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||
#include "triton/Analysis/AxisInfo.h" | ||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/Debug.h" | ||
|
||
|
@@ -149,12 +150,13 @@ static void collectOpsToPipeline(scf::ForOp forOp, | |
} | ||
} | ||
|
||
/// Combine the current mask with the given predicate. | ||
static Value getPredMask(RewriterBase &rewriter, Type typeLike, | ||
Value currentMask, Value pred) { | ||
/// Return a new mask of type of shape \p typeLike, and value combining the | ||
/// current mask \p currentMask with the given predicate \p pred. | ||
static Value computeNewMask(RewriterBase &rewriter, Type typeLike, | ||
Value currentMask, Value pred) { | ||
Location loc = pred.getLoc(); | ||
Value mask = pred; | ||
Type maskType = tt::getI1SameShape(typeLike); | ||
Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type of mask should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does typeLike means ? Lets document the parameters this function take and also what it does. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added function description. |
||
|
||
if (isa<RankedTensorType>(maskType)) | ||
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred); | ||
|
@@ -167,18 +169,17 @@ static Value getPredMask(RewriterBase &rewriter, Type typeLike, | |
static Operation *predicateOp(RewriterBase &rewriter, Operation *op, | ||
Value pred) { | ||
OpBuilder::InsertionGuard guard(rewriter); | ||
if (mlir::isMemoryEffectFree(op) || isa<ttgi::PrefetchOp>(op)) | ||
if (mlir::isMemoryEffectFree(op)) | ||
return op; | ||
|
||
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) { | ||
rewriter.setInsertionPoint(loadOp); | ||
Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), | ||
loadOp.getMask(), pred); | ||
loadOp.getMaskMutable().assign(mask); | ||
return loadOp; | ||
} | ||
|
||
llvm_unreachable("don't know how to predicate this operation"); | ||
return TypeSwitch<Operation *, Operation *>(op) | ||
.Case<tt::LoadOp, ttgi::PrefetchOp>([&](auto op) { | ||
rewriter.setInsertionPoint(op); | ||
Value mask = | ||
computeNewMask(rewriter, op.getPtr().getType(), op.getMask(), pred); | ||
op.getMaskMutable().assign(mask); | ||
return op; | ||
}); | ||
} | ||
|
||
/// Helper to get the defining operation of a value. | ||
|
Uh oh!
There was an error while loading. Please reload this page.