Skip to content

Commit 3d535d2

Browse files
authored
Merge pull request #470 from rstudio/as-tensor-generic
add S3 generic `as_tensor`
2 parents 8daddd0 + d466e59 commit 3d535d2

File tree

6 files changed

+140
-2
lines changed

6 files changed

+140
-2
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ S3method(as.matrix,tensorflow.python.ops.variables.Variable)
4545
S3method(as.numeric,python.builtin.EagerTensor)
4646
S3method(as.numeric,tensorflow.python.framework.ops.EagerTensor)
4747
S3method(as.numeric,tensorflow.python.ops.variables.Variable)
48+
S3method(as_tensor,default)
49+
S3method(as_tensor,double)
4850
S3method(asin,tensorflow.tensor)
4951
S3method(atan,tensorflow.tensor)
5052
S3method(ceiling,tensorflow.tensor)
@@ -74,6 +76,7 @@ S3method(tanpi,tensorflow.tensor)
7476
export("%as%")
7577
export(all_dims)
7678
export(array_reshape)
79+
export(as_tensor)
7780
export(dict)
7881
export(evaluate)
7982
export(export_savedmodel)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# tensorflow (development version)
22

3+
- Added S3 generic `as_tensor()`.
4+
35
# tensorflow 2.5.0
46

57
- Updated default Tensorflow version to 2.5.

R/generics.R

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,59 @@ switch_fun_if_tf <- function(x, y, version = "1.14") {
336336
else
337337
y
338338
}
339+
340+
341+
342+
#' as_tensor
343+
#'
344+
#' Coerce objects to tensorflow tensors (potentially of a specific dtype). The
345+
#' provided default methods will call
346+
#' [`tf.convert_to_tensor`](https://www.tensorflow.org/api_docs/python/tf/convert_to_tensor)
347+
#' and [`tf.cast`](https://www.tensorflow.org/api_docs/python/tf/cast) as
348+
#' appropriate.
349+
#'
350+
#' @param x object to convert
351+
#' @param dtype `NULL`, a tensorflow dtype (`tf$int32`), or something coercible
352+
#' to one (e.g. a string `"int32"`)
353+
#' @param ..., ignored
354+
#' @param name `NULL` or a string. Useful for debugging in graph mode, ignored
355+
#' while in eager mode.
356+
#'
357+
#' @return a tensorflow tensor
358+
#'
359+
#' @export
360+
#'
361+
#' @examples
362+
#' \dontrun{
363+
#' as_tensor(42, "int32")
364+
#' as_tensor(as_tensor(42))
365+
#' }
366+
as_tensor <- function(x, dtype = NULL, ..., name = NULL) UseMethod("as_tensor")
367+
368+
#' @rdname as_tensor
369+
#' @export
370+
as_tensor.default <- function(x, dtype = NULL, ..., name = NULL) {
371+
x <- tf$convert_to_tensor(x, dtype_hint = dtype, name = name)
372+
if (is.null(dtype))
373+
x
374+
else
375+
tf$cast(x, dtype, name = name)
376+
}
377+
378+
#' @rdname as_tensor
379+
#' @export
380+
as_tensor.double <- function(x, dtype = NULL, ..., name = NULL) {
381+
if (!is.null(dtype)) {
382+
dtype <- tf$as_dtype(dtype)
383+
if (dtype$is_integer) {
384+
# tf.cast() overflows quietly, at least R raises a warning (and produces NA)
385+
if (dtype$max <= .Machine$integer.max)
386+
storage.mode(x) <- "integer"
387+
388+
if (anyNA(x))
389+
stop("converting R numerics with NA values to integer dtypes not supported")
390+
}
391+
}
392+
393+
NextMethod()
394+
}

man/as_tensor.Rd

Lines changed: 41 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/helper-utils.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ arr <- function(..., mode = "double", gen = seq_len)
3131
set.seed(42)
3232
rarr <- function(...) arr(..., gen=runif)
3333

34-
as_tensor <- function(...) tf$convert_to_tensor(...)
35-
3634
expect_near <- function(..., tol = 1e-5) expect_equal(..., tolerance = tol)
3735

3836

tests/testthat/test-as_tensor.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
test_that("as_tensor works", {
2+
3+
test_is_tensor <- function(x, dtype) {
4+
expect_s3_class(x, "tensorflow.tensor")
5+
if(is.character(dtype))
6+
expect_true(x$dtype[[dtype]])
7+
else if(!is.null(dtype))
8+
expect_true(x$dtype == tf$as_dtype(dtype))
9+
}
10+
11+
test_is_tensor(as_tensor(3), 'is_floating')
12+
test_is_tensor(as_tensor(3L), tf$int32)
13+
test_is_tensor(as_tensor("foo"), tf$string)
14+
test_is_tensor(as_tensor(TRUE), tf$bool)
15+
test_is_tensor(as_tensor(1+1i), 'is_complex')
16+
17+
test_is_tensor(as_tensor(3L, tf$int32) , tf$int32)
18+
test_is_tensor(as_tensor(3L, tf$int64) , tf$int64)
19+
test_is_tensor(as_tensor(3L, tf$float32) , tf$float32)
20+
test_is_tensor(as_tensor(3L, tf$float64) , tf$float64)
21+
test_is_tensor(as_tensor(3L, tf$int8) , tf$int8)
22+
23+
test_is_tensor(as_tensor(3.0, tf$float32), tf$float32)
24+
test_is_tensor(as_tensor(3.0, tf$float64), tf$float64)
25+
test_is_tensor(as_tensor(3.0, tf$int32) , tf$int32)
26+
test_is_tensor(as_tensor(3.0, tf$int64) , tf$int64)
27+
test_is_tensor(as_tensor(3.0, tf$int8) , tf$int8)
28+
29+
# currently scalars -> float32; arrays -> float64
30+
test_is_tensor(as_tensor(arr(3)) , 'is_floating')
31+
test_is_tensor(as_tensor(arr(3, 3)) , 'is_floating')
32+
test_is_tensor(as_tensor(arr(3, 3, 3)) , 'is_floating')
33+
34+
x <- tf$constant(3)
35+
test_is_tensor(as_tensor(x, tf$int32), tf$int32)
36+
test_is_tensor(as_tensor(x, tf$int64), tf$int64)
37+
38+
})

0 commit comments

Comments
 (0)