Skip to content

Commit 8cda648

Browse files
committed
templatize canonicalization
1 parent e3322e8 commit 8cda648

File tree

1 file changed

+28
-289
lines changed

1 file changed

+28
-289
lines changed

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 28 additions & 289 deletions
Original file line numberDiff line numberDiff line change
@@ -596,295 +596,34 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
596596
* function signature, and only modify the number of outputs.
597597
*
598598
*/
599-
class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
600-
public:
601-
using OpRewritePattern<AutoDiffOp>::OpRewritePattern;
602-
603-
LogicalResult matchAndRewrite(AutoDiffOp uop,
604-
PatternRewriter &rewriter) const override {
605-
// early return if there are no outputs
606-
if (uop.getOutputs().size() == 0)
607-
return failure();
608-
609-
auto inpActivity = uop.getActivity();
610-
auto retActivity = uop.getRetActivity();
611-
auto out_idx = 0;
612-
SmallVector<mlir::Value, 2> in_args;
613-
SmallVector<mlir::Value, 2> outs_args;
614-
SmallVector<Type, 2> in_ty;
615-
SmallVector<Type, 2> out_ty;
616-
SmallVector<ActivityAttr, 2> newInActivityArgs;
617-
SmallVector<ActivityAttr, 2> newRetActivityArgs;
618-
619-
bool changed = false;
620-
auto in_idx = 0;
621-
622-
// go upto dOutput
623-
for (auto [idx, act] : llvm::enumerate(inpActivity)) {
624-
auto iattr = cast<ActivityAttr>(act);
625-
auto val = iattr.getValue();
626-
mlir::Value res = uop.getInputs()[in_idx];
627-
in_args.push_back(res);
628-
in_ty.push_back(res.getType());
629-
in_idx++;
630-
631-
if (val == Activity::enzyme_dup || val == Activity::enzyme_dupnoneed) {
632-
mlir::Value dres = uop.getInputs()[in_idx];
633-
in_args.push_back(dres);
634-
in_ty.push_back(dres.getType());
635-
in_idx++;
636-
}
637-
}
638-
// function isn't differentiable
639-
if (in_idx == uop.getInputs().size())
640-
return failure();
641-
642-
// handle pOutput
643-
for (auto [idx, act] : llvm::enumerate(retActivity)) {
644-
auto iattr = cast<ActivityAttr>(act);
645-
auto val = iattr.getValue();
646-
647-
// skip primal return
648-
if (val == Activity::enzyme_constnoneed ||
649-
val == Activity::enzyme_dupnoneed) {
650-
newRetActivityArgs.push_back(iattr);
651-
continue;
652-
}
653-
654-
mlir::Value res = uop.getOutputs()[out_idx];
655-
656-
switch (val) {
657-
case Activity::enzyme_active: {
658-
// active -> activenoneed(if res isn't used)
659-
// active -> const(if dres == 0)
660-
// active -> constnoneed(both)
661-
662-
mlir::Value dres = uop.getInputs()[in_idx];
663-
in_idx++;
664-
665-
auto dres_type = dres.getType();
666-
auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
667-
668-
if (!res.use_empty()) {
669-
outs_args.push_back(res);
670-
out_ty.push_back(res.getType());
671-
ActivityAttr new_act = iattr;
672-
if (dres_type_intf && !isMutable(dres_type) &&
673-
dres_type_intf.isZero(dres)) {
674-
// const
675-
changed = true;
676-
new_act = ActivityAttr::get(rewriter.getContext(),
677-
Activity::enzyme_const);
678-
} else {
679-
in_args.push_back(dres);
680-
in_ty.push_back(dres_type);
681-
}
682-
newRetActivityArgs.push_back(new_act);
683-
} else {
684-
changed = true;
685-
ActivityAttr new_act = ActivityAttr::get(
686-
rewriter.getContext(), Activity::enzyme_activenoneed);
687-
if (dres_type_intf && !isMutable(dres_type) &&
688-
dres_type_intf.isZero(dres)) {
689-
// constnoneed
690-
new_act = ActivityAttr::get(rewriter.getContext(),
691-
Activity::enzyme_constnoneed);
692-
} else {
693-
// activenoneed
694-
in_args.push_back(dres);
695-
in_ty.push_back(dres_type);
696-
}
697-
newRetActivityArgs.push_back(new_act);
698-
}
699-
700-
++out_idx;
701-
break;
702-
}
703-
704-
case Activity::enzyme_activenoneed:
705-
// activenoneed -> constnoneed
706-
{
707-
mlir::Value dres = uop.getInputs()[in_idx];
708-
in_idx++;
709-
auto new_act = iattr;
710-
711-
auto dres_type = dres.getType();
712-
auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
713-
if (dres_type_intf && !isMutable(dres_type) &&
714-
dres_type_intf.isZero(dres)) {
715-
// constnoneed
716-
new_act = ActivityAttr::get(rewriter.getContext(),
717-
Activity::enzyme_constnoneed);
718-
} else {
719-
in_args.push_back(dres);
720-
in_ty.push_back(dres_type);
721-
}
722-
newRetActivityArgs.push_back(iattr);
723-
break;
724-
}
725-
case Activity::enzyme_const:
726-
// const -> constnoneed
727-
{
728-
auto new_act = iattr;
729-
if (!res.use_empty()) {
730-
outs_args.push_back(res);
731-
out_ty.push_back(res.getType());
732-
newRetActivityArgs.push_back(new_act);
733-
} else {
734-
changed = true;
735-
new_act = ActivityAttr::get(rewriter.getContext(),
736-
Activity::enzyme_constnoneed);
737-
newRetActivityArgs.push_back(new_act);
738-
}
739-
++out_idx;
740-
break;
741-
}
742-
743-
case Activity::enzyme_dup:
744-
// TODO: check if ret_arg == enzyme_dup inserts a derivative as the
745-
// output and input both
746-
outs_args.push_back(res);
747-
out_ty.push_back(res.getType());
748-
newRetActivityArgs.push_back(iattr);
749-
++out_idx;
750-
break;
751-
752-
case Activity::enzyme_constnoneed:
753-
case Activity::enzyme_dupnoneed:
754-
break;
755-
756-
default:
757-
llvm_unreachable("unexpected activity arg");
758-
}
759-
}
760-
761-
// handle dInputs
762-
for (auto [idx, act] : llvm::enumerate(inpActivity)) {
763-
auto iattr = cast<ActivityAttr>(act);
764-
auto val = iattr.getValue();
765-
766-
if (val == Activity::enzyme_active) {
767-
mlir::Value res = uop.getOutputs()[out_idx];
768-
if (!res.use_empty()) {
769-
out_ty.push_back(res.getType());
770-
outs_args.push_back(res);
771-
newInActivityArgs.push_back(iattr);
772-
} else {
773-
// TODO: check if we can relax immutability here
774-
if (!isMutable(res.getType())) {
775-
changed = true;
776-
auto new_const = ActivityAttr::get(rewriter.getContext(),
777-
Activity::enzyme_const);
778-
newInActivityArgs.push_back(new_const);
779-
} else {
780-
// noop even if its not used.
781-
out_ty.push_back(res.getType());
782-
outs_args.push_back(res);
783-
newInActivityArgs.push_back(iattr);
784-
}
785-
}
786-
787-
++out_idx;
788-
} else if (val == Activity::enzyme_activenoneed) {
789-
mlir::Value res = uop.getOutputs()[out_idx];
790-
out_ty.push_back(res.getType());
791-
outs_args.push_back(res);
792-
newInActivityArgs.push_back(iattr);
793-
++out_idx;
794-
llvm_unreachable("unsupported arg activenoneed");
795-
} else {
796-
newInActivityArgs.push_back(iattr);
797-
}
798-
}
799-
800-
if (!changed)
801-
return failure();
802-
803-
ArrayAttr newInActivity =
804-
ArrayAttr::get(rewriter.getContext(),
805-
llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
806-
newInActivityArgs.end()));
807-
ArrayAttr newRetActivity =
808-
ArrayAttr::get(rewriter.getContext(),
809-
llvm::ArrayRef<Attribute>(newRetActivityArgs.begin(),
810-
newRetActivityArgs.end()));
811599

