Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ line_length = 88
known_first_party = ["sparse_attention_hub"]

[tool.mypy]
python_version = "3.9"
python_version = "3.12"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
Expand Down Expand Up @@ -178,3 +178,4 @@ module = [
]
ignore_missing_imports = true
ignore_errors = true
follow_imports = "skip"
85 changes: 85 additions & 0 deletions scripts/extract_benchmark_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python3
"""Extract and print benchmark scores from a results directory.

Scans subdirectories (e.g. ruler32k_fwe, ruler32k_qa_1) for metrics.json,
reads the overall_score (or a fallback score), and prints each dataset name
with its score plus the average across all datasets.

Usage:
python scripts/extract_benchmark_scores.py --directory /path/to/model/masker
"""

import argparse
import json
from pathlib import Path


def get_score_from_metrics(metrics: dict) -> float | None:
"""Extract a single numeric score from a metrics dict.

Prefers 'overall_score'. If missing, tries task_scores (e.g. string_match).
"""
if "overall_score" in metrics:
val = metrics["overall_score"]
if isinstance(val, (int, float)):
return float(val)
if "task_scores" in metrics:
task_scores = metrics["task_scores"]
if not task_scores:
return None
# Use first task; prefer string_match then first numeric value
first = next(iter(task_scores.values()))
if isinstance(first, (int, float)):
return float(first)
if isinstance(first, dict) and "string_match" in first:
return float(first["string_match"])
for v in first.values() if isinstance(first, dict) else []:
if isinstance(v, (int, float)):
return float(v)
return None


def main() -> None:
parser = argparse.ArgumentParser(
description="Extract benchmark scores from subdirs and print dataset: score plus average."
)
parser.add_argument(
"--directory",
type=str,
required=True,
help="Path to the results folder (e.g. .../Qwen_Qwen2.5-72B-Instruct/dense)",
)
args = parser.parse_args()

base = Path(args.directory)
if not base.is_dir():
raise SystemExit(f"Not a directory: {base}")

results: list[tuple[str, float]] = []
for subdir in sorted(base.iterdir()):
if not subdir.is_dir():
continue
metrics_path = subdir / "metrics.json"
if not metrics_path.exists():
continue
try:
with metrics_path.open("r", encoding="utf-8") as f:
metrics: dict = json.load(f)
except (json.JSONDecodeError, OSError):
continue
score = get_score_from_metrics(metrics)
if score is not None:
results.append((subdir.name, score))

for name, score in results:
print(f"{name}: {score}")

if results:
avg = sum(s for _, s in results) / len(results)
print(f"average: {avg}")
else:
print("No metrics found in subdirectories.")


if __name__ == "__main__":
main()
18 changes: 13 additions & 5 deletions sparse_attention_hub/adapters/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
device: Optional[str] = None,
hybrid: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Initialize HuggingFace adapter.

