Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions benchmark/mock_benchmark/mock_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def _load_datasets(self) -> pd.DataFrame:

# Convert to DataFrame
df: pd.DataFrame = pd.DataFrame(sample_data)

# Ensure compatibility with base Benchmark request processing contract.
if "answer_prefix" not in df.columns:
df["answer_prefix"] = ""
if "max_new_tokens" not in df.columns:
df["max_new_tokens"] = 64

# Add sample IDs for tracking
df["sample_id"] = range(1, len(df) + 1)
Expand Down
10 changes: 9 additions & 1 deletion sparse_attention_hub/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,17 @@ def __init__(
self.sparse_attention_config = sparse_attention_config
self.sparse_attention = None
self.kwargs = kwargs

has_sparse_masking = True
if sparse_attention_config is None:
has_sparse_masking = False
elif hasattr(sparse_attention_config, "masker_configs") and not sparse_attention_config.masker_configs:
# Defensive path: empty masker configs are equivalent to dense attention.
has_sparse_masking = False

self.sparse_attention = (
SparseAttention.create_from_config(self.sparse_attention_config)
if self.sparse_attention_config is not None
if has_sparse_masking
else None
)

Expand Down
4 changes: 2 additions & 2 deletions sparse_attention_hub/adapters/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def __init__(
)
self.torch_dtype = self.model_kwargs.get("torch_dtype", torch.float32)

# Handle dense-only mode when sparse_attention_config is None
self._sparse_attention_available: bool = sparse_attention_config is not None
# Handle dense-only mode when sparse attention is absent or has no active maskers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the right change, we want empty masker to use our own backend, a.k.a the soft_cap functionality.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right, I misunderstood the intended behavior.
I'll push a fix shortly

self._sparse_attention_available: bool = self.sparse_attention is not None
# Control token-by-token question processing (for hybrid models)
self.hybrid = hybrid if hybrid is not None else False
# Convert device string to GPU ID for ModelServer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def create_sampling_mask_with_per_head_budget(

return sampling_mask

def apply_softcap(scores: torch.Tensor, softcap: Optional[float]):
if softcap is None:
return scores
return softcap * torch.tanh(scores / softcap)

def _compute_masked_exp_attention_weights(
queries: torch.Tensor,
Expand All @@ -192,6 +196,7 @@ def _compute_masked_exp_attention_weights(
sparse_attention_mask: Mask,
dropout: float = 0.0,
training: bool = False,
softcap: Optional[float] = None,
) -> torch.Tensor:
"""Compute masked attention weights (common logic for numerator and denominator).

Expand All @@ -217,6 +222,9 @@ def _compute_masked_exp_attention_weights(
k = key_states.to(torch.float32)
raw_attention_weights: torch.Tensor = torch.matmul(q, k.transpose(2, 3)) * scaling

# Gemma softcap
raw_attention_weights = apply_softcap(raw_attention_weights, softcap)

if attention_mask is not None:
raw_attention_weights = raw_attention_weights + attention_mask[
:, :, :, : key_states.shape[-2]
Expand All @@ -240,7 +248,7 @@ def _compute_masked_exp_attention_weights(
return exp_attention_weights


def _get_attention_denominator(exp_attention_weights: torch.Tensor) -> torch.Tensor:
def _get_attention_denominator(exp_attention_weights: torch.Tensor, softcap: Optional[float] = None) -> torch.Tensor:
"""Get attention denominator from pre-computed exponential attention weights.

Args:
Expand All @@ -255,6 +263,7 @@ def _get_attention_denominator(exp_attention_weights: torch.Tensor) -> torch.Ten
def _get_attention_numerator(
exp_attention_weights: torch.Tensor,
value_states: torch.Tensor,
softcap: Optional[float] = None
) -> torch.Tensor:
"""Get attention numerator from pre-computed exponential attention weights and prepared values.

Expand All @@ -276,7 +285,8 @@ def get_attention_denominator(
scaling: float,
dropout: float,
sparse_attention_mask: Mask,
**kwargs: Dict[str, Any],
softcap: Optional[float] = None,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
"""Get masked attention denominator.

Expand All @@ -302,6 +312,7 @@ def get_attention_denominator(
sparse_attention_mask=sparse_attention_mask,
dropout=dropout,
training=training,
softcap=softcap,
)

return _get_attention_denominator(exp_attention_weights)
Expand All @@ -316,6 +327,7 @@ def get_attention_numerator(
scaling: float,
dropout: float,
sparse_attention_mask: Mask,
softcap: Optional[float] = None,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
"""Get masked attention numerator.
Expand Down Expand Up @@ -343,6 +355,7 @@ def get_attention_numerator(
sparse_attention_mask=sparse_attention_mask,
dropout=dropout,
training=training,
softcap=softcap,
)

# Prepare values by applying key-value grouping
Expand Down Expand Up @@ -413,6 +426,7 @@ def get_masked_attention_output(
scaling: float,
dropout: float,
sparse_attention_mask: Mask,
softcap: Optional[float] = None,
return_attention_weights: bool = False,
**kwargs: Dict[str, Any],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Expand Down Expand Up @@ -448,6 +462,7 @@ def get_masked_attention_output(
sparse_attention_mask=sparse_attention_mask,
dropout=dropout,
training=training,
softcap=softcap,
)

# Prepare values by applying key-value grouping
Expand All @@ -457,8 +472,12 @@ def get_masked_attention_output(
)

# Use internal helpers with pre-computed weights
num: torch.Tensor = _get_attention_numerator(exp_attention_weights, value_states)
den: torch.Tensor = _get_attention_denominator(exp_attention_weights)
num: torch.Tensor = _get_attention_numerator(
exp_attention_weights, value_states, softcap=softcap
)
den: torch.Tensor = _get_attention_denominator(
exp_attention_weights, softcap=softcap
)

num, den, exp_attention_weights = apply_sink_bias(
num=num,
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,28 @@ def test_generate_unique_attention_name(self, mock_tokenizer, mock_model) -> Non
assert name2.startswith("sparse_attention_")
assert name1 != name2 # Should be unique

@patch(
"sparse_attention_hub.adapters.model_servers.huggingface.AutoModelForCausalLM"
)
@patch("sparse_attention_hub.adapters.model_servers.huggingface.AutoTokenizer")
def test_empty_masker_config_treated_as_dense(
self, mock_tokenizer, mock_model
) -> None:
"""Empty masker configs should behave like dense mode."""
mock_tokenizer_instance = Mock()
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance

mock_model_instance = Mock()
mock_model.from_pretrained.return_value = mock_model_instance

adapter = ModelAdapterHF(
sparse_attention_config=ResearchAttentionConfig(masker_configs=[]),
model_name="test-model",
)

assert adapter.sparse_attention is None
assert adapter._sparse_attention_available is False

@patch(
"sparse_attention_hub.adapters.model_servers.huggingface.AutoModelForCausalLM"
)
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/sparse_attention/utils/test_mask_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sparse_attention_hub.sparse_attention.utils.mask import Mask
from sparse_attention_hub.sparse_attention.utils.mask_attention_utils import (
_compute_masked_exp_attention_weights,
apply_softcap,
apply_inv_mask_sum,
create_sampling_mask_with_per_head_budget,
get_attention_denominator,
Expand Down Expand Up @@ -601,6 +602,15 @@ def test_compute_masked_attention_weights_different_scaling(self):

assert torch.allclose(result, expected_exp_weights, atol=1e-6)

def test_softcap_limits_attention_scores(self):
"""Softcap transformation should asymptotically bound large logits."""
scores = torch.tensor([100.0])
cap = 50.0

result = apply_softcap(scores, cap)

assert result.item() <= cap


@pytest.mark.unit
class TestGetAttentionDenominator:
Expand Down