@@ -655,14 +655,40 @@ static LogicalResult Verify(UnpackOp op) {
655
655
static llvm::Optional<int64_t > ExtractConstantIntFromTensor (Value *value) {
656
656
ElementsAttr attr;
657
657
if (!matchPattern (value, m_Constant (&attr))) return {};
658
-
659
658
IntegerAttr int_attr = attr.getValue (llvm::None).cast <IntegerAttr>();
660
659
return int_attr.getValue ().getSExtValue ();
661
660
}
662
661
662
+ // Returns a RankedTensorType which is similar to `input_type` but replaces the
663
+ // dimension size of `dim` with `dim_size`. For example,
664
+ // `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
665
+ // `tensor<3x2xi32>`.
666
+ static RankedTensorType SubstituteRankedTensorTypeDimSize (
667
+ RankedTensorType input_type, int64_t dim, int64_t dim_size) {
668
+ auto shape = input_type.getShape ().vec ();
669
+ shape[dim] = dim_size;
670
+ return RankedTensorType::get (shape, input_type.getElementType ());
671
+ }
672
+
673
+ // Verifies the output tensor types of SplitOp or SplitVOp.
674
+ template <typename ExpectedOutputTypeGetter>
675
+ static LogicalResult VerifySplitOpOutputTypes (
676
+ Operation *op, int64_t num_splits,
677
+ ExpectedOutputTypeGetter get_expected_output_type) {
678
+ for (int64_t i = 0 ; i < num_splits; ++i) {
679
+ auto expected_output_type = get_expected_output_type (i);
680
+ Value *output = op->getResult (i);
681
+ auto output_type = output->getType ().dyn_cast <RankedTensorType>();
682
+ if (!output_type || output_type != expected_output_type)
683
+ return op->emitOpError ()
684
+ << " output #" << i << " should be " << expected_output_type;
685
+ }
686
+ return success ();
687
+ }
688
+
663
689
static LogicalResult Verify (SplitOp op) {
664
690
int64_t num_splits = op.num_splits ().getSExtValue ();
665
- if (op.getOperation ()-> getNumResults () != num_splits)
691
+ if (op.getNumResults () != num_splits)
666
692
return op.emitOpError (" output count should match 'num_splits' attribute" );
667
693
668
694
// If 'split_dim' is not a constant, there are no other checks.
@@ -688,21 +714,100 @@ static LogicalResult Verify(SplitOp op) {
688
714
if (dim_size % num_splits != 0 )
689
715
return op.emitOpError (" 'num_splits' should evenly divide 'split_dim' axis" );
690
716
691
- // Creates sliced tensor type.
692
- auto slice_shape = input_type.getShape ().vec ();
693
- slice_shape[split_dim] = dim_size / num_splits;
694
- RankedTensorType slice_type =
695
- RankedTensorType::get (slice_shape, input_type.getElementType ());
717
+ // Verifies output tensor types.
718
+ RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize (
719
+ input_type, split_dim, dim_size / num_splits);
720
+ return VerifySplitOpOutputTypes (
721
+ op.getOperation (), num_splits,
722
+ [expected_output_type](int64_t ) { return expected_output_type; });
723
+ }
724
+
725
+ static LogicalResult Verify (SplitVOp op) {
726
+ int64_t num_splits = op.num_splits ().getSExtValue ();
727
+ if (op.getNumResults () != num_splits)
728
+ return op.emitOpError (" output count should match 'num_splits' attribute" );
729
+
730
+ // If 'split_dim' is not a constant, there are no other checks.
731
+ llvm::Optional<int64_t > split_dim_opt =
732
+ ExtractConstantIntFromTensor (op.split_dim ());
733
+ if (!split_dim_opt) return success ();
734
+
735
+ // If 'input' is not a ranked tensor, there are no other checks.
736
+ auto input_type = op.value ()->getType ().dyn_cast <RankedTensorType>();
737
+ if (!input_type) return success ();
738
+
739
+ int64_t split_dim = split_dim_opt.getValue ();
740
+ const int64_t rank = input_type.getRank ();
741
+ if (split_dim < 0 ) split_dim += rank;
742
+ if (split_dim < 0 || split_dim >= rank)
743
+ return op.emitOpError (" 'split_dim' should be in [-rank, rank)" );
744
+
745
+ // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
746
+ // there are no other checks.
747
+ const int64_t dim_size = input_type.getDimSize (split_dim);
748
+ if (ShapedType::isDynamic (dim_size)) return success ();
749
+
750
+ // If 'size_splits' is not a constant, there are no other checks.
751
+ ElementsAttr size_splits_attr;
752
+ if (!matchPattern (op.size_splits (), m_Constant (&size_splits_attr)))
753
+ return success ();
754
+
755
+ if (size_splits_attr.getNumElements () != num_splits) {
756
+ auto size_splits_type =
757
+ op.size_splits ()->getType ().cast <RankedTensorType>();
758
+ RankedTensorType expected_size_splits_type =
759
+ RankedTensorType::get ({num_splits}, size_splits_type.getElementType ());
760
+ return op.emitOpError (" 'size_splits' should be " )
761
+ << expected_size_splits_type;
762
+ }
763
+
764
+ // Normalizes and verifies 'size_splits'.
765
+ // Note: TensorFlow allows one -1 element in 'size_splits'. The -1 element
766
+ // means the rest of the dimension size.
767
+ llvm::SmallVector<int64_t , 4 > size_splits;
768
+ size_splits.reserve (num_splits);
769
+
770
+ int64_t negative_size_split_loc = -1 ;
771
+ int64_t total_size_splits = 0 ;
696
772
697
- // Verifies result tensor types.
698
773
for (int64_t i = 0 ; i < num_splits; ++i) {
699
- Value *result = op.getResult (i);
700
- auto result_type = result->getType ().dyn_cast <RankedTensorType>();
701
- if (!result_type || result_type != slice_type)
702
- return op.emitOpError () << " output #" << i << " should be " << slice_type;
774
+ auto size_split_attr = size_splits_attr.getValue <IntegerAttr>(i);
775
+ int64_t size_split = size_split_attr.getValue ().getSExtValue ();
776
+ size_splits.push_back (size_split);
777
+ if (size_split >= 0 ) {
778
+ total_size_splits += size_split;
779
+ continue ;
780
+ }
781
+ if (size_split < -1 )
782
+ return op.emitOpError (
783
+ " elements of 'size_splits' should be greater than or equal to -1" );
784
+ if (negative_size_split_loc != -1 )
785
+ return op.emitOpError (" 'size_splits' can only have one -1" );
786
+ negative_size_split_loc = i;
703
787
}
704
788
705
- return success ();
789
+ if (negative_size_split_loc != -1 ) {
790
+ if (total_size_splits > dim_size)
791
+ return op.emitOpError (
792
+ " sum of non-negative elements of 'size_splits' is greater than the "
793
+ " dimension size of 'split_dim' axis" );
794
+ size_splits[negative_size_split_loc] = dim_size - total_size_splits;
795
+ total_size_splits = dim_size;
796
+ }
797
+
798
+ if (total_size_splits != dim_size)
799
+ return op.emitOpError (
800
+ " sum of 'size_splits' should match the dimension size of 'split_dim' "
801
+ " axis" );
802
+
803
+ // Verifies result tensor types.
804
+ auto get_expected_output_type = [input_type, split_dim,
805
+ &size_splits](int64_t i) {
806
+ return SubstituteRankedTensorTypeDimSize (input_type, split_dim,
807
+ size_splits[i]);
808
+ };
809
+ return VerifySplitOpOutputTypes (op.getOperation (), num_splits,
810
+ get_expected_output_type);
706
811
}
707
812
708
813
// ===----------------------------------------------------------------------===//
0 commit comments