Skip to content

Conversation

@AmitMY
Copy link

@AmitMY AmitMY commented Dec 4, 2025

Related to #347

Summary

  • Replace manual attention with F.scaled_dot_product_attention
  • Use repeat_interleave instead of meshgrid for position computation (reduces arange calls from 2052 to ~1026)
  • Build image_ids efficiently with repeat_interleave instead of incremental F.pad
  • Remove unused Rearrange import

Benchmark

Tested on 512 variable-sized images (16px tall, 32-80px wide), depth=12, ViT-tiny config:

Version Time Speedup
Original 91.1ms 1.0x
Optimized 58.4ms 1.56x

and compared to ViT (padding for ViT)

Model fp32 bf16
ViT (padded) 4.8ms 5.9ms
NaViT (optimized) 54.9ms 36.2ms
Ratio 11.5x slower 6.1x slower

Numerical equivalence

  • Max diff: ~5e-4 (within flash attention tolerance)
  • torch.allclose(original, optimized, rtol=1e-3, atol=1e-3) returns True

Test plan

  • Verified outputs match original within numerical tolerance
  • Tested with variable-sized images
  • Tested with group_images=True
  • No API changes - drop-in replacement

🤖 Generated with Claude Code

- Replace manual attention with F.scaled_dot_product_attention
- Use repeat_interleave instead of meshgrid for position computation
- Build image_ids efficiently with repeat_interleave instead of F.pad
- Remove unused Rearrange import

~56% speedup (91ms -> 58ms on 512 variable-sized images)
Numerically equivalent (max diff ~5e-4, within flash attention tolerance)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant