Skip to content

Bug: Division by zero in sampling methods when temperature is 0.0 #562

@KumarADITHYA123

Description

@KumarADITHYA123

Description
While analyzing the sampling implementations in gemma/gm/text/_sampling.py, I identified a potential division by zero issue when temperature is set to 0.0.

The RandomSampling, TopkSampling, and TopPSampling classes all perform division by self.temperature without checking if it is zero. This causes inf or NaN values in the logits, leading to undefined behavior or crashes.

Affected Components
RandomSampling.get_next_tokens (line 61)
TopkSampling.get_next_tokens (line 78)
TopPSampling.get_next_tokens (line 95)
Expected Behavior
When temperature is 0.0 (or effectively zero), the sampler should theoretically behave deterministically, equivalent to Greedy sampling.

Proposed Solution
Add a guard clause to these methods to check if self.temperature is below a small epsilon (e.g., 1e-6). If so, delegate to Greedy().get_next_tokens(...) to ensure stable and deterministic output.

Reproduction
python :
import jax
from gemma.gm.text import _sampling
logits = jax.numpy.array([[1.0, 2.0]])
rng = jax.random.PRNGKey(0)
sampler = _sampling.RandomSampling(temperature=0.0)

This results in division by zero

sampler.get_next_tokens(logits, rng)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions