Skip to content

Commit 517fd8d

Browse files
committed
Cleanup adapter exceptions, name gradient scopes
Signed-off-by: Ryan Nett <[email protected]>
1 parent ca5d343 commit 517fd8d

File tree

2 files changed

+10
-24
lines changed

2 files changed

+10
-24
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
*/
1818
package org.tensorflow.op;
1919

20-
import static org.tensorflow.internal.c_api.global.tensorflow.StatusFromTF_Status;
21-
2220
import java.util.List;
2321
import org.bytedeco.javacpp.PointerScope;
2422
import org.tensorflow.BaseGradientAdapter;
@@ -30,7 +28,6 @@
3028
import org.tensorflow.internal.c_api.NativeOutputVector;
3129
import org.tensorflow.internal.c_api.NativeStatus;
3230
import org.tensorflow.internal.c_api.TF_Scope;
33-
import org.tensorflow.internal.c_api.TF_Status;
3431

3532
/** A native adapter for {@link RawCustomGradient}. */
3633
final class RawGradientAdapter extends BaseGradientAdapter {
@@ -54,24 +51,21 @@ public NativeStatus call(
5451
throw new IllegalStateException("No graph found for native gradient scope.");
5552
}
5653

57-
Scope nativeScope = new GradientScope(scope, g, null);
54+
GraphOperation operation = BaseGradientAdapter.getGraphOp(g, op.node());
55+
56+
Scope nativeScope = new GradientScope(scope, g, null).withSubScope(operation.name());
5857
Ops tf = new Ops(nativeScope);
5958

6059
List<Output<?>> gradInputs = BaseGradientAdapter.fromNativeOutputs(g, grad_inputs);
6160

62-
GraphOperation operation = BaseGradientAdapter.getGraphOp(g, op.node());
63-
6461
// The graph locks are not re-entrant, so attempting to add an op to a graph that has been
6562
// locked by the gradient builder will fail without this.
6663
BaseGradientAdapter.useDangerousLockedBuilders(g, true);
6764
List<Operand<?>> gradOutputs = gradient.call(tf, operation, gradInputs);
6865
BaseGradientAdapter.useDangerousLockedBuilders(g, false);
6966

7067
BaseGradientAdapter.putToNativeOutputs(gradOutputs, grad_outputs);
71-
} catch (Throwable t) {
72-
t.printStackTrace();
73-
throw t;
7468
}
75-
return StatusFromTF_Status(TF_Status.newStatus());
69+
return NativeStatus.OK();
7670
}
7771
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
*/
1818
package org.tensorflow.op;
1919

20-
import static org.tensorflow.internal.c_api.global.tensorflow.StatusFromTF_Status;
21-
2220
import java.lang.reflect.Constructor;
2321
import java.lang.reflect.InvocationTargetException;
2422
import java.util.List;
@@ -31,7 +29,6 @@
3129
import org.tensorflow.internal.c_api.NativeOutputVector;
3230
import org.tensorflow.internal.c_api.NativeStatus;
3331
import org.tensorflow.internal.c_api.TF_Scope;
34-
import org.tensorflow.internal.c_api.TF_Status;
3532

3633
/** A native adapter for {@link CustomGradient}. */
3734
final class TypedGradientAdapter<T extends RawOpInputs<?>> extends BaseGradientAdapter {
@@ -60,14 +57,15 @@ public NativeStatus call(
6057
throw new IllegalStateException("No graph found for native gradient scope.");
6158
}
6259

63-
Scope nativeScope = new GradientScope(scope, g, null);
60+
T rawOp = ctor.newInstance(BaseGradientAdapter.getGraphOp(g, op.node()));
61+
62+
Scope nativeScope =
63+
new GradientScope(scope, g, null).withSubScope(rawOp.getOutputs().op().name());
6464

6565
Ops tf = new Ops(nativeScope);
6666

6767
List<Output<?>> gradInputs = BaseGradientAdapter.fromNativeOutputs(g, grad_inputs);
6868

69-
T rawOp = ctor.newInstance(BaseGradientAdapter.getGraphOp(g, op.node()));
70-
7169
// The graph locks are not re-entrant, so attempting to add an op to a graph that has been
7270
// locked by the gradient builder will fail without this.
7371
BaseGradientAdapter.useDangerousLockedBuilders(g, true);
@@ -77,14 +75,8 @@ public NativeStatus call(
7775
BaseGradientAdapter.putToNativeOutputs(gradOutputs, grad_outputs);
7876

7977
} catch (InvocationTargetException | InstantiationException | IllegalAccessException e) {
80-
RuntimeException re =
81-
new RuntimeException("Could not instantiate Op class " + opInputClass, e);
82-
re.printStackTrace();
83-
throw re;
84-
} catch (Throwable t) {
85-
t.printStackTrace();
86-
throw t;
78+
throw new RuntimeException("Could not instantiate Op class " + opInputClass, e);
8779
}
88-
return StatusFromTF_Status(TF_Status.newStatus());
80+
return NativeStatus.OK();
8981
}
9082
}

0 commit comments

Comments
 (0)