812-
AutoDiffOp newOp = rewriter.create<AutoDiffOp>(
813-
uop.getLoc(), out_ty, uop.getFnAttr(), in_args, newInActivity,
814-
newRetActivity, uop.getWidthAttr(), uop.getStrongZeroAttr());
815-
816-
// Map old uses of uop to newOp
817-
auto oldIdx = 0;
818-
auto newIdx = 0;
819-
for (auto [idx, old_act, new_act] :
820-
llvm::enumerate(retActivity, newRetActivityArgs)) {
821-
auto iattr = cast<ActivityAttr>(old_act);
822-
auto old_val = iattr.getValue();
823-
auto new_val = new_act.getValue();
824-
825-
if (old_val == new_val) {
826-
// don't index into op if no primal is returned
827-
if (old_val == Activity::enzyme_constnoneed ||
828-
old_val == Activity::enzyme_activenoneed ||
829-
old_val == Activity::enzyme_dupnoneed) {
830-
continue;
831-
}
832-
// replace current Primal
833-
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
834-
newOp.getOutputs()[newIdx++]);
835-
} else {
836-
// handle all substitutions
837-
if (new_val == Activity::enzyme_activenoneed &&
838-
old_val == Activity::enzyme_active) {
839-
++oldIdx; // skip active primal
840-
} else if (new_val == Activity::enzyme_constnoneed &&
841-
old_val == Activity::enzyme_const) {
842-
++oldIdx; // skip const primal
843-
} else if (old_val == Activity::enzyme_active &&
844-
new_val == Activity::enzyme_const) {
845-
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
846-
newOp.getOutputs()[newIdx++]);
847-
} else if (old_val == Activity::enzyme_active &&
848-
new_val == Activity::enzyme_constnoneed) {
849-
++oldIdx;
850-
} else if (old_val == Activity::enzyme_activenoneed &&
851-
new_val == Activity::enzyme_constnoneed) {
852-
// just skip
853-
}
854-
}
855-
}
856-
857-
for (auto [idx, old_act, new_act] :
858-
llvm::enumerate(inpActivity, newInActivityArgs)) {
859-
auto iattr = cast<ActivityAttr>(old_act);
860-
auto old_val = iattr.getValue();
861-
auto new_val = new_act.getValue();
600+
// Overload for AutoDiffOp
601+
static inline AutoDiffOp createOp(PatternRewriter &rewriter, AutoDiffOp uop,
602+
ArrayRef<Type> out_ty,
603+
ArrayRef<Value> in_args,
604+
ArrayAttr newInActivity,
605+
ArrayAttr newRetActivity) {
606+
return rewriter.create<AutoDiffOp>(
607+
uop.getLoc(), out_ty, uop.getFnAttr(), in_args, newInActivity,
608+
newRetActivity, uop.getWidthAttr(), uop.getStrongZeroAttr());
609+
}
862610

863-
if (old_val == new_val) {
864-
if (old_val == Activity::enzyme_active ||
865-
old_val == Activity::enzyme_activenoneed) {
866-
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
867-
newOp.getOutputs()[newIdx++]);
868-
} else {
869-
continue;
870-
}
871-
} else {
872-
if (old_val == Activity::enzyme_active &&
873-
new_val == Activity::enzyme_const) {
874-
oldIdx++; // skip derivative
875-
}
876-
}
877-
}
878-
rewriter.eraseOp(uop);
879-
return success();
880-
}
881-
};
611+
// Overload for AutoDiffRegionOp
612+
static inline AutoDiffRegionOp
613+
createOp(PatternRewriter &rewriter, AutoDiffRegionOp uop, ArrayRef<Type> out_ty,
614+
ArrayRef<Value> in_args, ArrayAttr newInActivity,
615+
ArrayAttr newRetActivity) {
616+
return rewriter.create<AutoDiffRegionOp>(
617+
uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity,
618+
uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr());
619+
}
882620

