Skip to content

Commit 478924e

Browse files
committed
Working gradients
Signed-off-by: Ryan Nett <[email protected]>
1 parent 1bf4ba1 commit 478924e

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ public void close() {
119119
}
120120
delete(nativeHandle);
121121
nativeHandle = null;
122+
allGraphs.remove(this);
122123
}
123124
}
124125

@@ -1106,7 +1107,7 @@ private static SaverDef addVariableSaver(Graph graph) {
11061107
*/
11071108
public static Graph findGraphForPointer(NativeGraphPointer pointer) {
11081109
for (Graph g : allGraphs) {
1109-
if (g.nativeHandle.graph().equals(pointer)) {
1110+
if (g.nativeHandle != null && !g.nativeHandle.isNull() && g.nativeHandle.graph().equals(pointer)) {
11101111
return g;
11111112
}
11121113
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.tensorflow.internal.c_api.global.tensorflow.TF_FinishOperationLocked;
2323
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewOperation;
2424
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewOperationLocked;
25+
import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationName;
2526
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrBool;
2627
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrBoolList;
2728
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrFloat;
@@ -413,7 +414,7 @@ private static TF_Operation finishDangerousGradient(TF_Graph g, TF_OperationDesc
413414
TF_Status status = TF_Status.newStatus();
414415
TF_Operation op = TF_FinishOperationLocked(handle, status);
415416
status.throwExceptionIfNotOK();
416-
// g.name_map().put(TF_OperationName(op), null);
417+
g.name_map().erase(TF_OperationName(op));
417418
return op;
418419
}
419420
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ public void map(InfoMap infoMap) {
285285
.put(new Info("TF_Graph::refiner", "TF_Graph::mu",
286286
"TF_Graph::sessions", "TF_Graph::delete_requested").skip())
287287
.put(new Info("std::unordered_map<tensorflow::string,tensorflow::Node*>")
288-
.pointerTypes("NameMap").define())
288+
.pointerTypes("NameMap").define().javaText("public native long erase(@StdString BytePointer key);"))
289289
.put(new Info("TF_ImportGraphDefOptions").pointerTypes("TF_ImportGraphDefOptions")
290290
.base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions"))
291291
.put(new Info("TF_WhileParams", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell",

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public void addGradientsToGraph() {
3838
Operand<?> out = gradInputs.get(0);
3939
Operand<?> a = tf.stridedSlice(out, Indices.slice(0, 1));
4040
Operand<?> b = tf.stridedSlice(out, Indices.slice(1, 2));
41-
return Arrays.asList(a, b);
41+
return Arrays.asList(a, b, tf.constant(0f));
4242
});
4343
Ops tf = Ops.create(g);
4444

@@ -64,8 +64,8 @@ public void addGradientsToGraph() {
6464
.run())) {
6565

6666
assertEquals(2, outputs.size());
67-
assertEquals(3.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
68-
assertEquals(2.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f);
67+
assertEquals(6.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
68+
assertEquals(4.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f);
6969
}
7070
}
7171
}

0 commit comments

Comments
 (0)