88from keras_hub .src .models .gemma .rms_normalization import RMSNormalization
99from keras_hub .src .utils .keras_utils import clone_initializer
1010from keras_hub .src .utils .keras_utils import fused_attention_op_available
11+ from keras_hub .src .utils .keras_utils import gpu_supports_fused_attention_op
12+ from keras_hub .src .utils .keras_utils import running_on_gpu
1113from keras_hub .src .utils .keras_utils import running_on_tpu
1214
1315
1416class CachedGemma3Attention (keras .layers .Layer ):
1517 """A cached grouped query attention layer for Gemma3.
1618
17- This is different from Gemma and Gemma2 in several ways:
19+ This is the same as the attention layer used for Gemma and Gemma2. It
20+ exposes a few additional args:
1821
19- - `use_query_key_norm`: Applies RMS Norm on query, key.
20- - `rope_wavelength`: RoPE wavelength differs from local to global attention
21- layers.
22- - `rope_scaling_factor`: RoPE scaling factor differs from local to global
23- attention layers.
22+ `use_query_key_norm`: bool. If True, apply RMS normalization on query
23+ and key. For Gemma3, this is True.
24+ `rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
25+ uses 10K for local attention layers and 1M for global attention layers.
26+ `gate_dim_reduction`: int. In the gating layers, the output dimension is
27+ `intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
28+ value is 2. For Gemma3, it is 1.
29+
30+ Moreover, the call() method takes in a `cache_update_mask` so as to make
31+ sure that the key-value cache is updated only for the non-prompt tokens
32+ during generation.
2433 """
2534
2635 def __init__ (
@@ -139,17 +148,22 @@ def _apply_rope(self, x, start_index):
139148 x = self .rope_layer (x , start_index = start_index )
140149 return x
141150
142- def _can_use_flash_attention (self ):
151+ def _use_fused_attention_op (self ):
143152 if not fused_attention_op_available ():
144153 return False
145154 if self .dropout > 0.0 :
146155 return False
147- if self .logit_soft_cap is None :
148- return True
149- sig = inspect .signature (ops .dot_product_attention )
150- # We can currently only run soft capped attention for keras >= 3.10
151- # and only on TPU.
152- return running_on_tpu () and "attn_logits_soft_cap" in sig .parameters
156+ if running_on_gpu ():
157+ # GPU never supports softcap in the fused op.
158+ if self .logit_soft_cap is not None :
159+ return False
160+ return gpu_supports_fused_attention_op ()
161+ elif running_on_tpu ():
162+ # TPU supports softcap with on keras >= 3.10.
163+ sig = inspect .signature (ops .dot_product_attention )
164+ return "attn_logits_soft_cap" in sig .parameters
165+ else :
166+ return False
153167
154168 def _compute_attention (
155169 self ,
@@ -166,7 +180,14 @@ def _compute_attention(
166180 query_normalization = 1 / np .sqrt (
167181 self .hidden_dim // self .num_query_heads
168182 )
169- if self ._can_use_flash_attention ():
183+
184+ if self .use_sliding_window_attention and attention_mask is not None :
185+ attention_mask = self ._mask_sliding_window (
186+ attention_mask ,
187+ cache_update_index = cache_update_index ,
188+ )
189+
190+ if self ._use_fused_attention_op ():
170191 if attention_mask is not None :
171192 attention_mask = ops .expand_dims (attention_mask , axis = 1 )
172193 attention_mask = ops .cast (attention_mask , dtype = "bool" )
@@ -205,13 +226,8 @@ def _compute_attention(
205226 ops .tanh (attention_logits ), self .logit_soft_cap
206227 )
207228
208- if self .use_sliding_window_attention :
209- attention_mask = self ._mask_sliding_window (
210- attention_mask ,
211- cache_update_index = cache_update_index ,
212- )
213-
214- attention_mask = attention_mask [:, None , None , :, :]
229+ if attention_mask is not None :
230+ attention_mask = attention_mask [:, None , None , :, :]
215231 orig_dtype = attention_logits .dtype
216232 attention_softmax = self .softmax (attention_logits , mask = attention_mask )
217233 attention_softmax = ops .cast (attention_softmax , orig_dtype )
@@ -256,6 +272,7 @@ def call(
256272 attention_mask = None ,
257273 cache = None ,
258274 cache_update_index = 0 ,
275+ cache_update_mask = None ,
259276 training = False ,
260277 ):
261278 query = self .query_dense (x )
@@ -275,7 +292,43 @@ def call(
275292
276293 key_update = self ._apply_rope (key_update , cache_update_index )
277294 value_update = self .value_dense (x )
295+
296+ # Update cache. Note that the cache is updated only if the
297+ # corresponding `cache_update_mask` value is True. This is to
298+ # ensure that we don't update the cache at indices corresponding to
299+ # the prompt. For Gemma3, in particular, this is useful because
300+ # image tokens have bidirectional attention. During generation,
301+ # if we have uneven inputs during generation, we might end up having
302+ # causal attention between image tokens, which is incorrect. To
303+ # avoid this, bidirectional attention is taken care of during
304+ # the prefill step, and during generation, the cache is not updated
305+ # for the prompt. The shape of `cache_update_mask` is
306+ # `(bsz, seq_len)`, where `seq_len` is 1 when we are generating
307+ # token-by-token.
278308 start = [0 , cache_update_index , 0 , 0 ]
309+ if cache_update_mask is not None :
310+ cache_update_mask = ops .expand_dims (
311+ ops .expand_dims (cache_update_mask , axis = - 1 ),
312+ axis = - 1 ,
313+ )
314+ key_original = ops .slice (
315+ key_cache , start , ops .shape (key_update )
316+ )
317+ value_original = ops .slice (
318+ value_cache , start , ops .shape (value_update )
319+ )
320+
321+ key_update = ops .where (
322+ cache_update_mask ,
323+ key_update ,
324+ key_original ,
325+ )
326+ value_update = ops .where (
327+ cache_update_mask ,
328+ value_update ,
329+ value_original ,
330+ )
331+
279332 key = ops .slice_update (key_cache , start , key_update )
280333 value = ops .slice_update (value_cache , start , value_update )
281334 cache = ops .stack ((key , value ), axis = 1 )
0 commit comments