-
Notifications
You must be signed in to change notification settings - Fork 677
Description
Description
While analyzing the sampling implementations in gemma/gm/text/_sampling.py, I identified a potential division by zero issue when temperature is set to 0.0.
The RandomSampling, TopkSampling, and TopPSampling classes all perform division by self.temperature without checking if it is zero. This causes inf or NaN values in the logits, leading to undefined behavior or crashes.
Affected Components
RandomSampling.get_next_tokens (line 61)
TopkSampling.get_next_tokens (line 78)
TopPSampling.get_next_tokens (line 95)
Expected Behavior
When temperature is 0.0 (or effectively zero), the sampler should theoretically behave deterministically, equivalent to Greedy sampling.
Proposed Solution
Add a guard clause to these methods to check if self.temperature is below a small epsilon (e.g., 1e-6). If so, delegate to Greedy().get_next_tokens(...) to ensure stable and deterministic output.
Reproduction
python :
import jax
from gemma.gm.text import _sampling
logits = jax.numpy.array([[1.0, 2.0]])
rng = jax.random.PRNGKey(0)
sampler = _sampling.RandomSampling(temperature=0.0)
This results in division by zero
sampler.get_next_tokens(logits, rng)