Fused kernels (with ChainRules.jl integration):
This project originates from NNop.jl and would not have been possible without Anton Smirnov.
We're developing this fork to have hackable kernels for Onion.jl, which is also its namesake.
using Pkg
Registry.add("https://github.com/MurrellGroup/MurrellGroupRegistry")
Pkg.add("Onion")See benchmarks/main.jl for comparison scripts between naїve & fused versions.
Implementation of FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
E, L, H, B = 64, 4096, 4, 4
causal = false
q = CuArray(rand(Float32, E, L, H, B))
k = CuArray(rand(Float32, E, L, H, B))
v = CuArray(rand(Float32, E, L, H, B))
o = ONIONop.flash_attention(q, k, v; causal)
∇ = Zygote.gradient(q, k, v) do q, k, v
sum(ONIONop.flash_attention(q, k, v; causal))
end- Forward & backward passes.
- Arbitrary sequence length.
- FP32, FP16, BFP16 support.
- Variable sequence length.
- Causal masking.
Implementation of Online normalizer calculation for softmax.
x = CuArray(rand(Float32, 8192, 1024))
y = ONIONop.online_softmax(x)x = CuArray(rand(Float32, 1024, 1024))
w = CuArray(rand(Float32, 1024))
y = ONIONop.rms_norm(x, w)
∇ = Zygote.gradient(x, w) do x, w
sum(ONIONop.rms_norm(x, w))
endx = CuArray(rand(Float32, 1024, 1024))
w = CuArray(rand(Float32, 1024))
w = CuArray(rand(Float32, 1024))
y = ONIONop.layer_norm(x, w)
∇ = Zygote.gradient(x, w, b) do x, w, b
sum(ONIONop.layer_norm(x, w, b))
endE, L, B = 16, 1024, 1
QH, KH = 16, 16
emb = ONIONop.LlamaRotaryEmbedding(E)
position_ids = reshape(collect(0f0:Float32(L) - 1f0), :, 1)
position_ids = repeat(position_ids; inner=(1, B))
cos, sin = emb(position_ids)
cos = Adapt.adapt(kab, cos)
sin = Adapt.adapt(kab, sin)
q = Adapt.adapt(kab, ones(Float32, (E, L, QH, B)))
k = Adapt.adapt(kab, ones(Float32, (E, L, KH, B)))
q, k = ONIONop.llama_rope(q, k; cos, sin)