Use of PyTorch layers and tensors inside a flax network? #2272
-
I am trying to build a spatial transformation layer inside my network. PyTorch has great abilities for this: Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
No, that is not possible if you are planning on jitting your code (which is strongly recommended). jit will compile your code and use Tracers, which do not work with numpy arrays. Here's a short example: import jax
import numpy as np
def sum_it(x):
x = np.array(x)
x = x.sum()
return jnp.array(x)
jax.jit(sum_it)(jax.numpy.array(1)) # Tracer error! Error trace
Currently I do not believe JAX provides those functions, perhaps you can file an issue with them and ask what their plans are? |
Beta Was this translation helpful? Give feedback.
No, that is not possible if you are planning on jitting your code (which is strongly recommended). jit will compile your code and use Tracers, which do not work with numpy arrays. Here's a short example:
Error trace