Skip to content

Commit 70823b2

Browse files
committed
requireShape helper methods
Signed-off-by: Ryan Nett <[email protected]>
1 parent be5eff0 commit 70823b2

File tree

1 file changed

+42
-0
lines changed
  • tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/main/kotlin/org/tensorflow

1 file changed

+42
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow
18+
19+
import org.tensorflow.ndarray.Shape
20+
import org.tensorflow.ndarray.Shaped
21+
22+
/**
23+
* Require the [Shaped] object have a certain shape.
24+
*
25+
* Throws [IllegalStateException] on failure.
26+
*/
27+
public fun <T: Shaped> T.requireShape(shape: Shape): T = apply{
28+
check(this.shape().isCompatibleWith(shape)){
29+
"Shape ${this.shape()} is not compatible with the required shape $shape"
30+
}
31+
}
32+
33+
/**
34+
* Require the [Shaped] object have a certain shape.
35+
*
36+
* Throws [IllegalStateException] on failure.
37+
*/
38+
public fun <T: Shaped> T.requireShape(vararg shape: Long): T = apply{
39+
check(this.shape().isCompatibleWith(Shape.of(*shape))){
40+
"Shape ${this.shape()} is not compatible with the required shape $shape"
41+
}
42+
}

0 commit comments

Comments
 (0)