Replies: 1 comment 3 replies
-
Hey @PVarnai, we don't have a "sliding_window_scan", while it sounds cool you can mimic it by first creating the windows and then scanning over them. Here is a link to some JAX code for creating such windows: Copying the code here for convenience (credit to @erdmann): from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
starts = jnp.arange(len(a) - size + 1)
return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
a = jnp.arange(10)
print(moving_window(a, 4))
You can generalize it for |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi!
I found some great examples of scanning over a sequence of data one by one, also using the lifted
linen.scan
version which is convenient with the flax training loop. But what if I want to scan a rolling window over a sequence of data? I couldn't really find anything, and I'm surprised this is not a usual operation (or it is I'm just missing something). For example, if I had a tensorx
of shape(T, dims...)
, instead of getting the slicesx[0, dims...]
,x[1, dims...]
, ... to operate on within each iteration of the scan, I would want to getx[0:window, dims...]
,x[1:window+1, dims...]
, and so on.Thanks for any help!
Beta Was this translation helpful? Give feedback.
All reactions