Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions 1_nn/src/nn/attention.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::{Context, Distribution, Linear, NNError, NuralNetwork, TPTensor, Tensor, macros::*};
use super::{
Context, Distribution, Linear, NNError, Normalization, NuralNetwork, TPTensor, Tensor,
macros::*,
};
use crate::{
Arg, TPAction,
Arg, Dim, TPAction,
weight_types::{AttnQKV, RowTPWeight},
};
use tensor::digit_layout::types;
Expand All @@ -10,6 +13,8 @@ pub struct Attention<T> {
pub nh: usize,
pub nkvh: usize,
pub qkv: Linear<T>,
pub q_norm: Option<Normalization<T>>,
pub k_norm: Option<Normalization<T>>,
pub rope: Option<RoPE<T>>,
pub output: Linear<T>,
}
Expand All @@ -28,6 +33,8 @@ impl<T> Attention<T> {
nh,
nkvh,
qkv,
q_norm,
k_norm,
rope,
output,
} = self;
Expand All @@ -37,6 +44,8 @@ impl<T> Attention<T> {
nh: nh / dist.total * dist.len,
nkvh: nkvh / dist.total * dist.len,
qkv: qkv.parallel(TPAction::new(AttnQKV(nh / nkvh), dist)),
q_norm: q_norm.map(|norm| norm.tensor_parallel()),
k_norm: k_norm.map(|norm| norm.tensor_parallel()),
rope: rope.map(
|RoPE {
multimodal,
Expand Down Expand Up @@ -67,10 +76,11 @@ impl<T> NuralNetwork<T> for Attention<T> {
nh,
nkvh,
qkv,
q_norm,
k_norm,
rope,
output,
} = self;

destruct!([x] = ctx.trap("attn-qkv", qkv, [x])?);
dims!([_, dqkv] = x);
let dh = dqkv.clone() / (nh + nkvh + nkvh);
Expand All @@ -90,6 +100,78 @@ impl<T> NuralNetwork<T> for Attention<T> {
)?
);

// Apply normalization to q and k if they exist
let q = match q_norm {
Some(norm) => {
destruct!(
[q] = ctx.call(
"",
"tile",
Some(Arg::dict([
("axis".into(), Arg::int(1)),
(
"tile".into(),
Arg::arr([Dim::from(nh), dh.clone()].map(Arg::from))
),
])),
[q],
)?
);
destruct!([q] = ctx.trap("attn-q-norm", norm, [q])?);

destruct!(
[q] = ctx
.call(
"",
"merge",
Some(Arg::dict([
("start".into(), Arg::int(1)),
("len".into(), Arg::int(2),)
])),
[q],
)
.unwrap()
);
q
}
None => q,
};

let k = match k_norm {
Some(norm) => {
destruct!(
[k] = ctx.call(
"",
"tile",
Some(Arg::dict([
("axis".into(), Arg::int(1)),
(
"tile".into(),
Arg::arr([Dim::from(nkvh), dh.clone()].map(Arg::from))
),
])),
[k],
)?
);
destruct!([k] = ctx.trap("attn-k-norm", norm, [k])?);
destruct!(
[k] = ctx
.call(
"",
"merge",
Some(Arg::dict([
("start".into(), Arg::int(1)),
("len".into(), Arg::int(2),)
])),
[k],
)
.unwrap()
);
k
}
None => k,
};

let [q, k] = match rope {
Some(RoPE {
multimodal,
Expand Down
23 changes: 17 additions & 6 deletions 1_nn/src/op/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@ impl Operator for RmsNorm {

match inputs {
[x, scale] => {
dims!([_n, _d] = x);
dims!([_d] = scale);

let _d = make_eq(&[&x.shape[1], &scale.shape[0]]).ok_or(OpError::ShapeMismatch)?;

Ok(vec![TensorMeta::new(x.dt, [_n.clone(), _d])])
let (x_d, scale_d) = match x.shape().len() {
2 => {
dims!([_n, d] = x);
dims!([d_] = scale);
(d, d_)
}
3 => {
dims!([_n, _, d] = x);
dims!([d_] = scale);
(d, d_)
}
_ => {
return Err(OpError::ShapeError);
}
};
let _d = make_eq(&[x_d, scale_d]).ok_or(OpError::ShapeMismatch)?;
Ok(vec![TensorMeta::new(x.dt, x.shape().to_vec())])
}
_ => Err(OpError::ShapeError),
}
Expand Down
2 changes: 1 addition & 1 deletion 1_nn/src/op/tile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Operator for Tile {

let mut new_shape = shape[..axis].to_vec();
new_shape.extend_from_slice(tile.as_slice());
new_shape.extend_from_slice(&shape[axis..]);
new_shape.extend_from_slice(&shape[axis + 1..]);

Ok(vec![TensorMeta::new(x.dt, new_shape)])
}
Expand Down
8 changes: 4 additions & 4 deletions 2_mem/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ pub(crate) fn tile<T>(node: &mut Node, topo: NodeRef, edges: &mut [Edge<T>]) {
};
let tile = tile
.iter()
.map(|p| {
if let Arg::Dim(dim) = p {
dim.to_usize()
} else {
.map(|p| match p {
Arg::Dim(dim) => dim.to_usize(),
Arg::Int(dim) => *dim as usize,
_ => {
unreachable!()
}
})
Expand Down
2 changes: 2 additions & 0 deletions example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ fn main() {
.register_op("layer-norm", op::normalization::LayerNorm)
.register_op("attention", op::attention::Attention)
.register_op("split", op::split::Split)
.register_op("tile", op::tile::Tile)
.register_op("merge", op::merge::Merge)
.register_op("swiglu", op::activation::SwiGLU)
.register_op("gelu", op::activation::GeLU)
.register_op("linear", op::linear::Linear)
Expand Down
38 changes: 37 additions & 1 deletion example/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub fn init(gguf: &mut GGufModel) -> nn::LLaMA<String> {
let dt_bias = match arch {
"llama" => None,
"qwen2" => Some(gguf.tensors["blk.0.attn_qkv.bias"].dt()),
"qwen3" => None,
arch => panic!("unsupported arch {arch}"),
};

Expand All @@ -21,7 +22,12 @@ pub fn init(gguf: &mut GGufModel) -> nn::LLaMA<String> {
let d = meta![gguf => llm_embedding_length];
let nh = meta![gguf => llm_attention_head_count];
let nkvh = meta![gguf => llm_attention_head_count_kv; nh];
let dh = meta![gguf => llm_rope_dimension_count; d / nh];
let dh = match arch {
"qwen3" => gguf.tensors["blk.0.attn_qkv.weight"].shape()[0]
.checked_div(nh + nkvh + nkvh)
.unwrap(),
_ => meta![gguf => llm_rope_dimension_count; d / nh],
};
let di = meta![gguf => llm_feed_forward_length];
let epsilon = meta![gguf => llm_attention_layer_norm_rms_epsilon; 1e-5];
let dt_embd = gguf.tensors["token_embd.weight"].dt();
Expand Down Expand Up @@ -63,6 +69,36 @@ pub fn init(gguf: &mut GGufModel) -> nn::LLaMA<String> {
format!("blk.{iblk}.attn_qkv.weight"),
dt_bias.map(|dt| (dt, format!("blk.{iblk}.attn_qkv.bias"))),
),
q_norm: if gguf
.tensors
.contains_key(format!("blk.{iblk}.attn_q_norm.weight").as_str())
{
Some(::nn::Normalization {
d: dh,
epsilon: epsilon as _,
items: ::nn::NormType::RmsNorm {
dt: dt_norm,
scale: format!("blk.{iblk}.attn_q_norm.weight"),
},
})
} else {
None
},
k_norm: if gguf
.tensors
.contains_key(format!("blk.{iblk}.attn_k_norm.weight").as_str())
{
Some(::nn::Normalization {
d: dh,
epsilon: epsilon as _,
items: ::nn::NormType::RmsNorm {
dt: dt_norm,
scale: format!("blk.{iblk}.attn_k_norm.weight"),
},
})
} else {
None
},
rope: Some(::nn::RoPE {
multimodal: false,
nctx,
Expand Down