-
Notifications
You must be signed in to change notification settings - Fork 215
Graph custom gradient support #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
ee96979
Add JavaCPP generation for gradient registry stuff
rnett c5717b0
Working gradients
rnett 12d579c
Add missing requireHandles
rnett 1f1ba2c
Rebase, use try-with-resources
rnett 89690ab
Expose the necessary symbols on Windows
rnett dfaac8d
Nicely handle pre-existing gradients
rnett 2f22b57
Small-ish review changes
rnett 3acc001
Use annotation instead of field reflection to store op types
rnett 5127413
Cleanup and more review changes
rnett 293fecc
Remove empty init file
rnett 2eb9342
Update annotation generator comments
rnett f1f6e8b
Update annotation names and comments, and registerCustomGradient javadoc
rnett 8c40b8c
Add no-arg ctor to BaseGradientAdapter
rnett d08b38b
Add documentation about dangerousGradientBuilder
rnett 72ed4f0
Add Javadoc for getUnsafeNativeHandle
rnett fd2609d
More dangerous gradient builder javadocs
rnett 36a6e30
Add note about why gradientFuncs is required
rnett 2ddbb6c
Store and allow getting native scope device when it has been set from…
rnett 759a754
Rename withDevice's parameter
rnett ed1da29
Update scope for fix review comments
rnett 4504a6f
Clarify the difference between CustomGradient and RawCustomGradient
rnett be4840c
Remove experiment
rnett 9601138
Adjust GraphOperation#input to not require a graph lock
rnett ca5d343
Remove printing from CustomGradientTest
rnett 517fd8d
Cleanup adapter exceptions, name gradient scopes
rnett f5dd0e5
Document use of rawtypes
rnett bfbfb49
Generate the new op classes
rnett File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
tensorflow-core/tensorflow-core-api/external/custom-grad-helpers.patch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc | ||
index f3bf7b98a1e6b..c9194c36c116b 100644 | ||
--- a/tensorflow/c/c_api.cc | ||
+++ b/tensorflow/c/c_api.cc | ||
@@ -782,9 +782,9 @@ void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, | ||
|
||
extern "C" { | ||
|
||
-static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, | ||
- const char* op_type, | ||
- const char* oper_name) | ||
+TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, | ||
+ const char* op_type, | ||
+ const char* oper_name) | ||
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { | ||
return new TF_OperationDescription(graph, op_type, oper_name); | ||
} | ||
@@ -1041,8 +1041,8 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, | ||
status->status = Status::OK(); | ||
} | ||
|
||
-static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, | ||
- TF_Status* status) | ||
+TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, | ||
+ TF_Status* status) | ||
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { | ||
Node* ret = nullptr; | ||
|
||
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h | ||
index 705cf85e0512f..fb746dd4af94f 100644 | ||
--- a/tensorflow/c/c_api.h | ||
+++ b/tensorflow/c/c_api.h | ||
@@ -255,6 +255,12 @@ TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, | ||
int64_t* dims, int num_dims, | ||
TF_Status* status); | ||
|
||
+// TF_NewOperation, but without locking the graph. | ||
+// Should prefer TF_NewOperation when possible. | ||
+TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, | ||
+ const char* op_type, | ||
+ const char* oper_name); | ||
+ | ||
// Operation will only be added to *graph when TF_FinishOperation() is | ||
// called (assuming TF_FinishOperation() does not return an error). | ||
// *graph must not be deleted until after TF_FinishOperation() is | ||
@@ -406,6 +412,11 @@ TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, | ||
size_t proto_len, | ||
TF_Status* status); | ||
|
||
+// TF_FinishOperation, but without locking the graph. | ||
+// TF_FinishOperation should be preferred when possible. | ||
+TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, | ||
+ TF_Status* status); | ||
+ | ||
// If this function succeeds: | ||
// * *status is set to an OK value, | ||
// * a TF_Operation is added to the graph, |
151 changes: 151 additions & 0 deletions
151
tensorflow-core/tensorflow-core-api/external/custom-grad-symbols.patch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
Index: tensorflow/tools/def_file_filter/BUILD | ||
IDEA additional info: | ||
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | ||
<+>UTF-8 | ||
=================================================================== | ||
diff --git a/tensorflow/tools/def_file_filter/BUILD b/tensorflow/tools/def_file_filter/BUILD | ||
--- a/tensorflow/tools/def_file_filter/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad) | ||
+++ b/tensorflow/tools/def_file_filter/BUILD (date 1629063191558) | ||
@@ -12,3 +12,8 @@ | ||
name = "symbols_pybind", | ||
srcs = ["symbols_pybind.txt"], | ||
) | ||
+ | ||
+filegroup( | ||
+ name = "symbols_java", | ||
+ srcs = ["symbols_java.txt"], | ||
+) | ||
Index: tensorflow/BUILD | ||
IDEA additional info: | ||
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | ||
<+>UTF-8 | ||
=================================================================== | ||
diff --git a/tensorflow/BUILD b/tensorflow/BUILD | ||
--- a/tensorflow/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad) | ||
+++ b/tensorflow/BUILD (date 1629063361078) | ||
@@ -1069,13 +1069,20 @@ | ||
# the dynamic libraries of custom ops can find it at runtime. | ||
genrule( | ||
name = "tensorflow_filtered_def_file", | ||
- srcs = [":tensorflow_def_file"], | ||
+ srcs = [ | ||
+ ":tensorflow_def_file", | ||
+ ":java_symbol_target_libs_file", | ||
+ ":win_lib_files_for_java_exported_symbols", | ||
+ "//tensorflow/tools/def_file_filter:symbols_java", | ||
+ ], | ||
outs = ["tensorflow_filtered_def_file.def"], | ||
cmd = select({ | ||
"//tensorflow:windows": """ | ||
$(location @local_config_def_file_filter//:def_file_filter) \\ | ||
--input $(location :tensorflow_def_file) \\ | ||
- --output $@ | ||
+ --output $@ \\ | ||
+ --symbols $(location //tensorflow/tools/def_file_filter:symbols_java) \\ | ||
+ --lib_paths_file $(location :java_symbol_target_libs_file) | ||
""", | ||
"//conditions:default": "touch $@", # Just a placeholder for Unix platforms | ||
}), | ||
@@ -1083,6 +1090,34 @@ | ||
visibility = ["//visibility:public"], | ||
) | ||
|
||
+# Write to a file a list of all cc_library targets that we need for exporting symbols on Windows. | ||
+genrule( | ||
+ name = "java_symbol_target_libs_file", | ||
+ srcs = [":win_lib_files_for_java_exported_symbols"], | ||
+ outs = ["java_symbol_target_libs_file.txt"], | ||
+ cmd = select({ | ||
+ "//tensorflow:windows": """ | ||
+ for SRC in $(SRCS); do | ||
+ echo $$SRC | sed 's/third_party\\///g' >> $@ | ||
+ done | ||
+ """, | ||
+ "//conditions:default": "touch $@", # Just a placeholder for Unix platforms | ||
+ }), | ||
+ visibility = ["//visibility:public"], | ||
+) | ||
+ | ||
+filegroup( | ||
+ name = "win_lib_files_for_java_exported_symbols", | ||
+ srcs = [ | ||
+ "//tensorflow/cc:scope", | ||
+ "//tensorflow/cc:grad_op_registry", | ||
+ "//tensorflow/c:tf_status_helper", | ||
+ "//tensorflow/cc:ops" | ||
+ ], | ||
+ visibility = ["//visibility:private"], | ||
+) | ||
+ | ||
+ | ||
# The interface library (tensorflow.dll.if.lib) for linking tensorflow DLL library (tensorflow.dll) on Windows. | ||
# To learn more about import library (called interface library in Bazel): | ||
# https://docs.microsoft.com/en-us/cpp/build/linking-an-executable-to-a-dll?view=vs-2017#linking-implicitly | ||
Index: tensorflow/tools/def_file_filter/BUILD.tpl | ||
IDEA additional info: | ||
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | ||
<+>UTF-8 | ||
=================================================================== | ||
diff --git a/tensorflow/tools/def_file_filter/BUILD.tpl b/tensorflow/tools/def_file_filter/BUILD.tpl | ||
--- a/tensorflow/tools/def_file_filter/BUILD.tpl (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad) | ||
+++ b/tensorflow/tools/def_file_filter/BUILD.tpl (date 1629063191583) | ||
@@ -18,3 +18,8 @@ | ||
name = "symbols_pybind", | ||
srcs = ["symbols_pybind.txt"], | ||
) | ||
+ | ||
+filegroup( | ||
+ name = "symbols_java", | ||
+ srcs = ["symbols_java.txt"], | ||
+) | ||
Index: tensorflow/tools/def_file_filter/symbols_java.txt | ||
IDEA additional info: | ||
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | ||
<+>UTF-8 | ||
=================================================================== | ||
diff --git a/tensorflow/tools/def_file_filter/symbols_java.txt b/tensorflow/tools/def_file_filter/symbols_java.txt | ||
new file mode 100644 | ||
--- /dev/null (date 1629063607794) | ||
+++ b/tensorflow/tools/def_file_filter/symbols_java.txt (date 1629063607794) | ||
@@ -0,0 +1,26 @@ | ||
+[//tensorflow/cc:scope] # scope | ||
+tensorflow::Scope::graph | ||
+tensorflow::Scope::ok | ||
+tensorflow::Scope::UpdateBuilder | ||
+tensorflow::Scope::GetUniqueNameForOp | ||
+tensorflow::Scope::ExitOnError | ||
+tensorflow::Scope::WithDevice | ||
+tensorflow::Scope::WithNoControlDependencies | ||
+tensorflow::Scope::WithControlDependencies | ||
+tensorflow::Scope::NewSubScope | ||
+tensorflow::Scope::NewRootScope | ||
+tensorflow::Scope::operator= | ||
+tensorflow::Scope::~Scope | ||
+tensorflow::Scope::Scope | ||
+ | ||
+[//tensorflow/cc:ops] | ||
+tensorflow::Operation::Operation | ||
+ | ||
+[//tensorflow/cc:grad_op_registry] # custom gradients for graph | ||
+tensorflow::ops::GradOpRegistry::Global | ||
+tensorflow::ops::GradOpRegistry::Lookup | ||
+tensorflow::ops::GradOpRegistry::Register | ||
+ | ||
+[//tensorflow/c:tf_status_helper] # status helpers | ||
+tensorflow::Set_TF_Status_from_Status | ||
+tensorflow::StatusFromTF_Status | ||
=================================================================== | ||
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl | ||
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (revision 919f693420e35d00c8d0a42100837ae3718f7927) | ||
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (date 1632048268359) | ||
@@ -143,8 +143,8 @@ | ||
re_filter_comp = re.compile(r"{}".format(re_filter)) | ||
|
||
# Filter out symbol from the split line (`sym_split` in the for loop below). | ||
- sym_line_filter = r".*\s+\| (.*) \(.*" | ||
- sym_line_filter_anomaly = r".*\s+\| (.*)" | ||
+ sym_line_filter = r".*\s+\| (.*?) \(.*" | ||
+ sym_line_filter_anomaly = r".*\s+\| (.*?)" | ||
|
||
for sym_line in sym_split: | ||
if re_filter_comp.search(sym_line): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/GradFunc.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE | ||
|
||
package org.tensorflow.internal.c_api; | ||
|
||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.tensorflow.internal.c_api.global.tensorflow.*; | ||
|
||
|
||
/** GradFunc is the signature for all gradient functions in GradOpRegistry. | ||
* Implementations should add operations to compute the gradient outputs of | ||
* 'op' (returned in 'grad_outputs') using 'scope' and 'grad_inputs'. */ | ||
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) | ||
public class GradFunc extends FunctionPointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public GradFunc(Pointer p) { super(p); } | ||
protected GradFunc() { allocate(); } | ||
private native void allocate(); | ||
public native @ByVal NativeStatus call(@Const @ByRef TF_Scope scope, @Const @ByRef NativeOperation op, | ||
@Const @ByRef NativeOutputVector grad_inputs, | ||
NativeOutputVector grad_outputs); | ||
} |
48 changes: 48 additions & 0 deletions
48
...w-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/GradOpRegistry.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE | ||
|
||
package org.tensorflow.internal.c_api; | ||
|
||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.tensorflow.internal.c_api.global.tensorflow.*; | ||
|
||
|
||
/** GradOpRegistry maintains a static registry of gradient functions. | ||
* Gradient functions are indexed in the registry by the forward op name (i.e. | ||
* "MatMul" -> MatMulGrad func). */ | ||
@Namespace("tensorflow::ops") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) | ||
public class GradOpRegistry extends Pointer { | ||
static { Loader.load(); } | ||
/** Default native constructor. */ | ||
public GradOpRegistry() { super((Pointer)null); allocate(); } | ||
Craigacp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/** Native array allocator. Access with {@link Pointer#position(long)}. */ | ||
public GradOpRegistry(long size) { super((Pointer)null); allocateArray(size); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public GradOpRegistry(Pointer p) { super(p); } | ||
private native void allocate(); | ||
private native void allocateArray(long size); | ||
@Override public GradOpRegistry position(long position) { | ||
return (GradOpRegistry)super.position(position); | ||
} | ||
@Override public GradOpRegistry getPointer(long i) { | ||
return new GradOpRegistry((Pointer)this).offsetAddress(i); | ||
} | ||
|
||
/** Registers 'func' as the gradient function for 'op'. | ||
* Returns true if registration was successful, check fails otherwise. */ | ||
public native @Cast("bool") boolean Register(@StdString BytePointer op, GradFunc func); | ||
public native @Cast("bool") boolean Register(@StdString String op, GradFunc func); | ||
|
||
/** Sets 'func' to the gradient function for 'op' and returns Status OK if | ||
* the gradient function for 'op' exists in the registry. | ||
* Note that 'func' can be null for ops that have registered no-gradient with | ||
* the registry. | ||
* Returns error status otherwise. */ | ||
public native @ByVal NativeStatus Lookup(@StdString BytePointer op, @ByPtrPtr GradFunc func); | ||
public native @ByVal NativeStatus Lookup(@StdString String op, @ByPtrPtr GradFunc func); | ||
|
||
/** Returns a pointer to the global gradient function registry. */ | ||
public static native GradOpRegistry Global(); | ||
} |
41 changes: 41 additions & 0 deletions
41
tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/NameMap.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE | ||
|
||
package org.tensorflow.internal.c_api; | ||
|
||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.tensorflow.internal.c_api.global.tensorflow.*; | ||
|
||
@Name("std::unordered_map<tensorflow::string,tensorflow::Node*>") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) | ||
public class NameMap extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public NameMap(Pointer p) { super(p); } | ||
public NameMap() { allocate(); } | ||
Craigacp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
private native void allocate(); | ||
public native @Name("operator =") @ByRef NameMap put(@ByRef NameMap x); | ||
|
||
public boolean empty() { return size() == 0; } | ||
public native long size(); | ||
|
||
@Index public native Node get(@StdString BytePointer i); | ||
public native NameMap put(@StdString BytePointer i, Node value); | ||
|
||
public native void erase(@ByVal Iterator pos); | ||
public native @ByVal Iterator begin(); | ||
public native @ByVal Iterator end(); | ||
@NoOffset @Name("iterator") public static class Iterator extends Pointer { | ||
public Iterator(Pointer p) { super(p); } | ||
public Iterator() { } | ||
|
||
public native @Name("operator ++") @ByRef Iterator increment(); | ||
public native @Name("operator ==") boolean equals(@ByRef Iterator it); | ||
public native @Name("operator *().first") @MemberGetter @StdString BytePointer first(); | ||
public native @Name("operator *().second") @MemberGetter @Const Node second(); | ||
} | ||
|
||
public native long erase(@StdString BytePointer key); | ||
} | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.