11/*
2- Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
33
4- Licensed under the Apache License, Version 2.0 (the "License");
5- you may not use this file except in compliance with the License.
6- You may obtain a copy of the License at
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
77
8- http://www.apache.org/licenses/LICENSE-2.0
8+ http://www.apache.org/licenses/LICENSE-2.0
99
10- Unless required by applicable law or agreed to in writing, software
11- distributed under the License is distributed on an "AS IS" BASIS,
12- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13- See the License for the specific language governing permissions and
14- limitations under the License.
15- ==============================================================================
16- */
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ ==============================================================================
16+ */
1717package org .tensorflow ;
1818
19+ import static org .tensorflow .internal .c_api .global .tensorflow .OutputsFromTFOutputs ;
20+ import static org .tensorflow .internal .c_api .global .tensorflow .TFOutputsFromOutputs ;
1921import static org .tensorflow .internal .c_api .global .tensorflow .ToOperation ;
2022
2123import java .util .ArrayList ;
2224import java .util .List ;
2325import org .bytedeco .javacpp .PointerScope ;
24- import org .tensorflow .internal .c_api .NativeOutput ;
2526import org .tensorflow .internal .c_api .NativeOutputVector ;
2627import org .tensorflow .internal .c_api .Node ;
28+ import org .tensorflow .internal .c_api .TF_Output ;
2729
2830/**
2931 * Helpers for {@link org.tensorflow.op.TypedGradientAdapter} and {@link
@@ -34,39 +36,49 @@ public class GradientAdapterHelpers {
3436 /**
3537 * Convert a array of native outputs to a list of {@link Output}s.
3638 *
37- * @param g the graph the outputs are in
39+ * @param g the graph the outputs are in
3840 * @param nativeOutputs the native outputs to convert
3941 */
4042 public static List <Output <?>> fromNativeOutputs (Graph g , NativeOutputVector nativeOutputs ) {
43+ TF_Output outputs = new TF_Output (nativeOutputs .size ());
44+ TFOutputsFromOutputs (nativeOutputs , outputs );
4145 List <Output <?>> gradInputs = new ArrayList <>((int ) nativeOutputs .size ());
4246 for (int i = 0 ; i < nativeOutputs .size (); i ++) {
43- NativeOutput output = nativeOutputs .get (i );
44- gradInputs .add (new Output <>(getGraphOp (g , output .node ()),
45- output .index ()));
47+ TF_Output output = outputs .position (i );
48+ gradInputs .add (new Output <>(new GraphOperation (g , output .oper ()), output .index ()));
4649 }
4750 return gradInputs ;
4851 }
4952
5053 /**
5154 * Put the Java outputs into the array of native outputs, resizing it to the necessary size.
5255 *
53- * @param outputs the outputs to put
56+ * @param outputs the outputs to put
5457 * @param nativeOutputs the native array to put the outputs into
5558 */
56- public static void putToNativeOutputs (List < Operand <?>> outputs ,
57- NativeOutputVector nativeOutputs ) {
59+ public static void putToNativeOutputs (
60+ List < Operand <?>> outputs , NativeOutputVector nativeOutputs ) {
5861 nativeOutputs .resize (outputs .size ());
62+
63+ TF_Output tempOutputs = new TF_Output (outputs .size ());
5964 for (int i = 0 ; i < outputs .size (); i ++) {
6065 Output <?> output = outputs .get (i ).asOutput ();
61- Node node = ((GraphOperation ) output .op ()).getUnsafeNativeHandle ().node ();
62- nativeOutputs .put (i , new NativeOutput (node , output .index ()));
66+ GraphOperation graphOp = (GraphOperation ) output .op ();
67+ tempOutputs
68+ .position (i )
69+ .put (new TF_Output ().oper (graphOp .getUnsafeNativeHandle ()).index (output .index ()));
70+ }
71+
72+ NativeOutputVector temp = OutputsFromTFOutputs (tempOutputs , outputs .size ());
73+ for (int i = 0 ; i < outputs .size (); i ++) {
74+ nativeOutputs .put (i , temp .get (i ));
6375 }
6476 }
6577
6678 /**
6779 * Make a {@link GraphOperation} from a native {@link Node}
6880 *
69- * @param g the graph the operation is in
81+ * @param g the graph the operation is in
7082 * @param node the native node
7183 * @return a graph operation with the underlying native node
7284 */
0 commit comments