-
Hi, Is there any function like keras.layers.Masking, which masks padded value in input sequences? |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 6 replies
-
Hi @dev-sora, Flax doesn't have a function like that. |
Beta Was this translation helpful? Give feedback.
-
Hi @matthias-wright , thanks for your responding. |
Beta Was this translation helpful? Give feedback.
-
I was asking myself the same question. Thanks, @matthias-wright for pointing out jax2tf. By the way, |
Beta Was this translation helpful? Give feedback.
-
@dev-sora, I think you are saying you want to mask out a batch of sequences after a certain length, where these lengths can differ per example, is that correct? Will this work for you? def mask_sequences(sequences, *, eos_id=0):
lengths = jnp.where(sequences == eos_id)[-1]
lengths = jnp.expand_dims(lengths, -1)
max_len = jnp.arange(sequences.shape[1])
return sequences * (lengths > max_len)
sequences = jnp.array([[1, 2, 3, 0, 5, 6],
[1, 0, 3, 4, 5, 6]])
mask_sequences(sequences) Returns
This will mask everything after the occurrence of |
Beta Was this translation helpful? Give feedback.
Hi @matthias-wright , thanks for your responding.
I wanna use a function like padding to input data acquired using a sensor with an unstable sampling rate into a neural network.