1717*/
1818package org .tensorflow .op ;
1919
20- import static org .tensorflow .internal .c_api .global .tensorflow .StatusFromTF_Status ;
21-
2220import java .lang .reflect .Constructor ;
2321import java .lang .reflect .InvocationTargetException ;
2422import java .util .List ;
3129import org .tensorflow .internal .c_api .NativeOutputVector ;
3230import org .tensorflow .internal .c_api .NativeStatus ;
3331import org .tensorflow .internal .c_api .TF_Scope ;
34- import org .tensorflow .internal .c_api .TF_Status ;
3532
3633/** A native adapter for {@link CustomGradient}. */
3734final class TypedGradientAdapter <T extends RawOpInputs <?>> extends BaseGradientAdapter {
@@ -60,14 +57,15 @@ public NativeStatus call(
6057 throw new IllegalStateException ("No graph found for native gradient scope." );
6158 }
6259
63- Scope nativeScope = new GradientScope (scope , g , null );
60+ T rawOp = ctor .newInstance (BaseGradientAdapter .getGraphOp (g , op .node ()));
61+
62+ Scope nativeScope =
63+ new GradientScope (scope , g , null ).withSubScope (rawOp .getOutputs ().op ().name ());
6464
6565 Ops tf = new Ops (nativeScope );
6666
6767 List <Output <?>> gradInputs = BaseGradientAdapter .fromNativeOutputs (g , grad_inputs );
6868
69- T rawOp = ctor .newInstance (BaseGradientAdapter .getGraphOp (g , op .node ()));
70-
7169 // The graph locks are not re-entrant, so attempting to add an op to a graph that has been
7270 // locked by the gradient builder will fail without this.
7371 BaseGradientAdapter .useDangerousLockedBuilders (g , true );
@@ -77,14 +75,8 @@ public NativeStatus call(
7775 BaseGradientAdapter .putToNativeOutputs (gradOutputs , grad_outputs );
7876
7977 } catch (InvocationTargetException | InstantiationException | IllegalAccessException e ) {
80- RuntimeException re =
81- new RuntimeException ("Could not instantiate Op class " + opInputClass , e );
82- re .printStackTrace ();
83- throw re ;
84- } catch (Throwable t ) {
85- t .printStackTrace ();
86- throw t ;
78+ throw new RuntimeException ("Could not instantiate Op class " + opInputClass , e );
8779 }
88- return StatusFromTF_Status ( TF_Status . newStatus () );
80+ return NativeStatus . OK ( );
8981 }
9082}
0 commit comments