Skip to content

Commit 8fd9362

Browse files
authored
Adds a closeable session result (#411)
1 parent d518678 commit 8fd9362

23 files changed

+477
-341
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2020-2022 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -295,8 +295,8 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
295295
}
296296

297297
@Override
298-
public Map<String, Tensor> call(Map<String, Tensor> arguments) {
299-
// FIXME need to manage input/output operand lifetimes
298+
public Result call(Map<String, Tensor> arguments) {
299+
// FIXME need to manage input operand lifetimes
300300
Ops tf = Ops.create();
301301
Map<String, Operand<?>> inputs = new LinkedHashMap<>(arguments.size());
302302

@@ -305,11 +305,11 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments) {
305305
inputs.put(inputName, tf.constantOf((TType) argument));
306306
}
307307
Map<String, Operand<?>> outputs = tf.call(this, inputs);
308-
Map<String, Tensor> tensorOutputs = new LinkedHashMap<>(outputs.size());
308+
LinkedHashMap<String, Tensor> tensorOutputs = new LinkedHashMap<>(outputs.size());
309309
for (String outputName : outputs.keySet()) {
310310
tensorOutputs.put(outputName, outputs.get(outputName).asTensor());
311311
}
312-
return tensorOutputs;
312+
return new Result(tensorOutputs);
313313
}
314314

