Skip to content

Commit 531199b

Browse files
loganchientensorflower-gardener
authored andcommitted
Add tfl.split_v operation verification
This commit adds `tfl.split_v` operation verification. PiperOrigin-RevId: 266467137
1 parent 4221413 commit 531199b

File tree

4 files changed

+352
-29
lines changed

4 files changed

+352
-29
lines changed

tensorflow/compiler/mlir/lite/ir/tfl_ops.cc

+118-13
Original file line numberDiff line numberDiff line change
@@ -655,14 +655,40 @@ static LogicalResult Verify(UnpackOp op) {
655655
static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value *value) {
656656
ElementsAttr attr;
657657
if (!matchPattern(value, m_Constant(&attr))) return {};
658-
659658
IntegerAttr int_attr = attr.getValue(llvm::None).cast<IntegerAttr>();
660659
return int_attr.getValue().getSExtValue();
661660
}
662661

662+
// Returns a RankedTensorType which is similar to `input_type` but replaces the
663+
// dimension size of `dim` with `dim_size`. For example,
664+
// `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
665+
// `tensor<3x2xi32>`.
666+
static RankedTensorType SubstituteRankedTensorTypeDimSize(
667+
RankedTensorType input_type, int64_t dim, int64_t dim_size) {
668+
auto shape = input_type.getShape().vec();
669+
shape[dim] = dim_size;
670+
return RankedTensorType::get(shape, input_type.getElementType());
671+
}
672+
673+
// Verifies the output tensor types of SplitOp or SplitVOp.
674+
template <typename ExpectedOutputTypeGetter>
675+
static LogicalResult VerifySplitOpOutputTypes(
676+
Operation *op, int64_t num_splits,
677+
ExpectedOutputTypeGetter get_expected_output_type) {
678+
for (int64_t i = 0; i < num_splits; ++i) {
679+
auto expected_output_type = get_expected_output_type(i);
680+
Value *output = op->getResult(i);
681+
auto output_type = output->getType().dyn_cast<RankedTensorType>();
682+
if (!output_type || output_type != expected_output_type)
683+
return op->emitOpError()
684+
<< "output #" << i << " should be " << expected_output_type;
685+
}
686+
return success();
687+
}
688+
663689
static LogicalResult Verify(SplitOp op) {
664690
int64_t num_splits = op.num_splits().getSExtValue();
665-
if (op.getOperation()->getNumResults() != num_splits)
691+
if (op.getNumResults() != num_splits)
666692
return op.emitOpError("output count should match 'num_splits' attribute");
667693

668694
// If 'split_dim' is not a constant, there are no other checks.
@@ -688,21 +714,100 @@ static LogicalResult Verify(SplitOp op) {
688714
if (dim_size % num_splits != 0)
689715
return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
690716

691-
// Creates sliced tensor type.
692-
auto slice_shape = input_type.getShape().vec();
693-
slice_shape[split_dim] = dim_size / num_splits;
694-
RankedTensorType slice_type =
695-
RankedTensorType::get(slice_shape, input_type.getElementType());
717+
// Verifies output tensor types.
718+
RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
719+
input_type, split_dim, dim_size / num_splits);
720+
return VerifySplitOpOutputTypes(
721+
op.getOperation(), num_splits,
722+
[expected_output_type](int64_t) { return expected_output_type; });
723+
}
724+
725+
static LogicalResult Verify(SplitVOp op) {
726+
int64_t num_splits = op.num_splits().getSExtValue();
727+
if (op.getNumResults() != num_splits)
728+
return op.emitOpError("output count should match 'num_splits' attribute");
729+
730+
// If 'split_dim' is not a constant, there are no other checks.
731+
llvm::Optional<int64_t> split_dim_opt =
732+
ExtractConstantIntFromTensor(op.split_dim());
733+
if (!split_dim_opt) return success();
734+
735+
// If 'input' is not a ranked tensor, there are no other checks.
736+
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
737+
if (!input_type) return success();
738+
739+
int64_t split_dim = split_dim_opt.getValue();
740+
const int64_t rank = input_type.getRank();
741+
if (split_dim < 0) split_dim += rank;
742+
if (split_dim < 0 || split_dim >= rank)
743+
return op.emitOpError("'split_dim' should be in [-rank, rank)");
744+
745+
// If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
746+
// there are no other checks.
747+
const int64_t dim_size = input_type.getDimSize(split_dim);
748+
if (ShapedType::isDynamic(dim_size)) return success();
749+
750+
// If 'size_splits' is not a constant, there are no other checks.
751+
ElementsAttr size_splits_attr;
752+
if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
753+
return success();
754+
755+
if (size_splits_attr.getNumElements() != num_splits) {
756+
auto size_splits_type =
757+
op.size_splits()->getType().cast<RankedTensorType>();
758+
RankedTensorType expected_size_splits_type =
759+
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
760+
return op.emitOpError("'size_splits' should be ")
761+
<< expected_size_splits_type;
762+
}
763+
764+
// Normalizes and verifies 'size_splits'.
765+
// Note: TensorFlow allows one -1 element in 'size_splits'. The -1 element
766+
// means the rest of the dimension size.
767+
llvm::SmallVector<int64_t, 4> size_splits;
768+
size_splits.reserve(num_splits);
769+
770+
int64_t negative_size_split_loc = -1;
771+
int64_t total_size_splits = 0;
696772

697-
// Verifies result tensor types.
698773
for (int64_t i = 0; i < num_splits; ++i) {
699-
Value *result = op.getResult(i);
700-
auto result_type = result->getType().dyn_cast<RankedTensorType>();
701-
if (!result_type || result_type != slice_type)
702-
return op.emitOpError() << "output #" << i << " should be " << slice_type;
774+
auto size_split_attr = size_splits_attr.getValue<IntegerAttr>(i);
775+
int64_t size_split = size_split_attr.getValue().getSExtValue();
776+
size_splits.push_back(size_split);
777+
if (size_split >= 0) {
778+
total_size_splits += size_split;
779+
continue;
780+
}
781+
if (size_split < -1)
782+
return op.emitOpError(
783+
"elements of 'size_splits' should be greater than or equal to -1");
784+
if (negative_size_split_loc != -1)
785+
return op.emitOpError("'size_splits' can only have one -1");
786+
negative_size_split_loc = i;
703787
}
704788

705-
return success();
789+
if (negative_size_split_loc != -1) {
790+
if (total_size_splits > dim_size)
791+
return op.emitOpError(
792+
"sum of non-negative elements of 'size_splits' is greater than the "
793+
"dimension size of 'split_dim' axis");
794+
size_splits[negative_size_split_loc] = dim_size - total_size_splits;
795+
total_size_splits = dim_size;
796+
}
797+
798+
if (total_size_splits != dim_size)
799+
return op.emitOpError(
800+
"sum of 'size_splits' should match the dimension size of 'split_dim' "
801+
"axis");
802+
803+
// Verifies result tensor types.
804+
auto get_expected_output_type = [input_type, split_dim,
805+
&size_splits](int64_t i) {
806+
return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
807+
size_splits[i]);
808+
};
809+
return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
810+
get_expected_output_type);
706811
}
707812

