@@ -102,31 +102,44 @@ as.logical.tensorflow.python.ops.variables.Variable <- as.logical.python.builtin
102
102
# ' runtime to apply optimizations and exploit parallelism in the computation
103
103
# ' defined by `f`.
104
104
# '
105
+ # ' A guide to getting started with
106
+ # ' [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) can
107
+ # ' be found [here](https://www.tensorflow.org/guide/function).
108
+ # '
105
109
# ' @param f the function to be compiled
106
110
# ' @param input_signature A possibly nested sequence of `tf$TensorSpec` objects
107
111
# ' specifying the shapes and dtypes of the tensors that will be supplied to
108
112
# ' this function. If `NULL`, a separate function is instantiated for each
109
113
# ' inferred input signature. If `input_signature` is specified, every input to
110
114
# ' `f` must be a tensor.
111
- # ' @param autograph Whether autograph should be applied on `f` before tracing a
112
- # ' graph. This allows for dynamic control flow (if's, loops etc.) in the
113
- # ' traced graph. See https://www. tensorflow.org/ guide/autograph for more
114
- # ' information. Note: We set the default to `FALSE` until this functionality
115
- # ' is available from R.
115
+ # ' @param autograph TRUE or FALSE. If TRUE (the default), you can use tensors in
116
+ # ' R control flow expressions `if`, `while`, `for` and `break` and they will
117
+ # ' be traced into the tensorflow graph. A guide to getting started and
118
+ # ' additional details can be found:
119
+ # ' [here](https://t-kalinowski.github.io/tfautograph/)
116
120
# ' @param ... additional arguments passed on to `tf.function` (vary based on
117
121
# ' Tensorflow version). See
118
- # ' https://www.tensorflow.org/api_docs/python/tf/function for details.
122
+ # ' [here](https://www.tensorflow.org/api_docs/python/tf/function#args_1) for
123
+ # ' details.
119
124
# '
120
125
# ' @export
121
126
tf_function <- function (f ,
122
127
input_signature = NULL ,
123
- autograph = FALSE , # default will change to TRUE in TF 2.6
128
+ autograph = TRUE ,
124
129
... ) {
130
+ if (! is.function(f ))
131
+ stop(" `f` must be an R function" )
125
132
126
- if (! isFALSE(autograph )) stop(" Autograph functionality is not (yet) supported from R." )
133
+ if (! (isTRUE(autograph ) || isFALSE(autograph )))
134
+ stop(" `autograph` must be TRUE or FALSE" )
127
135
128
- args <- list (py_func(f ), input_signature , autograph , ... )
136
+ if (autograph ) {
137
+ # Can't register tfautograph in Imports yet due to circular dependency
138
+ if (! requireNamespace(" tfautograph" , quietly = TRUE ))
139
+ stop(' "tfautograph" package required if autograph=TRUE. Please run install.packages("tfautograph")' )
140
+ f <- tfautograph :: autograph(f )
141
+ }
129
142
143
+ args <- list (py_func(f ), input_signature , FALSE , ... )
130
144
do.call(tf $ `function` , args )
131
145
}
132
-
0 commit comments