Skip to content

Commit a879da8

Browse files
committed
Add tests
Signed-off-by: Ryan Nett <[email protected]>
1 parent 9c2e647 commit a879da8

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

tensorflow-core-kotlin/tensorflow-core-kotlin-api/pom.xml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@
6060
<artifactId>jmh-generator-annprocess</artifactId>
6161
<scope>test</scope>
6262
</dependency>
63+
<dependency>
64+
<groupId>org.jetbrains.kotlin</groupId>
65+
<artifactId>kotlin-test-junit5</artifactId>
66+
<version>1.4.31</version>
67+
<scope>test</scope>
68+
</dependency>
6369
<!-- Include native binaries dependencies only for testing -->
6470
<dependency>
6571
<groupId>org.tensorflow</groupId>
@@ -221,6 +227,16 @@
221227
</dependency>
222228
</dependencies>
223229
</plugin>
230+
<plugin>
231+
<groupId>org.apache.maven.plugins</groupId>
232+
<artifactId>maven-surefire-plugin</artifactId>
233+
<version>2.22.2</version>
234+
<!-- <configuration>-->
235+
<!-- <forkCount>1</forkCount>-->
236+
<!-- <reuseForks>false</reuseForks>-->
237+
<!-- <argLine>-Xmx2G -XX:MaxPermSize=256m</argLine>-->
238+
<!-- </configuration>-->
239+
</plugin>
224240
</plugins>
225241
</build>
226242
</project>

tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/main/kotlin/org/tensorflow/op/kotlin/OpsHelpers.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public fun KotlinOps.withSubScope(childScopeName: String): KotlinOps = KotlinOps
4545
*
4646
* @see org.tensorflow.op.Scope.withSubScope
4747
*/
48+
// TODO should be a decorator too, when possible
4849
public inline fun <R> KotlinOps.withSubScope(childScopeName: String, block: KotlinOps.() -> R): R {
4950
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
5051
return withSubScope(childScopeName).run(block)

tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/test/kotlin/org/tensorflow/Example.kt renamed to tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/test/kotlin/org/tensorflow/ExampleTest.kt

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,32 @@
1616
*/
1717
package org.tensorflow
1818

19-
import org.junit.jupiter.api.Test
2019
import org.tensorflow.ndarray.Shape
2120
import org.tensorflow.ndarray.get
2221
import org.tensorflow.op.kotlin.KotlinOps
2322
import org.tensorflow.op.kotlin.tf
2423
import org.tensorflow.op.kotlin.withSubScope
2524
import org.tensorflow.types.TFloat32
25+
import kotlin.test.Test
2626

27-
public fun KotlinOps.DenseLayer(
27+
private fun KotlinOps.DenseLayer(
2828
name: String,
2929
x: Operand<TFloat32>,
3030
n: Int,
3131
activation: KotlinOps.(Operand<TFloat32>) -> Operand<TFloat32> = { tf.nn.relu(it) }
3232
): Operand<TFloat32> = tf.withSubScope(name) {
3333
val inputDims = x.shape()[1]
34-
val W = tf.variable(tf.math.add(tf.zeros(tf.array(inputDims.toInt(), n), TFloat32::class.java), constant(1f)))
35-
val b = tf.variable(tf.math.add(tf.zeros(tf.array(n), TFloat32::class.java), constant(1f)))
36-
activation(tf.math.add(tf.linalg.matMul(x, W), b))
34+
val W = tf.variable(tf.ones<TFloat32>(tf.array(inputDims.toInt(), n)))
35+
val b = tf.variable(tf.ones<TFloat32>(tf.array(n)))
36+
activation((x matMul W) + b)
3737
}
3838

39-
public class Example {
39+
public class ExampleTest {
4040
@Test
4141
public fun mnistExample() {
4242
Graph {
4343
val input = tf.placeholderWithDefault(
44-
tf.math.add(tf.zeros(tf.array(1, 28, 28, 3)), tf.constant(1f)),
44+
tf.ones<TFloat32>(tf.array(1, 28, 28, 3)),
4545
Shape.of(-1, 28, 28, 3)
4646
)
4747

@@ -53,10 +53,11 @@ public class Example {
5353
DenseLayer("OutputLayer", x, 10) { tf.math.sigmoid(x) }
5454
}
5555

56-
// useSession {
57-
// val outputValue = it.run(fetches = listOf(output))[output]
58-
// println(outputValue.data())
59-
// }
56+
useSession { session ->
57+
58+
val outputValue = session.runner().fetch(output).run()[0] as TFloat32
59+
println(outputValue.getFloat(0))
60+
}
6061
}
6162
}
6263
}

0 commit comments

Comments
 (0)