Skip to content

Sliding window attention#14

Merged
gonzalobenegas merged 2 commits intomainfrom
sliding-window-attention
Dec 3, 2025
Merged

Sliding window attention#14
gonzalobenegas merged 2 commits intomainfrom
sliding-window-attention

Conversation

@gonzalobenegas
Copy link
Copy Markdown
Member

No description provided.

@claude
Copy link
Copy Markdown

claude bot commented Dec 3, 2025

Code Review: Sliding Window Attention Implementation

This PR implements sliding window attention with alternating global/local patterns. Overall, the implementation is well-designed and follows project standards. Below is detailed feedback organized by category.


✅ Strengths

1. Excellent Code Quality

  • Clean, well-documented code with comprehensive Google-style docstrings
  • Proper type hints using Python 3.13+ syntax (int | None)
  • Good separation of concerns (attention logic, patterns, transformer integration)
  • Follows DRY principle with reusable pattern generators

2. Strong Test Coverage

  • 472 lines of real tests in test_attention.py
  • 237 lines of tests in test_attention_patterns.py
  • Tests validate actual behavior, not mocks (follows anti-fake test principle)
  • Comprehensive edge case coverage (batch=1, seq_len=1, window=0, etc.)
  • Integration tests with Transformer class
  • Device and dtype tests (CPU, CUDA, float32, bfloat16)

3. Smart Performance Optimizations

  • lru_cache on _get_or_create_block_mask() to avoid FlexAttention recompilation
  • Backward-compatible fallback to F.scaled_dot_product_attention when sliding_window=None
  • Proper dtype handling for autocast compatibility (lines 132-140 in attention.py)

4. Thoughtful Architecture

  • Drop-in replacement for standard attention function
  • Per-layer sliding window configuration via list
  • Rich set of pattern generators (7 different patterns!)
  • Hydra integration with _target_ instantiation

🔍 Issues Found

1. Experiment Config Inconsistency ⚠️ Minor

File: configs/experiment/clm_transformer_small.yaml

Issue: Line 6 removes the data override that was present before:

- defaults:
-   - override /data: plants
+ defaults:

Impact: This changes the default data source for the experiment. Is this intentional?

Recommendation:

  • If intentional, add a comment explaining why
  • If not, restore the - override /data: plants line

2. Incomplete Cache Test ⚠️ Minor

File: tests/test_attention.py, lines 314-317

Issue: Test function body is empty:

def test_mask_caching_reuses_block_mask():
    """Verify same mask configuration returns cached BlockMask."""
    # Clear cache
    _get_or_create_block_mask.cache_clear()
    # Test body is missing!

Recommendation: Either implement the test or remove it. Suggestion:

def test_mask_caching_reuses_block_mask():
    """Verify same mask configuration returns cached BlockMask."""
    _get_or_create_block_mask.cache_clear()
    
    # Create mask twice with same params
    mask1 = _get_or_create_block_mask("sliding_window", 8, 2, 4, 16, "cpu")
    mask2 = _get_or_create_block_mask("sliding_window", 8, 2, 4, 16, "cpu")
    
    # Should return same object (cached)
    assert mask1 is mask2
    
    # Cache should have 1 entry
    assert _get_or_create_block_mask.cache_info().hits == 1

3. Missing Input Validation ⚠️ Minor

File: glm_experiments/models/utils/attention_patterns.py

Issue: No validation for negative or zero values in several functions:

  • alternating_global_local(): No check for n_layers <= 0 or window_size <= 0
  • sparse_transformer(): No check for global_every <= 0
  • longformer_style(): No check for base_window <= 0

Recommendation: Add validation at the top of each function:

def alternating_global_local(n_layers: int, window_size: int, start_with_global: bool = True):
    if n_layers <= 0:
        raise ValueError(f"n_layers must be positive, got {n_layers}")
    if window_size <= 0:
        raise ValueError(f"window_size must be positive, got {window_size}")
    # ... rest of function

💡 Suggestions (Optional Improvements)

1. Add Benchmarking Tests

The PR mentions performance in docstrings but doesn't include timing tests. Consider adding (as @pytest.mark.slow):

@pytest.mark.slow
def test_sliding_window_compilation_overhead():
    """Verify first call has compilation overhead, subsequent calls are fast."""
    # Measure first call (with compilation)
    # Measure second call (cached)
    # Assert second call is significantly faster

2. Document PyTorch Version Requirements

FlexAttention API is relatively new. Consider adding version requirements to docstring or README.

3. Consider Variable-Length Sequences

Current implementation assumes fixed seq_len. For variable-length batches (common in NLP), padding tokens need special handling. Worth documenting this limitation.


🔒 Security & Performance

No security concerns - No user input processing, no file I/O, no network calls
Performance well-considered - Caching strategy is sound
No memory leaks - Cache size limited to 16 entries


📋 Checklist Compliance

Checking against CLAUDE.md requirements:

  • Formatting: Code follows Black/isort standards
  • Type hints: All functions properly typed
  • Docstrings: Comprehensive Google-style docstrings
  • Tests: Real tests, no fakes/mocks
  • DRY: No code duplication
  • Simplicity: Design is simple and focused
  • Hydra integration: Proper use of _target_ and hydra.utils.instantiate
  • ⚠️ No bloat: Mostly clean, but one empty test function

🎯 Verdict

APPROVE with minor fixes recommended

This is high-quality work that significantly extends the project's capabilities. The implementation is solid, well-tested, and follows all project standards. The issues identified are minor and can be addressed in follow-up commits or merged as-is with the understanding that:

  1. The data config change in clm_transformer_small.yaml is intentional
  2. The empty cache test will be implemented or removed in a follow-up
  3. Input validation can be added incrementally as needed

Recommended action: Merge after addressing the empty test function and confirming the config change is intentional.


Great work! The sliding window attention implementation is clean, well-architected, and production-ready. 🚀

@gonzalobenegas gonzalobenegas merged commit d4f6338 into main Dec 3, 2025
6 checks passed
@gonzalobenegas gonzalobenegas deleted the sliding-window-attention branch December 3, 2025 21:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant