@@ -596,295 +596,34 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
596596 * function signature, and only modify the number of outputs.
597597 *
598598 */
599- class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
600- public:
601- using OpRewritePattern<AutoDiffOp>::OpRewritePattern;
602-
603- LogicalResult matchAndRewrite (AutoDiffOp uop,
604- PatternRewriter &rewriter) const override {
605- // early return if there are no outputs
606- if (uop.getOutputs ().size () == 0 )
607- return failure ();
608-
609- auto inpActivity = uop.getActivity ();
610- auto retActivity = uop.getRetActivity ();
611- auto out_idx = 0 ;
612- SmallVector<mlir::Value, 2 > in_args;
613- SmallVector<mlir::Value, 2 > outs_args;
614- SmallVector<Type, 2 > in_ty;
615- SmallVector<Type, 2 > out_ty;
616- SmallVector<ActivityAttr, 2 > newInActivityArgs;
617- SmallVector<ActivityAttr, 2 > newRetActivityArgs;
618-
619- bool changed = false ;
620- auto in_idx = 0 ;
621-
622- // go upto dOutput
623- for (auto [idx, act] : llvm::enumerate (inpActivity)) {
624- auto iattr = cast<ActivityAttr>(act);
625- auto val = iattr.getValue ();
626- mlir::Value res = uop.getInputs ()[in_idx];
627- in_args.push_back (res);
628- in_ty.push_back (res.getType ());
629- in_idx++;
630-
631- if (val == Activity::enzyme_dup || val == Activity::enzyme_dupnoneed) {
632- mlir::Value dres = uop.getInputs ()[in_idx];
633- in_args.push_back (dres);
634- in_ty.push_back (dres.getType ());
635- in_idx++;
636- }
637- }
638- // function isn't differentiable
639- if (in_idx == uop.getInputs ().size ())
640- return failure ();
641-
642- // handle pOutput
643- for (auto [idx, act] : llvm::enumerate (retActivity)) {
644- auto iattr = cast<ActivityAttr>(act);
645- auto val = iattr.getValue ();
646-
647- // skip primal return
648- if (val == Activity::enzyme_constnoneed ||
649- val == Activity::enzyme_dupnoneed) {
650- newRetActivityArgs.push_back (iattr);
651- continue ;
652- }
653-
654- mlir::Value res = uop.getOutputs ()[out_idx];
655-
656- switch (val) {
657- case Activity::enzyme_active: {
658- // active -> activenoneed(if res isn't used)
659- // active -> const(if dres == 0)
660- // active -> constnoneed(both)
661-
662- mlir::Value dres = uop.getInputs ()[in_idx];
663- in_idx++;
664-
665- auto dres_type = dres.getType ();
666- auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
667-
668- if (!res.use_empty ()) {
669- outs_args.push_back (res);
670- out_ty.push_back (res.getType ());
671- ActivityAttr new_act = iattr;
672- if (dres_type_intf && !isMutable (dres_type) &&
673- dres_type_intf.isZero (dres)) {
674- // const
675- changed = true ;
676- new_act = ActivityAttr::get (rewriter.getContext (),
677- Activity::enzyme_const);
678- } else {
679- in_args.push_back (dres);
680- in_ty.push_back (dres_type);
681- }
682- newRetActivityArgs.push_back (new_act);
683- } else {
684- changed = true ;
685- ActivityAttr new_act = ActivityAttr::get (
686- rewriter.getContext (), Activity::enzyme_activenoneed);
687- if (dres_type_intf && !isMutable (dres_type) &&
688- dres_type_intf.isZero (dres)) {
689- // constnoneed
690- new_act = ActivityAttr::get (rewriter.getContext (),
691- Activity::enzyme_constnoneed);
692- } else {
693- // activenoneed
694- in_args.push_back (dres);
695- in_ty.push_back (dres_type);
696- }
697- newRetActivityArgs.push_back (new_act);
698- }
699-
700- ++out_idx;
701- break ;
702- }
703-
704- case Activity::enzyme_activenoneed:
705- // activenoneed -> constnoneed
706- {
707- mlir::Value dres = uop.getInputs ()[in_idx];
708- in_idx++;
709- auto new_act = iattr;
710-
711- auto dres_type = dres.getType ();
712- auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
713- if (dres_type_intf && !isMutable (dres_type) &&
714- dres_type_intf.isZero (dres)) {
715- // constnoneed
716- new_act = ActivityAttr::get (rewriter.getContext (),
717- Activity::enzyme_constnoneed);
718- } else {
719- in_args.push_back (dres);
720- in_ty.push_back (dres_type);
721- }
722- newRetActivityArgs.push_back (iattr);
723- break ;
724- }
725- case Activity::enzyme_const:
726- // const -> constnoneed
727- {
728- auto new_act = iattr;
729- if (!res.use_empty ()) {
730- outs_args.push_back (res);
731- out_ty.push_back (res.getType ());
732- newRetActivityArgs.push_back (new_act);
733- } else {
734- changed = true ;
735- new_act = ActivityAttr::get (rewriter.getContext (),
736- Activity::enzyme_constnoneed);
737- newRetActivityArgs.push_back (new_act);
738- }
739- ++out_idx;
740- break ;
741- }
742-
743- case Activity::enzyme_dup:
744- // TODO: check if ret_arg == enzyme_dup inserts a derivative as the
745- // output and input both
746- outs_args.push_back (res);
747- out_ty.push_back (res.getType ());
748- newRetActivityArgs.push_back (iattr);
749- ++out_idx;
750- break ;
751-
752- case Activity::enzyme_constnoneed:
753- case Activity::enzyme_dupnoneed:
754- break ;
755-
756- default :
757- llvm_unreachable (" unexpected activity arg" );
758- }
759- }
760-
761- // handle dInputs
762- for (auto [idx, act] : llvm::enumerate (inpActivity)) {
763- auto iattr = cast<ActivityAttr>(act);
764- auto val = iattr.getValue ();
765-
766- if (val == Activity::enzyme_active) {
767- mlir::Value res = uop.getOutputs ()[out_idx];
768- if (!res.use_empty ()) {
769- out_ty.push_back (res.getType ());
770- outs_args.push_back (res);
771- newInActivityArgs.push_back (iattr);
772- } else {
773- // TODO: check if we can relax immutability here
774- if (!isMutable (res.getType ())) {
775- changed = true ;
776- auto new_const = ActivityAttr::get (rewriter.getContext (),
777- Activity::enzyme_const);
778- newInActivityArgs.push_back (new_const);
779- } else {
780- // noop even if its not used.
781- out_ty.push_back (res.getType ());
782- outs_args.push_back (res);
783- newInActivityArgs.push_back (iattr);
784- }
785- }
786-
787- ++out_idx;
788- } else if (val == Activity::enzyme_activenoneed) {
789- mlir::Value res = uop.getOutputs ()[out_idx];
790- out_ty.push_back (res.getType ());
791- outs_args.push_back (res);
792- newInActivityArgs.push_back (iattr);
793- ++out_idx;
794- llvm_unreachable (" unsupported arg activenoneed" );
795- } else {
796- newInActivityArgs.push_back (iattr);
797- }
798- }
799-
800- if (!changed)
801- return failure ();
802-
803- ArrayAttr newInActivity =
804- ArrayAttr::get (rewriter.getContext (),
805- llvm::ArrayRef<Attribute>(newInActivityArgs.begin (),
806- newInActivityArgs.end ()));
807- ArrayAttr newRetActivity =
808- ArrayAttr::get (rewriter.getContext (),
809- llvm::ArrayRef<Attribute>(newRetActivityArgs.begin (),
810- newRetActivityArgs.end ()));
811599
812- AutoDiffOp newOp = rewriter.create <AutoDiffOp>(
813- uop.getLoc (), out_ty, uop.getFnAttr (), in_args, newInActivity,
814- newRetActivity, uop.getWidthAttr (), uop.getStrongZeroAttr ());
815-
816- // Map old uses of uop to newOp
817- auto oldIdx = 0 ;
818- auto newIdx = 0 ;
819- for (auto [idx, old_act, new_act] :
820- llvm::enumerate (retActivity, newRetActivityArgs)) {
821- auto iattr = cast<ActivityAttr>(old_act);
822- auto old_val = iattr.getValue ();
823- auto new_val = new_act.getValue ();
824-
825- if (old_val == new_val) {
826- // don't index into op if no primal is returned
827- if (old_val == Activity::enzyme_constnoneed ||
828- old_val == Activity::enzyme_activenoneed ||
829- old_val == Activity::enzyme_dupnoneed) {
830- continue ;
831- }
832- // replace current Primal
833- uop.getOutputs ()[oldIdx++].replaceAllUsesWith (
834- newOp.getOutputs ()[newIdx++]);
835- } else {
836- // handle all substitutions
837- if (new_val == Activity::enzyme_activenoneed &&
838- old_val == Activity::enzyme_active) {
839- ++oldIdx; // skip active primal
840- } else if (new_val == Activity::enzyme_constnoneed &&
841- old_val == Activity::enzyme_const) {
842- ++oldIdx; // skip const primal
843- } else if (old_val == Activity::enzyme_active &&
844- new_val == Activity::enzyme_const) {
845- uop.getOutputs ()[oldIdx++].replaceAllUsesWith (
846- newOp.getOutputs ()[newIdx++]);
847- } else if (old_val == Activity::enzyme_active &&
848- new_val == Activity::enzyme_constnoneed) {
849- ++oldIdx;
850- } else if (old_val == Activity::enzyme_activenoneed &&
851- new_val == Activity::enzyme_constnoneed) {
852- // just skip
853- }
854- }
855- }
856-
857- for (auto [idx, old_act, new_act] :
858- llvm::enumerate (inpActivity, newInActivityArgs)) {
859- auto iattr = cast<ActivityAttr>(old_act);
860- auto old_val = iattr.getValue ();
861- auto new_val = new_act.getValue ();
600+ // Overload for AutoDiffOp
601+ static inline AutoDiffOp createOp (PatternRewriter &rewriter, AutoDiffOp uop,
602+ ArrayRef<Type> out_ty,
603+ ArrayRef<Value> in_args,
604+ ArrayAttr newInActivity,
605+ ArrayAttr newRetActivity) {
606+ return rewriter.create <AutoDiffOp>(
607+ uop.getLoc (), out_ty, uop.getFnAttr (), in_args, newInActivity,
608+ newRetActivity, uop.getWidthAttr (), uop.getStrongZeroAttr ());
609+ }
862610
863- if (old_val == new_val) {
864- if (old_val == Activity::enzyme_active ||
865- old_val == Activity::enzyme_activenoneed) {
866- uop.getOutputs ()[oldIdx++].replaceAllUsesWith (
867- newOp.getOutputs ()[newIdx++]);
868- } else {
869- continue ;
870- }
871- } else {
872- if (old_val == Activity::enzyme_active &&
873- new_val == Activity::enzyme_const) {
874- oldIdx++; // skip derivative
875- }
876- }
877- }
878- rewriter.eraseOp (uop);
879- return success ();
880- }
881- };
611+ // Overload for AutoDiffRegionOp
612+ static inline AutoDiffRegionOp
613+ createOp (PatternRewriter &rewriter, AutoDiffRegionOp uop, ArrayRef<Type> out_ty,
614+ ArrayRef<Value> in_args, ArrayAttr newInActivity,
615+ ArrayAttr newRetActivity) {
616+ return rewriter.create <AutoDiffRegionOp>(
617+ uop.getLoc (), out_ty, in_args, newInActivity, newRetActivity,
618+ uop.getWidthAttr (), uop.getStrongZeroAttr (), uop.getFnAttr ());
619+ }
882620
883- class ReverseRetOpt2 final : public OpRewritePattern<AutoDiffRegionOp> {
621+ template <typename SourceOp>
622+ class ReverseRetOpt final : public OpRewritePattern<SourceOp> {
884623public:
885- using OpRewritePattern<AutoDiffRegionOp >::OpRewritePattern;
624+ using OpRewritePattern<SourceOp >::OpRewritePattern;
886625
887- LogicalResult matchAndRewrite (AutoDiffRegionOp uop,
626+ LogicalResult matchAndRewrite (SourceOp uop,
888627 PatternRewriter &rewriter) const override {
889628 // early return if there are no outputs
890629 if (uop.getOutputs ().size () == 0 )
@@ -1093,9 +832,8 @@ class ReverseRetOpt2 final : public OpRewritePattern<AutoDiffRegionOp> {
1093832 llvm::ArrayRef<Attribute>(newRetActivityArgs.begin (),
1094833 newRetActivityArgs.end ()));
1095834
1096- AutoDiffRegionOp newOp = rewriter.create <AutoDiffRegionOp>(
1097- uop.getLoc (), out_ty, in_args, newInActivity,
1098- newRetActivity, uop.getWidthAttr (), uop.getStrongZeroAttr (),uop.getFnAttr ());
835+ SourceOp newOp =
836+ createOp (rewriter, uop, out_ty, in_args, newInActivity, newRetActivity);
1099837
1100838 // Map old uses of uop to newOp
1101839 auto oldIdx = 0 ;
@@ -1163,14 +901,15 @@ class ReverseRetOpt2 final : public OpRewritePattern<AutoDiffRegionOp> {
1163901 return success ();
1164902 }
1165903};
904+
1166905void AutoDiffOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
1167906 MLIRContext *context) {
1168- patterns.add <ReverseRetOpt>(context);
907+ patterns.add <ReverseRetOpt<AutoDiffOp> >(context);
1169908}
1170909
1171910void AutoDiffRegionOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
1172911 MLIRContext *context) {
1173- patterns.add <ReverseRetOpt2 >(context);
912+ patterns.add <ReverseRetOpt<AutoDiffRegionOp> >(context);
1174913}
1175914// ===----------------------------------------------------------------------===//
1176915// SampleOp
0 commit comments