@@ -794,6 +794,148 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
794
794
return success ();
795
795
});
796
796
797
+ // split with fixed-size parts
798
+ // Arguments:
799
+ // - input: the tensor to split
800
+ // Attributes:
801
+ // - axis: the axis along which to split the input
802
+ // - num_outputs: the number of outputs to produce
803
+ // Outputs:
804
+ // - outputs: the produced outputs. Variadic with num_outputs elements.
805
+ // Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
806
+ // tensors
807
+ // so we need to unpack the list
808
+ patterns.onOp (
809
+ " Split" , 1 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
810
+ Value self;
811
+ int64_t axis;
812
+ int64_t num_outputs;
813
+ if (binder.tensorOperand (self))
814
+ return rewriter.notifyMatchFailure (
815
+ binder.op , " Not converting to AtenSplitTensorOp due to input "
816
+ " tensor mismatch" );
817
+ if (binder.s64IntegerAttr (axis, " axis" , 0 ))
818
+ return rewriter.notifyMatchFailure (binder.op ,
819
+ " Failed to get axis attribute" );
820
+ if (binder.s64IntegerAttr (num_outputs, " num_outputs" , 0 ))
821
+ return rewriter.notifyMatchFailure (
822
+ binder.op , " Failed to get num_outputs attribute" );
823
+
824
+ auto result0Ty =
825
+ binder.op ->getResult (0 ).getType ().cast <Torch::ValueTensorType>();
826
+ auto selfTy = self.getType ().cast <Torch::ValueTensorType>();
827
+
828
+ int64_t dim = axis;
829
+ if (dim < 0 )
830
+ dim += selfTy.getSizes ().size ();
831
+
832
+ // set intermediate shape to the shape of the first result
833
+ // if the results are of different shapes
834
+ // set the splitted axis to variable shape
835
+ llvm::SmallVector<int64_t > intermediateShape (result0Ty.getSizes ());
836
+ for (auto result : binder.op ->getResultTypes ()) {
837
+ int64_t d = result.cast <Torch::ValueTensorType>().getSizes ()[dim];
838
+ intermediateShape[dim] = d == intermediateShape[dim] ? d : -1 ;
839
+ }
840
+
841
+ Value dimValue = rewriter.create <Torch::ConstantIntOp>(
842
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
843
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), dim));
844
+
845
+ Value splitSize = rewriter.create <Torch::ConstantIntOp>(
846
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
847
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), num_outputs));
848
+
849
+ // TODO: Attempting to use the shape expected by the ONNX mlir as ground
850
+ // truth. For now just use dynamic shapes.
851
+ auto resultOuterType =
852
+ Torch::ListType::get (rewriter.getType <Torch::ValueTensorType>(
853
+ /* std::optional<llvm::ArrayRef<int64_t>>=*/ intermediateShape,
854
+ result0Ty.getOptionalDtype ()));
855
+ Torch::AtenSplitTensorOp new_op =
856
+ rewriter.create <Torch::AtenSplitTensorOp>(
857
+ binder.getLoc (), resultOuterType, self, splitSize, dimValue);
858
+
859
+ // the onnx op is variadic with multiple results, but AtenSplitWithSizes
860
+ // outputs a list so we need to unpack the list
861
+ rewriter.replaceOpWithNewOp <Torch::PrimListUnpackOp>(
862
+ binder.op , binder.op ->getResults ().getType (), new_op.getResult ());
863
+
864
+ return success ();
865
+ });
866
+
867
+ // split with variable parts
868
+ // Arguments:
869
+ // - input: the tensor to split
870
+ // - split: the sizes of the splits to be produced
871
+ // Attributes:
872
+ // - axis: the axis along which to split the input
873
+ // - num_outputs: the number of outputs to produce
874
+ // Outputs:
875
+ // - outputs: the produced outputs. Variadic with num_outputs elements.
876
+ // Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
877
+ // tensors
878
+ // so we need to unpack the list
879
+ patterns.onOp (
880
+ " Split" , 1 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
881
+ Value self;
882
+ Value split;
883
+ int64_t axis;
884
+ int64_t num_outputs;
885
+ if (binder.tensorOperandAtIndex (self, 0 ) ||
886
+ binder.tensorOperandAtIndex (split, 1 ))
887
+ return rewriter.notifyMatchFailure (
888
+ binder.op , " Not converting to AtenSplitWithSizesOp due to input "
889
+ " tensor mismatch" );
890
+ if (binder.s64IntegerAttr (axis, " axis" , 0 ))
891
+ return rewriter.notifyMatchFailure (binder.op ,
892
+ " Failed to get axis attribute" );
893
+ if (binder.s64IntegerAttr (num_outputs, " num_outputs" , 0 ))
894
+ return rewriter.notifyMatchFailure (
895
+ binder.op , " Failed to get num_outputs attribute" );
896
+
897
+ auto result0Ty =
898
+ binder.op ->getResult (0 ).getType ().cast <Torch::ValueTensorType>();
899
+ auto selfTy =
900
+ cast<Torch::ValueTensorType>(binder.op ->getOperand (0 ).getType ());
901
+
902
+ int64_t dim = axis;
903
+ if (dim < 0 )
904
+ dim += selfTy.getSizes ().size ();
905
+
906
+ llvm::SmallVector<int64_t > intermediateShape (result0Ty.getSizes ());
907
+ for (auto result : binder.op ->getResultTypes ()) {
908
+ int64_t d = result.cast <Torch::ValueTensorType>().getSizes ()[dim];
909
+ intermediateShape[dim] = d == intermediateShape[dim] ? d : -1 ;
910
+ }
911
+
912
+ Torch::PrimTolistOp splitToList = rewriter.create <Torch::PrimTolistOp>(
913
+ binder.getLoc (),
914
+ Torch::ListType::get (rewriter.getType <Torch::IntType>()), split);
915
+
916
+ Value dimValue = rewriter.create <Torch::ConstantIntOp>(
917
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
918
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), dim));
919
+
920
+ // TODO: Attempting to use the shape expected by the ONNX mlir as ground
921
+ // truth. For now just use dynamic shapes.
922
+ auto resultOuterType =
923
+ Torch::ListType::get (rewriter.getType <Torch::ValueTensorType>(
924
+ /* std::optional<llvm::ArrayRef<int64_t>>=*/ intermediateShape,
925
+ result0Ty.getOptionalDtype ()));
926
+ Torch::AtenSplitWithSizesOp new_op =
927
+ rewriter.create <Torch::AtenSplitWithSizesOp>(
928
+ binder.getLoc (), resultOuterType, self,
929
+ splitToList.getResult (0 ), dimValue);
930
+
931
+ // the onnx op is variadic with multiple results, but AtenSplitWithSizes
932
+ // outputs a list so we need to unpack the list
933
+ rewriter.replaceOpWithNewOp <Torch::PrimListUnpackOp>(
934
+ binder.op , binder.op ->getResults ().getType (), new_op.getResult ());
935
+
936
+ return success ();
937
+ });
938
+
797
939
patterns.onOp (" Tan" , 7 ,
798
940
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
799
941
Torch::ValueTensorType resultType;
0 commit comments