883-
class ReverseRetOpt2 final : public OpRewritePattern<AutoDiffRegionOp> {
621+
template <typename SourceOp>
622+
class ReverseRetOpt final : public OpRewritePattern<SourceOp> {
884623
public:
885-
using OpRewritePattern<AutoDiffRegionOp>::OpRewritePattern;
624+
using OpRewritePattern<SourceOp>::OpRewritePattern;
886625

887-
LogicalResult matchAndRewrite(AutoDiffRegionOp uop,
626+
LogicalResult matchAndRewrite(SourceOp uop,
888627
PatternRewriter &rewriter) const override {
889628
// early return if there are no outputs
890629
if (uop.getOutputs().size() == 0)
@@ -1093,9 +832,8 @@ class ReverseRetOpt2 final : public OpRewritePattern<AutoDiffRegionOp> {
1093832
llvm::ArrayRef<Attribute>(newRetActivityArgs.begin(),
1094833
newRetActivityArgs.end()));
1095834

1096-
AutoDiffRegionOp newOp = rewriter.create<AutoDiffRegionOp>(
1097-
uop.getLoc(), out_ty, in_args, newInActivity,
1098-
newRetActivity, uop.getWidthAttr(), uop.getStrongZeroAttr(),uop.getFnAttr());
835+
SourceOp newOp =
836+
createOp(rewriter, uop, out_ty, in_args, newInActivity, newRetActivity);
1099837

1100838
// Map old uses of uop to newOp
1101839
auto oldIdx = 0;
@@ -1163,14 +901,15 @@ class ReverseRetOpt2 final : public OpRewritePattern<AutoDiffRegionOp> {
1163901
return success();
1164902
}
1165903
};
904+
1166905
void AutoDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1167906
MLIRContext *context) {
1168-
patterns.add<ReverseRetOpt>(context);
907+
patterns.add<ReverseRetOpt<AutoDiffOp>>(context);
1169908
}
1170909

1171910
void AutoDiffRegionOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1172911
MLIRContext *context) {
1173-
patterns.add<ReverseRetOpt2>(context);
912+
patterns.add<ReverseRetOpt<AutoDiffRegionOp>>(context);
1174913
}
1175914
//===----------------------------------------------------------------------===//
1176915
// SampleOp

0 commit comments

Comments
 (0)