-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Description
ERROR:absl:Failed to apply Splash kernel for flash attention. Falling back to JAX native dot_product_attention.
Traceback (most recent call last):
File "/usr/local/lib/python3.12/site-packages/keras/src/backend/jax/nn.py", line 1336, in dot_product_attention
output = wrap_flash_attention(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/keras/src/backend/jax/nn.py", line 1188, in wrap_flash_attention
splash_kernel = splash_attention_kernel.make_splash_mha(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2557, in _make_splash_attention
fwd_mask_info, mask_function_fwd = process_mask_fn(
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py", line 222, in hash
return hash((type(self),) + tuple(hash(mask) for mask in self.masks))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py", line 222, in
return hash((type(self),) + tuple(hash(mask) for mask in self.masks))
^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py", line 516, in hash
return hash((type(self), self.array.tobytes()))
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/core.py", line 935, in tobytes
raise ConcretizationTypeError(self,
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[1024,1024]
The tobytes() method was called on traced array with shape bool[1024,1024].
The error occurred while tracing the function compiled_generate_function at /usr/local/lib/python3.12/site-packages/keras_hub/src/models/causal_lm.py:167 for jit. This concrete value was not available in Python because it depends on the value of the argument inputs['padding_mask'].
Kaggle v5e-8 instance, Keras 3.12.0