diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 4d07b678811..c822678fda6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -1,4 +1,4 @@ -/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020-2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -295,8 +295,8 @@ public Operand call(Scope scope, Operand argument) { } @Override - public Map call(Map arguments) { - // FIXME need to manage input/output operand lifetimes + public Result call(Map arguments) { + // FIXME need to manage input operand lifetimes Ops tf = Ops.create(); Map> inputs = new LinkedHashMap<>(arguments.size()); @@ -305,11 +305,11 @@ public Map call(Map arguments) { inputs.put(inputName, tf.constantOf((TType) argument)); } Map> outputs = tf.call(this, inputs); - Map tensorOutputs = new LinkedHashMap<>(outputs.size()); + LinkedHashMap tensorOutputs = new LinkedHashMap<>(outputs.size()); for (String outputName : outputs.keySet()) { tensorOutputs.put(outputName, outputs.get(outputName).asTensor()); } - return tensorOutputs; + return new Result(tensorOutputs); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java new file mode 100644 index 00000000000..a3560b068b1 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java @@ -0,0 +1,199 @@ +/* +Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. +Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.tensorflow.exceptions.TensorFlowException; +import org.tensorflow.proto.framework.RunMetadata; + +/** + * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s. + * + *

When this is closed it closes all the {@link Tensor}s inside it. If you maintain a reference + * to a value after this object has been closed it will throw an {@link IllegalStateException} upon + * access. + * + *

This class is not thread-safe with respect to the close operation. Multiple closers or one + * thread closing a tensor while another is reading may throw exceptions. + * + *

Note this class is used to manage the lifetimes of tensors produced by the TensorFlow runtime, + * from sessions and function calls. It is not used as an argument to {@code session.run} or + * function calls as users are in control of the creation of input tensors. + */ +public final class Result implements AutoCloseable, Iterable> { + @Override + public void close() { + if (!closed) { + for (Tensor t : list) { + try { + t.close(); + } catch (TensorFlowException e) { + logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e); + } + } + closed = true; + } else { + logger.warning("Closing an already closed Result"); + } + } + + @Override + public Iterator> iterator() { + if (!closed) { + return map.entrySet().iterator(); + } else { + throw new IllegalStateException("Result is closed"); + } + } + + /** + * Returns the number of outputs in this Result. + * + * @return The number of outputs. + */ + public int size() { + return map.size(); + } + + /** + * Gets the set containing all the tensor names. + * + * @return The tensor names set. + */ + public Set keySet() { + return Collections.unmodifiableSet(map.keySet()); + } + + /** + * Does this result object have a tensor for the supplied key? + * + * @param key The key to check. + * @return True if this result object has a tensor for this key. + */ + public boolean containsKey(String key) { + return map.containsKey(key); + } + + /** + * Gets the value from the container at the specified index. + * + *

Throws {@link IllegalStateException} if the container has been closed, and {@link + * IndexOutOfBoundsException} if the index is invalid. + * + * @param index The index to lookup. + * @return The value at the index. + */ + public Tensor get(int index) { + if (!closed) { + return list.get(index); + } else { + throw new IllegalStateException("Result is closed"); + } + } + + /** + * Gets the value from the container assuming it's not been closed. + * + *

Throws {@link IllegalStateException} if the container has been closed. + * + * @param key The key to lookup. + * @return Optional.of the value if it exists. + */ + public Optional get(String key) { + if (!closed) { + return Optional.ofNullable(map.get(key)); + } else { + throw new IllegalStateException("Result is closed"); + } + } + + /** + * Metadata about the run. + * + *

A RunMetadata + * protocol buffer. + */ + public Optional getMetadata() { + return Optional.ofNullable(metadata); + } + + /** + * Creates a Result from the names and values produced by {@link Session.Runner#run()}. + * + * @param names The output names. + * @param values The output values. + * @param metadata The run metadata, may be null. + */ + Result(List names, List values, RunMetadata metadata) { + this.map = new LinkedHashMap<>(); + this.list = new ArrayList<>(values); + + if (names.size() != values.size()) { + throw new IllegalArgumentException( + "Expected same number of names and values, found names.length = " + + names.size() + + ", values.length = " + + values.size()); + } + + for (int i = 0; i < names.size(); i++) { + Tensor old = this.map.put(names.get(i), values.get(i)); + if (old != null) { + throw new IllegalArgumentException( + "Name collision in the result set, two outputs are named '" + names.get(i) + "'"); + } + } + this.metadata = metadata; + this.closed = false; + } + + /** + * Creates a Result from the names and values. + * + * @param outputs The run outputs. + */ + Result(LinkedHashMap outputs) { + this.map = outputs; + this.list = new ArrayList<>(outputs.size()); + for (Map.Entry e : outputs.entrySet()) { + list.add(e.getValue()); + } + this.metadata = null; + this.closed = false; + } + + private final Map map; + + private final List list; + + private final RunMetadata metadata; + + private boolean closed; + + private static final Logger logger = Logger.getLogger(Result.class.getName()); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 4295dbb6c4a..35d81e7bc16 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -468,7 +468,7 @@ public List functions() { * @return list of output tensors, mapped by the signature name * @throws IllegalArgumentException if no function can be selected by default */ - public Map call(Map arguments) { + public Result call(Map arguments) { SessionFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 71fdcec3f41..76be5597cc1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -1,4 +1,4 @@ -/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -306,7 +306,9 @@ public Runner feed(Operand operand, Tensor t) { * @throws IllegalArgumentException if no output exists with the provided name */ public Runner fetch(String operation) { - return fetch(graph.outputOrThrow(operation)); + Runner r = fetch(graph.outputOrThrow(operation), false); + outputNames.add(operation); + return r; } /** @@ -336,6 +338,20 @@ public Runner fetch(String operation, int index) { * @return this session runner */ public Runner fetch(Output output) { + return fetch(output, true); + } + + /** + * Makes {@link #run()} return the Tensor referred to by {@code output}. + * + *

If {@code output} is a resource variable, will fetch the value. + * + * @param output the node to fetch the tensor from + * @param recordName Records the output name. If false the output name must be recorded by the + * calling method as otherwise the result object will throw on construction. + * @return this session runner + */ + private Runner fetch(Output output, boolean recordName) { if (output.env() != graph) { throw new IllegalStateException( "Can't fetch output " @@ -378,6 +394,9 @@ public Runner fetch(Output output) { } else { outputs.add(output); } + if (recordName) { + outputNames.add(output.name()); + } return this; } @@ -490,13 +509,13 @@ private void doInit() { * * @return list of resulting tensors fetched by this session runner */ - public List run() { + public Result run() { doInit(); return runNoInit(); } - List runNoInit() { - return runHelper(false).outputs; + Result runNoInit() { + return runHelper(false); } /** @@ -509,12 +528,12 @@ List runNoInit() { * * @return list of resulting tensors fetched by this session runner, with execution metadata */ - public Run runAndFetchMetadata() { + public Result runAndFetchMetadata() { doInit(); return runHelper(true); } - private Run runHelper(boolean wantMetadata) { + private Result runHelper(boolean wantMetadata) { TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()]; TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()]; int[] inputOpIndices = new int[inputs.size()]; @@ -569,10 +588,7 @@ private Run runHelper(boolean wantMetadata) { } finally { runRef.close(); } - Run ret = new Run(); - ret.outputs = outputs; - ret.metadata = metadata; - return ret; + return new Result(outputNames, outputs, metadata); } private class Reference implements AutoCloseable { @@ -602,6 +618,7 @@ public void close() { private final ArrayList> inputs = new ArrayList<>(); private final ArrayList inputTensors = new ArrayList<>(); private final ArrayList> outputs = new ArrayList<>(); + private final ArrayList outputNames = new ArrayList<>(); private final ArrayList targets = new ArrayList<>(); private RunOptions runOptions = null; } @@ -648,8 +665,9 @@ public SessionFunction function(Signature signature) { * * @param signature the signature of the function * @param arguments the arguments to call with. + * @return The results of the function call. */ - public Map run(Signature signature, Map arguments) { + public Result run(Signature signature, Map arguments) { return function(signature).call(arguments); } @@ -698,26 +716,6 @@ public void restore(String prefix) { setInitialized(); } - /** - * Output tensors and metadata obtained when executing a session. - * - *

See {@link Runner#runAndFetchMetadata()} - */ - public static final class Run { - - /** Tensors from requested fetches. */ - public List outputs; - - /** - * Metadata about the run. - * - *

A RunMetadata - * protocol buffer. - */ - public RunMetadata metadata; - } - Graph graph() { return graph; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java index 07bc418ac51..877ba1b2f2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -1,23 +1,22 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021-2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import java.io.IOException; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; /** @@ -89,7 +88,7 @@ public SessionFunction withNewSession(Session session) { } @Override - public Map call(Map arguments) { + public Result call(Map arguments) { Session.Runner runner = session.runner(); signature .getInputs() @@ -113,15 +112,16 @@ public Map call(Map arguments) { signature.getOutputs().values().forEach(x -> runner.fetch(x.name)); - List results = runner.run(); + Result results = runner.run(); - Map outputs = new LinkedHashMap<>(results.size()); + // Unpack the result object and rebuild it with the expected names. + LinkedHashMap outputs = new LinkedHashMap<>(results.size()); int i = 0; for (String outputName : signature.outputNames()) { outputs.put(outputName, results.get(i)); i++; } - return outputs; + return new Result(outputs); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 3b9deff9cd4..2ba3dc0a906 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -210,7 +210,7 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData *

When this methods retuns {@code true}, the tensor could be cast to a {@link SparseTensor * SparseTensor} to access its indices, values and denseShape tensors. * - * @retrun true if this tensor is a sparse + * @return true if this tensor is a sparse */ default boolean isSparse() { return false; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java index 0304d786494..1b83a1176ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import java.util.LinkedHashMap; @@ -28,7 +28,7 @@ public interface TensorFunction { /** * Invokes a function using the default eager session. * - *

Caller is responsible for closing all Tensors. + *

Caller is responsible for close the result object. * * @param arguments list of tensors to pass in input to the function, mapped by their signature * name @@ -37,7 +37,7 @@ public interface TensorFunction { * @throws IllegalArgumentException if the passed arguments don't match up to the function's * parameters. */ - Map call(Map arguments); + Result call(Map arguments); /** * Invokes a function with a single input and output using the default eager session. @@ -76,12 +76,11 @@ default Tensor call(Tensor tensor) { } String inputName = signature().inputNames().iterator().next(); - String outputName = signature().outputNames().iterator().next(); Map inputMap = new LinkedHashMap<>(); inputMap.put(inputName, tensor); - return call(inputMap).get(outputName); + return call(inputMap).get(0); } static Operand validateDescription( diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java deleted file mode 100644 index 330a40bae6b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.tensorflow; - -import java.util.ArrayList; -import java.util.Collection; - -public final class AutoCloseableList extends ArrayList - implements AutoCloseable { - - public AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 250ff9cc383..b303618eae2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -162,9 +162,9 @@ public void testFunctionWithTwoOutputs() { Map inputs = new HashMap<>(); inputs.put("x", TInt32.scalarOf(2)); - Map outputs = cf.call(inputs); - assertEquals(4, ((TInt32) outputs.get("dbl")).getInt()); - assertEquals(6, ((TInt32) outputs.get("trpl")).getInt()); + Result outputs = cf.call(inputs); + assertEquals(4, ((TInt32) outputs.get("dbl").get()).getInt()); + assertEquals(6, ((TInt32) outputs.get("trpl").get()).getInt()); } private static Signature square(Ops tf) { @@ -205,15 +205,14 @@ public void testGradientsGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run())) { + Result outputs = + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run()) { assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java index 62626c35641..0ad94ad2130 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java @@ -66,8 +66,7 @@ public void testCustomGradient() { assertEquals(DataType.DT_FLOAT, grads0[0].dataType()); try (TFloat32 c1 = TFloat32.vectorOf(3.0f, 2.0f, 1.0f, 0.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed(x, c1).fetch(grads0[0]).run())) { + Result outputs = s.runner().feed(x, c1).fetch(grads0[0]).run()) { assertEquals(1, outputs.size()); assertEquals(0.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index e4340da3275..28a549d72ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -14,6 +14,11 @@ ==============================================================================*/ package org.tensorflow; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.tensorflow.DeviceSpec.DeviceType; + import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; @@ -21,92 +26,87 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.types.TInt32; -import static com.google.common.truth.Truth.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; -import static org.tensorflow.DeviceSpec.DeviceType; - /** Tests for {@link DeviceSpec}. */ public class DeviceSpecTest { @Test public void withDeviceMethod() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { + try (Graph g = new Graph(); + Session session = new Session(g, config)) { Ops tf = Ops.create(g).withSubScope("testScope"); Constant aOps = tf.constant(-1); - DeviceSpec deviceSpec = DeviceSpec.newBuilder() + DeviceSpec deviceSpec = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) .deviceType(DeviceSpec.DeviceType.CPU) .build(); - Output absOps = tf - .withName("absWithDevice") - .withDevice(deviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = + tf.withName("absWithDevice").withDevice(deviceSpec).math.abs(aOps).asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + try (Result t = session.runner().fetch(absOps).run()) { + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } @Test public void withEmptyDeviceSpec() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { + try (Graph g = new Graph(); + Session session = new Session(g, config)) { Ops tf = Ops.create(g).withSubScope("testScope"); Constant aOps = tf.constant(-1); - DeviceSpec deviceSpec = DeviceSpec.newBuilder() + DeviceSpec deviceSpec = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) .deviceType(DeviceSpec.DeviceType.CPU) .build(); - Output absOps = tf - .withName("absWithDevice") - .withDevice(deviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = + tf.withName("absWithDevice").withDevice(deviceSpec).math.abs(aOps).asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + try (Result t = session.runner().fetch(absOps).run()) { + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } @Test public void withTwoScopes() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec deviceSpec1 = DeviceSpec.newBuilder() + try (Graph g = new Graph(); + Session session = new Session(g, config)) { + DeviceSpec deviceSpec1 = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) .deviceType(DeviceSpec.DeviceType.CPU) .build(); - DeviceSpec deviceSpec2 = DeviceSpec.newBuilder() + DeviceSpec deviceSpec2 = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) @@ -119,33 +119,27 @@ public void withTwoScopes() { Constant aOps = tf1.constant(-1); Constant bOps = tf2.constant(10); - Output absOps = tf1 - .withName("absWithDevice") - .math - .abs(aOps) - .asOutput(); + Output absOps = tf1.withName("absWithDevice").math.abs(aOps).asOutput(); - Output mulOps = tf2 - .withName("mulWithDevice") - .math - .mul(absOps, bOps) - .asOutput(); + Output mulOps = tf2.withName("mulWithDevice").math.mul(absOps, bOps).asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { - assertEquals(10, ((TInt32)t.get(0)).getInt()); + try (Result t = session.runner().fetch(mulOps).run()) { + assertEquals(10, ((TInt32) t.get(0)).getInt()); } } } @Test public void withIncorrectDeviceSpec() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec correctDeviceSpec = DeviceSpec.newBuilder() + try (Graph g = new Graph(); + Session session = new Session(g, config)) { + DeviceSpec correctDeviceSpec = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) @@ -153,7 +147,8 @@ public void withIncorrectDeviceSpec() { .build(); // Incorrect device spec, it will never be executed - DeviceSpec incorrectDeviceSpec = DeviceSpec.newBuilder() + DeviceSpec incorrectDeviceSpec = + DeviceSpec.newBuilder() .job("UNKNOWN") .replica(1) .task(1000) @@ -165,22 +160,17 @@ public void withIncorrectDeviceSpec() { Constant aOps = tf.constant(-1); Constant bOps = tf.constant(10); - Output absOps = tf - .withName("absWithDevice") - .withDevice(incorrectDeviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = + tf.withName("absWithDevice").withDevice(incorrectDeviceSpec).math.abs(aOps).asOutput(); - Output mulOps = tf - .withName("mulWithDevice") + Output mulOps = + tf.withName("mulWithDevice") .withDevice(correctDeviceSpec) .math .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { + try (Result t = session.runner().fetch(mulOps).run()) { fail(); } catch (TFInvalidArgumentException e) { // ok @@ -190,12 +180,15 @@ public void withIncorrectDeviceSpec() { @Test public void withDeviceSpecInScope() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec deviceSpec = DeviceSpec.newBuilder() + try (Graph g = new Graph(); + Session session = new Session(g, config)) { + DeviceSpec deviceSpec = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) @@ -206,15 +199,10 @@ public void withDeviceSpecInScope() { Constant aOps = tf.constant(-1); - Output absOps = tf - .withName("absWithDevice") - .math - .abs(aOps) - .asOutput(); + Output absOps = tf.withName("absWithDevice").math.abs(aOps).asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + try (Result t = session.runner().fetch(absOps).run()) { + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index 154d3903dcd..ff691e30adb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -25,7 +25,6 @@ import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashSet; -import java.util.List; import java.util.Set; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; @@ -84,15 +83,13 @@ public void graphDefRoundTripWithInit() { Operand variable2 = init.withName("var2").variable(init.constant(4)); - try (Session s = new Session(g, true)) { - List results = s.runner().fetch("result").fetch("var2").run(); + try (Session s = new Session(g, true); + Result results = s.runner().fetch("result").fetch("var2").run()) { TInt32 result = (TInt32) results.get(0); assertEquals(6, result.getInt()); TInt32 var2Result = (TInt32) results.get(1); assertEquals(4, var2Result.getInt()); - - results.forEach(Tensor::close); } } } @@ -266,15 +263,14 @@ public void addGradientsToGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run())) { + Result outputs = + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run()) { assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); @@ -418,14 +414,13 @@ public void buildWhileLoopMultipleInputs() { try (TInt32 c1 = TInt32.scalarOf(2); TInt32 c2 = TInt32.scalarOf(5); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() - .feed(input1, c1) - .feed(input2, c2) - .fetch(loopOutputs[0]) - .fetch(loopOutputs[1]) - .run())) { + Result outputs = + s.runner() + .feed(input1, c1) + .feed(input2, c2) + .fetch(loopOutputs[0]) + .fetch(loopOutputs[1]) + .run()) { assertEquals(2, outputs.size()); assertEquals(16, ((TInt32) outputs.get(0)).getInt()); // ((2^2)^2) assertEquals(625, ((TInt32) outputs.get(1)).getInt()); // ((5^2)^2) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index be6f952fb6a..deff52ffbeb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -215,7 +215,10 @@ public void exportFunctionWithVariables() throws IOException { // Now call the same function directly from the model try (TFloat32 zTensor = (TFloat32) - savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { + savedModel + .call(Collections.singletonMap("input", xTensor)) + .get("reducedSum") + .get()) { assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } } @@ -293,9 +296,9 @@ public void pythonTfFunction() { System.out.println(add.signature()); args.put("a", a); args.put("b", b); - Map result = add.call(args); + Result result = add.call(args); assertEquals(result.size(), 1); - try (TFloat32 c = (TFloat32) result.values().iterator().next()) { + try (TFloat32 c = (TFloat32) result.get(0)) { assertEquals(25.5f, c.getFloat()); } } @@ -307,11 +310,7 @@ public void pythonTfFunction() { args.put("dummy", dummy); // TF functions always require an input, so we supply a dummy one here // This test actually checks that resource variables can be loaded correctly. - try (TFloat32 v = - (TFloat32) - getVariable - .call(args) - .get(getVariable.signature().outputNames().iterator().next())) { + try (TFloat32 v = (TFloat32) getVariable.call(args).get(0)) { assertEquals(2f, v.getFloat()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 95da0520f7d..918ccac5fe2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -16,7 +16,6 @@ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -26,6 +25,7 @@ import java.nio.file.Path; import java.util.Comparator; import java.util.Iterator; +import java.util.Optional; import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -38,6 +38,7 @@ import org.tensorflow.op.math.Add; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.GraphDef; +import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -69,8 +70,7 @@ public void runUsingOperationNames() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { + Result outputs = s.runner().feed("X", x).fetch("Y").run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } @@ -86,8 +86,7 @@ public void runUsingOperationHandles() { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { + Result outputs = s.runner().feed(feed, x).fetch(fetch).run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } @@ -124,20 +123,20 @@ public void runWithMetadata() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { - Session.Run result = + Result result = s.runner() .feed("X", x) .fetch("Y") .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); + assertEquals(1, result.size()); + assertEquals(31, ((TInt32) result.get(0)).getInt(0, 0)); // Sanity check on metadata - assertNotNull(result.metadata); - assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); - outputs.close(); + Optional metadata = result.getMetadata(); + assertTrue(metadata.isPresent()); + assertTrue(metadata.get().hasStepStats(), metadata.get().toString()); + result.close(); } } } @@ -149,8 +148,7 @@ public void runMultipleOutputs() { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); + Result outputs = s.runner().fetch("c2").fetch("c1").run(); assertEquals(2, outputs.size()); assertEquals(31415, ((TInt32) outputs.get(0)).getInt()); assertEquals(2718, ((TInt32) outputs.get(1)).getInt()); @@ -227,10 +225,8 @@ public void saveAndRestore() throws IOException { restoredGraph.importGraphDef(graphDef); try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); - try (AutoCloseableList oldList = - new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); - AutoCloseableList newList = - new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())) { + try (Result oldList = s.runner().fetch("x").fetch("y").run(); + Result newList = restoredSession.runner().fetch("x").fetch("y").run()) { assertEquals(oldList.get(0), newList.get(0)); assertEquals(oldList.get(1), newList.get(1)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index 16c14f7a9a3..4edbea33b0d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -18,12 +18,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.OpScope; import org.tensorflow.op.Scope; @@ -50,7 +49,7 @@ public void testBooleanMaskUpdateSlice() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { @@ -89,7 +88,7 @@ public void testBooleanMaskUpdateSliceWithBroadcast() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { @@ -131,7 +130,7 @@ public void testBooleanMaskUpdateAxis() { BooleanMaskUpdate.create( scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 5c413b3abeb..5194fccd707 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -19,12 +19,11 @@ import java.io.IOException; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.EagerSession; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.IntNdArray; @@ -66,8 +65,7 @@ public void createInts() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -85,8 +83,7 @@ public void createFloats() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -104,8 +101,7 @@ public void createDoubles() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -123,8 +119,7 @@ public void createLongs() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -142,8 +137,7 @@ public void createStrings() throws IOException { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 80150b64bb6..fb52b2d1059 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -21,11 +21,10 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.Graph; import org.tensorflow.Output; +import org.tensorflow.Result; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; @@ -48,12 +47,10 @@ public void createGradients() { assertEquals(2, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) { - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + assertEquals(18.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); } } } @@ -75,10 +72,9 @@ public void createGradientsWithSum() { assertEquals(1, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { - assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + assertEquals(114.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); } } } @@ -94,18 +90,17 @@ public void createGradientsWithInitialValues() { Output y1 = tf.math.square(y0).y(); Gradients grads0 = Gradients.create(tf.scope(), y1, Arrays.asList(y0)); - Gradients grads1 = Gradients.create(tf.scope(), y0, Arrays.asList(x), Gradients.dx(grads0.dy())); + Gradients grads1 = + Gradients.create(tf.scope(), y0, Arrays.asList(x), Gradients.dx(grads0.dy())); assertNotNull(grads1); assertNotNull(grads1.dy()); assertEquals(1, grads1.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + Result outputs = sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index b4d36702c93..73b7e0a551c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -19,9 +19,9 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.op.OpScope; import org.tensorflow.op.Scope; @@ -134,7 +134,7 @@ public void operationsComposingZerosAreCorrectlyNamed() { long[] shape = {2, 2}; Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); - List results = + Result results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 1f8503829b7..1bbeb1a3f0a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; @@ -51,15 +52,10 @@ public void testGraphIteration() { int batches = 0; while (true) { - try { - List outputs = session.runner().fetch(x).fetch(y).run(); - - try (TInt32 xBatch = (TInt32) outputs.get(0); - TInt32 yBatch = (TInt32) outputs.get(1)) { - assertEquals(testMatrix1.get(batches), xBatch); - assertEquals(testMatrix2.get(batches), yBatch); - batches++; - } + try (Result outputs = session.runner().fetch(x).fetch(y).run()) { + assertEquals(testMatrix1.get(batches), outputs.get(0)); + assertEquals(testMatrix2.get(batches), outputs.get(1)); + batches++; } catch (TFOutOfRangeException e) { break; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index afa38e04ee8..e75bdde766e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.ndarray.IntNdArray; @@ -76,17 +77,11 @@ public void testGraphIteration() { int batches = 0; while (true) { - try { - List outputs = session.runner().fetch(X).fetch(y).run(); + try (Result outputs = session.runner().fetch(X).fetch(y).run()) { + assertEquals(mapped1.get(batches), outputs.get(0)); + assertEquals(mapped2.get(batches), outputs.get(1)); - try (TInt32 XBatch = (TInt32) outputs.get(0); - TInt32 yBatch = (TInt32) outputs.get(1)) { - - assertEquals(mapped1.get(batches), XBatch); - assertEquals(mapped2.get(batches), yBatch); - - batches++; - } + batches++; } catch (TFOutOfRangeException e) { break; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 4330fa0aed7..fc1e2fe9573 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -14,8 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Op; @@ -26,10 +29,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertThrows; - public class AssertBroadcastableTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -69,10 +68,10 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { + try (Result tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run()) { + Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1); Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); testSession diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java index 3322a81fe5b..9df29436e31 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java @@ -14,8 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -25,12 +30,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class BroadcastWeightsTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -78,55 +77,57 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { + try (Result tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run()) { + Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1); Operand dynamicOp = MetricsHelper.broadcastWeights(tf, weightsPlaceholder, valuesPlaceholder); - List result = + try (Result result = testSession .getGraphSession() .runner() .feed(weightsPlaceholder, weightsTensor) .feed(valuesPlaceholder, valuesTensor) .fetch(dynamicOp) - .run(); - - if (expected != null) { - if (type.equals(TInt32.class)) { - TInt32 intT = (TInt32) result.get(0); - AtomicInteger i = new AtomicInteger(); - intT.scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt())); - } else if (type.equals(TInt64.class)) { - TInt64 floatT = (TInt64) result.get(0); - AtomicInteger i = new AtomicInteger(); - floatT - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong())); - } else if (type.equals(TFloat32.class)) { - TFloat32 floatT = (TFloat32) result.get(0); - AtomicInteger i = new AtomicInteger(); - floatT - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals( - expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F)); - } else if (type.equals(TFloat64.class)) { - TFloat64 doubleT = (TFloat64) result.get(0); - AtomicInteger i = new AtomicInteger(); - doubleT - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals( - expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F)); + .run()) { + + if (expected != null) { + if (type.equals(TInt32.class)) { + TInt32 intT = (TInt32) result.get(0); + AtomicInteger i = new AtomicInteger(); + intT.scalars() + .forEachIndexed( + (idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt())); + } else if (type.equals(TInt64.class)) { + TInt64 floatT = (TInt64) result.get(0); + AtomicInteger i = new AtomicInteger(); + floatT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong())); + } else if (type.equals(TFloat32.class)) { + TFloat32 floatT = (TFloat32) result.get(0); + AtomicInteger i = new AtomicInteger(); + floatT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F)); + } else if (type.equals(TFloat64.class)) { + TFloat64 doubleT = (TFloat64) result.get(0); + AtomicInteger i = new AtomicInteger(); + doubleT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F)); + } } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 909fd53ca27..a59f67f5a99 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -4,6 +4,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -11,6 +12,7 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.initializers.Glorot; @@ -189,13 +191,18 @@ public void testDeterminism() { g.importGraphDef(def); s.initialize(); - initialized.add( + Result initializationRes = s.runner() .fetch(fcWeightName) .fetch(fcBiasName) .fetch(outputWeightName) .fetch(outputBiasName) - .run()); + .run(); + List initializedRun = new ArrayList<>(); + for (Map.Entry e : initializationRes) { + initializedRun.add(e.getValue()); + } + initialized.add(initializedRun); TFloat32 lossVal = (TFloat32) @@ -209,13 +216,18 @@ public void testDeterminism() { initialLoss[i] = lossVal.getFloat(); lossVal.close(); - trained.add( + Result trainedRes = s.runner() .fetch(fcWeightName) .fetch(fcBiasName) .fetch(outputWeightName) .fetch(outputBiasName) - .run()); + .run(); + List trainedRun = new ArrayList<>(); + for (Map.Entry e : trainedRes) { + trainedRun.add(e.getValue()); + } + trained.add(trainedRun); lossVal = (TFloat32)