|
| 1 | +/* |
| 2 | +Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. |
| 3 | +Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| 4 | +
|
| 5 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +you may not use this file except in compliance with the License. |
| 7 | +You may obtain a copy of the License at |
| 8 | +
|
| 9 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +
|
| 11 | +Unless required by applicable law or agreed to in writing, software |
| 12 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +See the License for the specific language governing permissions and |
| 15 | +limitations under the License. |
| 16 | +======================================================================= |
| 17 | +*/ |
| 18 | +package org.tensorflow; |
| 19 | + |
| 20 | +import java.util.ArrayList; |
| 21 | +import java.util.Collections; |
| 22 | +import java.util.Iterator; |
| 23 | +import java.util.LinkedHashMap; |
| 24 | +import java.util.List; |
| 25 | +import java.util.Map; |
| 26 | +import java.util.Optional; |
| 27 | +import java.util.Set; |
| 28 | +import java.util.logging.Level; |
| 29 | +import java.util.logging.Logger; |
| 30 | +import org.tensorflow.exceptions.TensorFlowException; |
| 31 | +import org.tensorflow.proto.framework.RunMetadata; |
| 32 | + |
| 33 | +/** |
| 34 | + * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s. |
| 35 | + * |
| 36 | + * <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a reference |
| 37 | + * to a value after this object has been closed it will throw an {@link IllegalStateException} upon |
| 38 | + * access. |
| 39 | + * |
| 40 | + * <p>This class is not thread-safe with respect to the close operation. Multiple closers or one |
| 41 | + * thread closing a tensor while another is reading may throw exceptions. |
| 42 | + * |
| 43 | + * <p>Note this class is used to manage the lifetimes of tensors produced by the TensorFlow runtime, |
| 44 | + * from sessions and function calls. It is not used as an argument to {@code session.run} or |
| 45 | + * function calls as users are in control of the creation of input tensors. |
| 46 | + */ |
| 47 | +public final class Result implements AutoCloseable, Iterable<Map.Entry<String, Tensor>> { |
| 48 | + @Override |
| 49 | + public void close() { |
| 50 | + if (!closed) { |
| 51 | + for (Tensor t : list) { |
| 52 | + try { |
| 53 | + t.close(); |
| 54 | + } catch (TensorFlowException e) { |
| 55 | + logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e); |
| 56 | + } |
| 57 | + } |
| 58 | + closed = true; |
| 59 | + } else { |
| 60 | + logger.warning("Closing an already closed Result"); |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + @Override |
| 65 | + public Iterator<Map.Entry<String, Tensor>> iterator() { |
| 66 | + if (!closed) { |
| 67 | + return map.entrySet().iterator(); |
| 68 | + } else { |
| 69 | + throw new IllegalStateException("Result is closed"); |
| 70 | + } |
| 71 | + } |
| 72 | + |
| 73 | + /** |
| 74 | + * Returns the number of outputs in this Result. |
| 75 | + * |
| 76 | + * @return The number of outputs. |
| 77 | + */ |
| 78 | + public int size() { |
| 79 | + return map.size(); |
| 80 | + } |
| 81 | + |
| 82 | + /** |
| 83 | + * Gets the set containing all the tensor names. |
| 84 | + * |
| 85 | + * @return The tensor names set. |
| 86 | + */ |
| 87 | + public Set<String> keySet() { |
| 88 | + return Collections.unmodifiableSet(map.keySet()); |
| 89 | + } |
| 90 | + |
| 91 | + /** |
| 92 | + * Does this result object have a tensor for the supplied key? |
| 93 | + * |
| 94 | + * @param key The key to check. |
| 95 | + * @return True if this result object has a tensor for this key. |
| 96 | + */ |
| 97 | + public boolean containsKey(String key) { |
| 98 | + return map.containsKey(key); |
| 99 | + } |
| 100 | + |
| 101 | + /** |
| 102 | + * Gets the value from the container at the specified index. |
| 103 | + * |
| 104 | + * <p>Throws {@link IllegalStateException} if the container has been closed, and {@link |
| 105 | + * IndexOutOfBoundsException} if the index is invalid. |
| 106 | + * |
| 107 | + * @param index The index to lookup. |
| 108 | + * @return The value at the index. |
| 109 | + */ |
| 110 | + public Tensor get(int index) { |
| 111 | + if (!closed) { |
| 112 | + return list.get(index); |
| 113 | + } else { |
| 114 | + throw new IllegalStateException("Result is closed"); |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + /** |
| 119 | + * Gets the value from the container assuming it's not been closed. |
| 120 | + * |
| 121 | + * <p>Throws {@link IllegalStateException} if the container has been closed. |
| 122 | + * |
| 123 | + * @param key The key to lookup. |
| 124 | + * @return Optional.of the value if it exists. |
| 125 | + */ |
| 126 | + public Optional<Tensor> get(String key) { |
| 127 | + if (!closed) { |
| 128 | + return Optional.ofNullable(map.get(key)); |
| 129 | + } else { |
| 130 | + throw new IllegalStateException("Result is closed"); |
| 131 | + } |
| 132 | + } |
| 133 | + |
| 134 | + /** |
| 135 | + * Metadata about the run. |
| 136 | + * |
| 137 | + * <p>A <a |
| 138 | + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata |
| 139 | + * protocol buffer</a>. |
| 140 | + */ |
| 141 | + public Optional<RunMetadata> getMetadata() { |
| 142 | + return Optional.ofNullable(metadata); |
| 143 | + } |
| 144 | + |
| 145 | + /** |
| 146 | + * Creates a Result from the names and values produced by {@link Session.Runner#run()}. |
| 147 | + * |
| 148 | + * @param names The output names. |
| 149 | + * @param values The output values. |
| 150 | + * @param metadata The run metadata, may be null. |
| 151 | + */ |
| 152 | + Result(List<String> names, List<Tensor> values, RunMetadata metadata) { |
| 153 | + this.map = new LinkedHashMap<>(); |
| 154 | + this.list = new ArrayList<>(values); |
| 155 | + |
| 156 | + if (names.size() != values.size()) { |
| 157 | + throw new IllegalArgumentException( |
| 158 | + "Expected same number of names and values, found names.length = " |
| 159 | + + names.size() |
| 160 | + + ", values.length = " |
| 161 | + + values.size()); |
| 162 | + } |
| 163 | + |
| 164 | + for (int i = 0; i < names.size(); i++) { |
| 165 | + Tensor old = this.map.put(names.get(i), values.get(i)); |
| 166 | + if (old != null) { |
| 167 | + throw new IllegalArgumentException( |
| 168 | + "Name collision in the result set, two outputs are named '" + names.get(i) + "'"); |
| 169 | + } |
| 170 | + } |
| 171 | + this.metadata = metadata; |
| 172 | + this.closed = false; |
| 173 | + } |
| 174 | + |
| 175 | + /** |
| 176 | + * Creates a Result from the names and values. |
| 177 | + * |
| 178 | + * @param outputs The run outputs. |
| 179 | + */ |
| 180 | + Result(LinkedHashMap<String, Tensor> outputs) { |
| 181 | + this.map = outputs; |
| 182 | + this.list = new ArrayList<>(outputs.size()); |
| 183 | + for (Map.Entry<String, Tensor> e : outputs.entrySet()) { |
| 184 | + list.add(e.getValue()); |
| 185 | + } |
| 186 | + this.metadata = null; |
| 187 | + this.closed = false; |
| 188 | + } |
| 189 | + |
| 190 | + private final Map<String, Tensor> map; |
| 191 | + |
| 192 | + private final List<Tensor> list; |
| 193 | + |
| 194 | + private final RunMetadata metadata; |
| 195 | + |
| 196 | + private boolean closed; |
| 197 | + |
| 198 | + private static final Logger logger = Logger.getLogger(Result.class.getName()); |
| 199 | +} |
0 commit comments