315315
/**
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
3+
Copyright 2022 The TensorFlow Authors. All Rights Reserved.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
=======================================================================
17+
*/
18+
package org.tensorflow;
19+
20+
import java.util.ArrayList;
21+
import java.util.Collections;
22+
import java.util.Iterator;
23+
import java.util.LinkedHashMap;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.Optional;
27+
import java.util.Set;
28+
import java.util.logging.Level;
29+
import java.util.logging.Logger;
30+
import org.tensorflow.exceptions.TensorFlowException;
31+
import org.tensorflow.proto.framework.RunMetadata;
32+
33+
/**
34+
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
35+
*
36+
* <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a reference
37+
* to a value after this object has been closed it will throw an {@link IllegalStateException} upon
38+
* access.
39+
*
40+
* <p>This class is not thread-safe with respect to the close operation. Multiple closers or one
41+
* thread closing a tensor while another is reading may throw exceptions.
42+
*
43+
* <p>Note this class is used to manage the lifetimes of tensors produced by the TensorFlow runtime,
44+
* from sessions and function calls. It is not used as an argument to {@code session.run} or
45+
* function calls as users are in control of the creation of input tensors.
46+
*/
47+
public final class Result implements AutoCloseable, Iterable<Map.Entry<String, Tensor>> {
48+
@Override
49+
public void close() {
50+
if (!closed) {
51+
for (Tensor t : list) {
52+
try {
53+
t.close();
54+
} catch (TensorFlowException e) {
55+
logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e);
56+
}
57+
}
58+
closed = true;
59+
} else {
60+
logger.warning("Closing an already closed Result");
61+
}
62+
}
63+
64+
@Override
65+
public Iterator<Map.Entry<String, Tensor>> iterator() {
66+
if (!closed) {
67+
return map.entrySet().iterator();
68+
} else {
69+
throw new IllegalStateException("Result is closed");
70+
}
71+
}
72+
73+
/**
74+
* Returns the number of outputs in this Result.
75+
*
76+
* @return The number of outputs.
77+
*/
78+
public int size() {
79+
return map.size();
80+
}
81+
82+
/**
83+
* Gets the set containing all the tensor names.
84+
*
85+
* @return The tensor names set.
86+
*/
87+
public Set<String> keySet() {
88+
return Collections.unmodifiableSet(map.keySet());
89+
}
90+
91+
/**
92+
* Does this result object have a tensor for the supplied key?
93+
*
94+
* @param key The key to check.
95+
* @return True if this result object has a tensor for this key.
96+
*/
97+
public boolean containsKey(String key) {
98+
return map.containsKey(key);
99+
}
100+
101+
/**
102+
* Gets the value from the container at the specified index.
103+
*
104+
* <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
105+
* IndexOutOfBoundsException} if the index is invalid.
106+
*
107+
* @param index The index to lookup.
108+
* @return The value at the index.
109+
*/
110+
public Tensor get(int index) {
111+
if (!closed) {
112+
return list.get(index);
113+
} else {
114+
throw new IllegalStateException("Result is closed");
115+
}
116+
}
117+
118+
/**
119+
* Gets the value from the container assuming it's not been closed.
120+
*
121+
* <p>Throws {@link IllegalStateException} if the container has been closed.
122+
*
123+
* @param key The key to lookup.
124+
* @return Optional.of the value if it exists.
125+
*/
126+
public Optional<Tensor> get(String key) {
127+
if (!closed) {
128+
return Optional.ofNullable(map.get(key));
129+
} else {
130+
throw new IllegalStateException("Result is closed");
131+
}
132+
}
133+
134+
/**
135+
* Metadata about the run.
136+
*
137+
* <p>A <a
138+
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
139+
* protocol buffer</a>.
140+
*/
141+
public Optional<RunMetadata> getMetadata() {
142+
return Optional.ofNullable(metadata);
143+
}
144+
145+
/**
146+
* Creates a Result from the names and values produced by {@link Session.Runner#run()}.
147+
*
148+
* @param names The output names.
149+
* @param values The output values.
150+
* @param metadata The run metadata, may be null.
151+
*/
152+
Result(List<String> names, List<Tensor> values, RunMetadata metadata) {
153+
this.map = new LinkedHashMap<>();
154+
this.list = new ArrayList<>(values);
155+
156+
if (names.size() != values.size()) {
157+
throw new IllegalArgumentException(
158+
"Expected same number of names and values, found names.length = "
159+
+ names.size()
160+
+ ", values.length = "
161+
+ values.size());
162+
}
163+
164+
for (int i = 0; i < names.size(); i++) {
165+
Tensor old = this.map.put(names.get(i), values.get(i));
166+
if (old != null) {
167+
throw new IllegalArgumentException(
168+
"Name collision in the result set, two outputs are named '" + names.get(i) + "'");
169+
}
170+
}
171+
this.metadata = metadata;
172+
this.closed = false;
173+
}
174+
175+
/**
176+
* Creates a Result from the names and values.
177+
*
178+
* @param outputs The run outputs.
179+
*/
180+
Result(LinkedHashMap<String, Tensor> outputs) {
181+
this.map = outputs;
182+
this.list = new ArrayList<>(outputs.size());
183+
for (Map.Entry<String, Tensor> e : outputs.entrySet()) {
184+
list.add(e.getValue());
185+
}
186+
this.metadata = null;
187+
this.closed = false;
188+
}
189+
190+
private final Map<String, Tensor> map;
191+
192+
private final List<Tensor> list;
193+
194+
private final RunMetadata metadata;
195+
196+
private boolean closed;
197+
198+
private static final Logger logger = Logger.getLogger(Result.class.getName());
199+
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ public List<SessionFunction> functions() {
468468
* @return list of output tensors, mapped by the signature name
469469
* @throws IllegalArgumentException if no function can be selected by default
470470
*/
471-
public Map<String, Tensor> call(Map<String, Tensor> arguments) {
471+
public Result call(Map<String, Tensor> arguments) {
472472
SessionFunction function = null;
473473
if (functions.size() == 1) {
474474
function = functions.values().iterator().next();

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2019-2022 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -306,7 +306,9 @@ public Runner feed(Operand<?> operand, Tensor t) {
306306
* @throws IllegalArgumentException if no output exists with the provided name
307307
*/
308308
public Runner fetch(String operation) {
309-
return fetch(graph.outputOrThrow(operation));
309+
Runner r = fetch(graph.outputOrThrow(operation), false);
310+
outputNames.add(operation);
311+
return r;
310312
}
311313

312314
/**
@@ -336,6 +338,20 @@ public Runner fetch(String operation, int index) {
336338
* @return this session runner
337339
*/
338340
public Runner fetch(Output<?> output) {
341+
return fetch(output, true);
342+
}
343+
344+
/**
345+
* Makes {@link #run()} return the Tensor referred to by {@code output}.
346+
*
347+
* <p>If {@code output} is a resource variable, will fetch the value.
348+
*
349+
* @param output the node to fetch the tensor from
350+
* @param recordName Records the output name. If false the output name must be recorded by the
351+
* calling method as otherwise the result object will throw on construction.
352+
* @return this session runner
353+
*/
354+
private Runner fetch(Output<?> output, boolean recordName) {
339355
if (output.env() != graph) {
340356
throw new IllegalStateException(
341357
"Can't fetch output "
@@ -378,6 +394,9 @@ public Runner fetch(Output<?> output) {
378394
} else {
379395
outputs.add(output);
380396
}
397+
if (recordName) {
398+
outputNames.add(output.name());
399+
}
381400
return this;
382401
}
383402

@@ -490,13 +509,13 @@ private void doInit() {
490509
*
491510
* @return list of resulting tensors fetched by this session runner
492511
*/
493-
public List<Tensor> run() {
512+
public Result run() {
494513
doInit();
495514
return runNoInit();
496515
}
497516

498-
List<Tensor> runNoInit() {
499-
return runHelper(false).outputs;
517+
Result runNoInit() {
518+
return runHelper(false);
500519
}
501520

502521
/**
@@ -509,12 +528,12 @@ List<Tensor> runNoInit() {
509528
*
510529
* @return list of resulting tensors fetched by this session runner, with execution metadata
511530
*/
512-
public Run runAndFetchMetadata() {
531+
public Result runAndFetchMetadata() {
513532
doInit();
514533
return runHelper(true);
515534
}
516535

517-
private Run runHelper(boolean wantMetadata) {
536+
private Result runHelper(boolean wantMetadata) {
518537
TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()];
519538
TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()];
520539
int[] inputOpIndices = new int[inputs.size()];
@@ -569,10 +588,7 @@ private Run runHelper(boolean wantMetadata) {
569588
} finally {
570589
runRef.close();
571590
}
572-
Run ret = new Run();
573-
ret.outputs = outputs;
574-
ret.metadata = metadata;
575-
return ret;
591+
return new Result(outputNames, outputs, metadata);
576592
}
577593

578594
private class Reference implements AutoCloseable {
@@ -602,6 +618,7 @@ public void close() {
602618
private final ArrayList<Output<?>> inputs = new ArrayList<>();
603619
private final ArrayList<Tensor> inputTensors = new ArrayList<>();
604620
private final ArrayList<Output<?>> outputs = new ArrayList<>();
621+
private final ArrayList<String> outputNames = new ArrayList<>();
605622
private final ArrayList<GraphOperation> targets = new ArrayList<>();
606623
private RunOptions runOptions = null;
607624
}
@@ -648,8 +665,9 @@ public SessionFunction function(Signature signature) {
648665
*
649666
* @param signature the signature of the function
650667
* @param arguments the arguments to call with.
668+
* @return The results of the function call.
651669
*/
652-
public Map<String, Tensor> run(Signature signature, Map<String, Tensor> arguments) {
670+
public Result run(Signature signature, Map<String, Tensor> arguments) {
653671
return function(signature).call(arguments);
654672
}
655673

@@ -698,26 +716,6 @@ public void restore(String prefix) {
698716
setInitialized();
699717
}
700718

701-
/**
702-
* Output tensors and metadata obtained when executing a session.
703-
*
704-
* <p>See {@link Runner#runAndFetchMetadata()}
705-
*/
706-
public static final class Run {
707-
708-
/** Tensors from requested fetches. */
709-
public List<Tensor> outputs;
710-
711-
/**
712-
* Metadata about the run.
713-
*
714-
* <p>A <a
715-
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
716-
* protocol buffer</a>.
717-
*/
718-
public RunMetadata metadata;
719-
}
720-
721719
Graph graph() {
722720
return graph;
723721
}

0 commit comments

Comments
 (0)