Skip to content

Use of PyTorch layers and tensors inside a flax network? #2272

Answered by marcvanzee
rosikand asked this question in Q&A
Discussion options

You must be logged in to vote

No, that is not possible if you are planning on jitting your code (which is strongly recommended). jit will compile your code and use Tracers, which do not work with numpy arrays. Here's a short example:

import jax
import numpy as np

def sum_it(x):
  x = np.array(x)
  x = x.sum()
  return jnp.array(x)

jax.jit(sum_it)(jax.numpy.array(1))  # Tracer error!
Error trace
UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on 
the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
While tracing the function sum_it at <ipython-input-129-c2c23ebbfd77>:4 for jit, this concrete v…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants