Skip to content

MurrellGroup/ONIONop.jl

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

38 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ONIONop.jl

Build Status Coverage

Fused kernels (with ChainRules.jl integration):

This project originates from NNop.jl and would not have been possible without Anton Smirnov.

We're developing this fork to have hackable kernels for Onion.jl, which is also its namesake.

Installation

using Pkg
Registry.add("https://github.com/MurrellGroup/MurrellGroupRegistry")
Pkg.add("Onion")

Benchmarking

See benchmarks/main.jl for comparison scripts between naїve & fused versions.

Flash Attention

Implementation of FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

E, L, H, B = 64, 4096, 4, 4
causal = false

q = CuArray(rand(Float32, E, L, H, B))
k = CuArray(rand(Float32, E, L, H, B))
v = CuArray(rand(Float32, E, L, H, B))

o = ONIONop.flash_attention(q, k, v; causal)
∇ = Zygote.gradient(q, k, v) do q, k, v
    sum(ONIONop.flash_attention(q, k, v; causal))
end

Features:

  • Forward & backward passes.
  • Arbitrary sequence length.
  • FP32, FP16, BFP16 support.
  • Variable sequence length.
  • Causal masking.

Softmax

Implementation of Online normalizer calculation for softmax.

x = CuArray(rand(Float32, 8192, 1024))
y = ONIONop.online_softmax(x)

RMS Norm

x = CuArray(rand(Float32, 1024, 1024))
w = CuArray(rand(Float32, 1024))
y = ONIONop.rms_norm(x, w)
∇ = Zygote.gradient(x, w) do x, w
    sum(ONIONop.rms_norm(x, w))
end

Layer Norm

x = CuArray(rand(Float32, 1024, 1024))
w = CuArray(rand(Float32, 1024))
w = CuArray(rand(Float32, 1024))
y = ONIONop.layer_norm(x, w)
∇ = Zygote.gradient(x, w, b) do x, w, b
    sum(ONIONop.layer_norm(x, w, b))
end

Llama RoPE

E, L, B = 16, 1024, 1
QH, KH = 16, 16

emb = ONIONop.LlamaRotaryEmbedding(E)
position_ids = reshape(collect(0f0:Float32(L) - 1f0), :, 1)
position_ids = repeat(position_ids; inner=(1, B))

cos, sin = emb(position_ids)
cos = Adapt.adapt(kab, cos)
sin = Adapt.adapt(kab, sin)

q = Adapt.adapt(kab, ones(Float32, (E, L, QH, B)))
k = Adapt.adapt(kab, ones(Float32, (E, L, KH, B)))
q, k = ONIONop.llama_rope(q, k; cos, sin)

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Languages

  • Julia 100.0%