708813
//===----------------------------------------------------------------------===//

tensorflow/compiler/mlir/lite/ir/tfl_ops.td

+6-7
Original file line numberDiff line numberDiff line change
@@ -2178,9 +2178,6 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
21782178
let hasOptions = 1;
21792179
}
21802180

2181-
def Rank0I32Tensor : Type<And<[I32Tensor.predicate, HasAnyRankOfPred<[0]>]>,
2182-
"tensor<i32>">;
2183-
21842181
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, SameOperandsAndResultsScale]> {
21852182
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
21862183

@@ -2191,7 +2188,7 @@ def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, SameOperandsAndResultsScale]> {
21912188
}];
21922189

21932190
let arguments = (ins
2194-
Rank0I32Tensor:$split_dim,
2191+
0DTensorOf<[I32]>:$split_dim,
21952192
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
21962193
PositiveI32Attr:$num_splits
21972194
);
@@ -2216,15 +2213,17 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]
22162213

22172214
let arguments = (ins
22182215
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
2219-
I32Tensor:$size_splits,
2220-
I32Tensor:$split_dim,
2221-
I32Attr:$num_splits
2216+
1DTensorOf<[I32]>:$size_splits,
2217+
0DTensorOf<[I32]>:$split_dim,
2218+
PositiveI32Attr:$num_splits
22222219
);
22232220

22242221
let results = (outs
22252222
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
22262223
);
22272224

2225+
let verifier = [{ return Verify(*this); }];
2226+
22282227
let hasOptions = 1;
22292228
}
22302229

tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -832,12 +832,12 @@ func @split(%arg0: tensor<i32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32>
832832
// CHECK: %0:3 = "tfl.split"(%arg0, %arg1) {num_splits = 3 : i32} : (tensor<i32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>)
833833
}
834834

835-
func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<1xi32>) -> tensor<1x4x2x3xf32> {
836-
%0:2 = "tf.SplitV"(%arg0, %arg1, %arg2) {num_split = 2 : i64} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
835+
func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<i32>) -> tensor<1x4x2x3xf32> {
836+
%0:2 = "tf.SplitV"(%arg0, %arg1, %arg2) {num_split = 2 : i64} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
837837
return %0#0 : tensor<1x4x2x3xf32>
838838

839839
// CHECK-LABEL: splitv
840-
// CHECK: %0:2 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
840+
// CHECK: %0:2 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
841841
}
842842

843843
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {

0 commit comments

Comments
 (0)