Expand All @@ -46,10 +46,18 @@ def __init__(
super().__init__(model_name, sparse_attention_config, **kwargs)
self._registered_attention_name: Optional[str] = None
self._custom_attention_fn: Optional[Callable] = None
self.model_kwargs = model_kwargs or {}
self.tokenizer_kwargs = tokenizer_kwargs or {}
self.model_registry_path = kwargs.get("model_registry_path", "")
self.allow_unregistered_models = kwargs.get("allow_unregistered_models", True)
self.model_kwargs: Dict[str, Any] = model_kwargs or {}
self.tokenizer_kwargs: Dict[str, Any] = tokenizer_kwargs or {}

raw_registry_path: Any = kwargs.get("model_registry_path", "")
self.model_registry_path: str = (
raw_registry_path if isinstance(raw_registry_path, str) else ""
)

raw_allow_unregistered: Any = kwargs.get("allow_unregistered_models", True)
self.allow_unregistered_models: bool = (
raw_allow_unregistered if isinstance(raw_allow_unregistered, bool) else True
)

# more useful parameters to store
self.device = (
Expand Down
9 changes: 3 additions & 6 deletions sparse_attention_hub/adapters/model_servers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,17 @@ def _create_model(
)
elif gpu_id is not None:
if torch.cuda.is_available():
if isinstance(gpu_id, str) and gpu_id.startswith("cuda"):
device = torch.device(gpu_id)
else:
device = torch.device(f"cuda:{gpu_id}")
device = torch.device(f"cuda:{gpu_id}")
self.logger.debug(f"Moving model {model_name} to device: {device}")
model = model.to(device)
else:
self.logger.warning(
f"CUDA not available, placing model {model_name} on CPU instead of GPU {gpu_id}"
)
model = model.to(torch.device("cpu")) # type: ignore[arg-type]
model = model.to(torch.device("cpu"))
else:
# Explicitly place on CPU
model = model.to(torch.device("cpu")) # type: ignore[arg-type]
model = model.to(torch.device("cpu"))

self.logger.info(
f"Successfully created HuggingFace model: {model_name} on {'GPU' if gpu_id is not None else 'CPU'}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_model_registry(path: str) -> Dict[str, RegistryEntry]:
raise ModelRegistryError(f"Model registry path does not exist: {path}")

try:
import yaml # type: ignore
import yaml
except Exception as e:
raise ModelRegistryError(
"PyYAML is required to load a model registry; install with `pip install pyyaml`"
Expand Down
18 changes: 6 additions & 12 deletions sparse_attention_hub/sparse_attention/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,27 @@ def __init__(
self.from_index = from_index
self.is_full = is_full
self.is_empty = is_empty
self.device: torch.device = device
self.mask: Optional[torch.Tensor] = None
self.indices: Optional[torch.Tensor] = None
self.ptr: Optional[torch.Tensor] = None
self.data: Optional[torch.Tensor] = None

if is_full:
# Full mask optimization - don't store any actual data
# Device must be provided for full masks since we have no tensors to infer from
if device is None:
raise ValueError("device must be specified for full masks")
self.device = device
self.mask = None
self.indices = None
self.ptr = None
self.data = None
elif is_empty:
# Empty mask optimization - don't store any actual data
# Device must be provided for empty masks since we have no tensors to infer from
if device is None:
raise ValueError("device must be specified for empty masks")
self.device = device
self.mask = None
self.indices = None
self.ptr = None
self.data = None
elif from_dense_mask and mask is not None:
self.mask = mask.to(dtype)
self.device = mask.device
self.indices = None
self.ptr = None
self.data = None
# Check if this is actually a full mask
if self._detect_full_mask():
self.is_full = True
Expand Down Expand Up @@ -172,7 +166,7 @@ def _detect_empty_mask(self, only_check_flag: bool = True) -> bool:
return bool(torch.all(self.mask == 0.0).item())
elif self.from_index and self.indices is not None:
# For sparse representation, empty means no indices
return self.indices.numel() == 0
return int(self.indices.numel()) == 0

return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def custom_attention_callable(
@pytest.fixture
def test_config() -> LlamaConfig:
"""Create a test Llama configuration matching Llama-3.1-8B-Instruct dimensions."""
return LlamaConfig(
config = LlamaConfig(
vocab_size=128256,
hidden_size=4096,
intermediate_size=14336,
Expand All @@ -145,6 +145,10 @@ def test_config() -> LlamaConfig:
attention_dropout=0.0,
attention_bias=False,
)
# Backwards/forwards compatibility across Transformers versions.
if not hasattr(config, "rope_theta"):
setattr(config, "rope_theta", 500000.0)
return config


@pytest.fixture(params=TEST_CONFIGS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def custom_attention_callable(

@pytest.fixture
def test_config() -> LlamaConfig:
return LlamaConfig(
config = LlamaConfig(
vocab_size=128256,
hidden_size=4096,
intermediate_size=14336,
Expand All @@ -195,6 +195,10 @@ def test_config() -> LlamaConfig:
attention_dropout=0.0,
attention_bias=False,
)
# Backwards/forwards compatibility across Transformers versions.
if not hasattr(config, "rope_theta"):
setattr(config, "rope_theta", 500000.0)
return config


@pytest.fixture(params=TEST_CONFIGS)
Expand Down Expand Up @@ -305,11 +309,40 @@ def _effective_ratio(past_len: int, page_size: int, ratio: float) -> float:

@pytest.fixture
def original_attention(
test_config: LlamaConfig, test_params: Dict[str, Any]
request: pytest.FixtureRequest,
test_config: LlamaConfig,
test_params: Dict[str, Any],
) -> nn.Module:
"""
Build a LlamaAttention and monkeypatch its forward with the upstream Quest implementation.
"""
# Quest's upstream implementation calls `apply_rotary_pos_emb` with an older positional
# signature where the 5th positional argument is `position_ids`. Newer Transformers
# versions interpret the 5th positional argument as `unsqueeze_dim` (an int), causing:
# `TypeError: unsqueeze(): argument 'dim' must be int, not Tensor`.
#
# Patch in a tiny compatibility shim for this test only.
from transformers.models.llama import modeling_llama

original_apply_rotary = modeling_llama.apply_rotary_pos_emb

def _apply_rotary_pos_emb_compat(*args: Any, **kwargs: Any):
if (
len(args) == 5
and torch.is_tensor(args[4])
and "unsqueeze_dim" not in kwargs
):
q, k, cos, sin, _position_ids = args
return original_apply_rotary(q, k, cos, sin, unsqueeze_dim=1)
return original_apply_rotary(*args, **kwargs)

modeling_llama.apply_rotary_pos_emb = _apply_rotary_pos_emb_compat # type: ignore[assignment]

def _restore_apply_rotary() -> None:
modeling_llama.apply_rotary_pos_emb = original_apply_rotary # type: ignore[assignment]

request.addfinalizer(_restore_apply_rotary)

quest_forward = _load_quest_forward()

attn = LlamaAttention(config=test_config, layer_idx=32)
Expand Down
Loading