Skip to content

Commit e23bfd3

Browse files
committed
Separate out optimization
1 parent 0b07226 commit e23bfd3

File tree

3 files changed

+56
-11
lines changed

3 files changed

+56
-11
lines changed

enzyme/Enzyme/MLIR/Passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms
1111
EnzymeBatchToTensorPass.cpp
1212
EnzymeWrapPass.cpp
1313
InlineEnzymeRegions.cpp
14+
HoistEnzymeRegions.cpp
1415
LowerLLVMExtPass.cpp
1516
PrintActivityAnalysis.cpp
1617
PrintAliasAnalysis.cpp
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===- HoistEnzymeRegions.cpp -LICM for enzyme.autodiff_region ----------=== //
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements passes to hoist computations within autodiff_region ops
10+
// to the caller
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "Dialect/Ops.h"
15+
#include "Interfaces/AutoDiffOpInterface.h"
16+
#include "Passes/Passes.h"
17+
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/Interfaces/FunctionInterfaces.h"
21+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
#include "mlir/Transforms/RegionUtils.h"
23+
#include "llvm/ADT/TypeSwitch.h"
24+
25+
using namespace mlir;
26+
27+
namespace mlir {
28+
namespace enzyme {
29+
#define GEN_PASS_DEF_HOISTENZYMEFROMREGIONPASS
30+
#include "Passes/Passes.h.inc"
31+
} // namespace enzyme
32+
} // namespace mlir
33+
34+
namespace {
35+
36+
struct HoistEnzymeAutoDiff : public OpRewritePattern<enzyme::AutoDiffRegionOp> {
37+
using OpRewritePattern<enzyme::AutoDiffRegionOp>::OpRewritePattern;
38+
LogicalResult matchAndRewrite(enzyme::AutoDiffRegionOp op,
39+
PatternRewriter &rewriter) const override {
40+
41+
return success();
42+
}
43+
};
44+
45+
struct HoistEnzymeFromRegion
46+
: public enzyme::impl::HoistEnzymeFromRegionPassBase<
47+
HoistEnzymeFromRegion> {
48+
void runOnOperation() override {
49+
RewritePatternSet patterns(&getContext());
50+
patterns.add<HoistEnzymeAutoDiff>(&getContext());
51+
GreedyRewriteConfig config;
52+
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
53+
}
54+
};
55+
} // namespace

enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ namespace mlir {
2727
namespace enzyme {
2828
#define GEN_PASS_DEF_INLINEENZYMEINTOREGIONPASS
2929
#define GEN_PASS_DEF_OUTLINEENZYMEFROMREGIONPASS
30-
#define GEN_PASS_DEF_HOISTENZYMEFROMREGIONPASS
3130
#include "Passes/Passes.h.inc"
3231
} // namespace enzyme
3332
} // namespace mlir
@@ -262,16 +261,6 @@ struct OutlineEnzymeFromRegion
262261
}
263262
};
264263

265-
struct HoistEnzymeFromRegion
266-
: public enzyme::impl::HoistEnzymeFromRegionPassBase<
267-
HoistEnzymeFromRegion> {
268-
void runOnOperation() override {
269-
RewritePatternSet patterns(&getContext());
270-
271-
GreedyRewriteConfig config;
272-
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
273-
}
274-
};
275264
} // namespace
276265

277266
bool mlir::enzyme::inlineAutodiffOp(enzyme::AutoDiffOp &op,

0 commit comments

Comments
 (0)