Skip to content

Commit e44747c

Browse files
Merge pull request #2290 from j2kun:extract_slice_of_splat
PiperOrigin-RevId: 814850947
2 parents c12a41a + 6e0c63f commit e44747c

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

lib/Transforms/FoldConstantTensors/FoldConstantTensors.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,39 @@ struct CollapseEmptyTensor
223223
}
224224
};
225225

226+
struct ExtractSliceOfSplat
227+
: public OpRewritePattern<mlir::tensor::ExtractSliceOp> {
228+
public:
229+
ExtractSliceOfSplat(MLIRContext* context)
230+
: OpRewritePattern<mlir::tensor::ExtractSliceOp>(context) {}
231+
232+
using OpRewritePattern::OpRewritePattern;
233+
234+
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp op,
235+
PatternRewriter& rewriter) const override {
236+
auto splatOp =
237+
dyn_cast_or_null<tensor::SplatOp>(op.getSource().getDefiningOp());
238+
if (!splatOp) return failure();
239+
240+
auto resultTy = op.getResult().getType();
241+
auto newSplat = tensor::SplatOp::create(
242+
rewriter, op.getLoc(), splatOp.getInput(), resultTy.getShape());
243+
rewriter.replaceOp(op, newSplat);
244+
return success();
245+
}
246+
};
247+
226248
struct FoldConstantTensors
227249
: public impl::FoldConstantTensorsBase<FoldConstantTensors> {
228250
void runOnOperation() override {
229251
MLIRContext* context = &getContext();
230252
auto* module = getOperation();
231253

232254
RewritePatternSet patterns(context);
233-
patterns.add<InsertAfterConstant, CollapseShapeAfterConstant,
234-
CollapseEmptyTensor, InsertIntoFromElements>(context);
255+
patterns
256+
.add<InsertAfterConstant, CollapseShapeAfterConstant,
257+
CollapseEmptyTensor, InsertIntoFromElements, ExtractSliceOfSplat>(
258+
context);
235259

236260
// Run pattern matching and conversion
237261
if (failed(applyPatternsGreedily(module, std::move(patterns)))) {

lib/Transforms/FoldConstantTensors/FoldConstantTensors.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def FoldConstantTensors : Pass<"fold-constant-tensors"> {
1515

1616
The following folders are supported:
1717
* `tensor.insert` of a constant tensor
18-
* `tensor.collapse_shape` of a constant tensor
18+
* `tensor.collapse_shape` of a constant or empty tensor
19+
* `tensor.extract_slice` of a splat to a splat of the new shape
1920
}];
2021
}
2122

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: heir-opt --fold-constant-tensors %s | FileCheck %s
2+
3+
// CHECK: func @extract_slice_of_splat
4+
func.func @extract_slice_of_splat(%arg0: i32) -> (tensor<1024xi32>) {
5+
// Fold a collapse shape of a constant
6+
// CHECK-NEXT: %[[splat:.+]] = tensor.splat
7+
// CHECK-NEXT: return %[[splat]]
8+
%0 = tensor.splat %arg0 : tensor<10x1024xi32>
9+
%slice = tensor.extract_slice %0[1, 0] [1, 1024] [1, 1] : tensor<10x1024xi32> to tensor<1024xi32>
10+
return %slice : tensor<1024xi32>
11+
}

0 commit comments

Comments
 (0)