From 7c3939bacb2502d368ab357ed89e73ca5352b457 Mon Sep 17 00:00:00 2001 From: pwhMass Date: Wed, 9 Jul 2025 01:15:09 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=AF=B9qwen3?= =?UTF-8?q?=E7=9A=84=E6=94=AF=E6=8C=81=20=E5=90=8C=E6=97=B6=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E4=B8=80=E4=BA=9B=E6=9C=89=E5=85=B3tile=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 1_nn/src/nn/attention.rs | 88 ++++++++++++++++++++++++++++++++++-- 1_nn/src/op/normalization.rs | 23 +++++++--- 1_nn/src/op/tile.rs | 2 +- 2_mem/src/op.rs | 8 ++-- example/src/main.rs | 2 + example/src/model.rs | 38 +++++++++++++++- 6 files changed, 146 insertions(+), 15 deletions(-) diff --git a/1_nn/src/nn/attention.rs b/1_nn/src/nn/attention.rs index 947af43..041bfd3 100644 --- a/1_nn/src/nn/attention.rs +++ b/1_nn/src/nn/attention.rs @@ -1,6 +1,9 @@ -use super::{Context, Distribution, Linear, NNError, NuralNetwork, TPTensor, Tensor, macros::*}; +use super::{ + Context, Distribution, Linear, NNError, Normalization, NuralNetwork, TPTensor, Tensor, + macros::*, +}; use crate::{ - Arg, TPAction, + Arg, Dim, TPAction, weight_types::{AttnQKV, RowTPWeight}, }; use tensor::digit_layout::types; @@ -10,6 +13,8 @@ pub struct Attention { pub nh: usize, pub nkvh: usize, pub qkv: Linear, + pub q_norm: Option>, + pub k_norm: Option>, pub rope: Option>, pub output: Linear, } @@ -28,6 +33,8 @@ impl Attention { nh, nkvh, qkv, + q_norm, + k_norm, rope, output, } = self; @@ -37,6 +44,8 @@ impl Attention { nh: nh / dist.total * dist.len, nkvh: nkvh / dist.total * dist.len, qkv: qkv.parallel(TPAction::new(AttnQKV(nh / nkvh), dist)), + q_norm: q_norm.map(|norm| norm.tensor_parallel()), + k_norm: k_norm.map(|norm| norm.tensor_parallel()), rope: rope.map( |RoPE { multimodal, @@ -67,10 +76,11 @@ impl NuralNetwork for Attention { nh, nkvh, qkv, + q_norm, + k_norm, rope, output, } = self; - destruct!([x] = ctx.trap("attn-qkv", qkv, [x])?); dims!([_, dqkv] = x); let dh = dqkv.clone() / (nh + nkvh + nkvh); @@ -90,6 +100,78 @@ impl NuralNetwork for Attention { )? ); + // Apply normalization to q and k if they exist + let q = match q_norm { + Some(norm) => { + destruct!( + [q] = ctx.call( + "", + "tile", + Some(Arg::dict([ + ("axis".into(), Arg::int(1)), + ( + "tile".into(), + Arg::arr([Dim::from(nh), dh.clone()].map(Arg::from)) + ), + ])), + [q], + )? + ); + destruct!([q] = ctx.trap("attn-q-norm", norm, [q])?); + + destruct!( + [q] = ctx + .call( + "", + "merge", + Some(Arg::dict([ + ("start".into(), Arg::int(1)), + ("len".into(), Arg::int(2),) + ])), + [q], + ) + .unwrap() + ); + q + } + None => q, + }; + + let k = match k_norm { + Some(norm) => { + destruct!( + [k] = ctx.call( + "", + "tile", + Some(Arg::dict([ + ("axis".into(), Arg::int(1)), + ( + "tile".into(), + Arg::arr([Dim::from(nkvh), dh.clone()].map(Arg::from)) + ), + ])), + [k], + )? + ); + destruct!([k] = ctx.trap("attn-k-norm", norm, [k])?); + destruct!( + [k] = ctx + .call( + "", + "merge", + Some(Arg::dict([ + ("start".into(), Arg::int(1)), + ("len".into(), Arg::int(2),) + ])), + [k], + ) + .unwrap() + ); + k + } + None => k, + }; + let [q, k] = match rope { Some(RoPE { multimodal, diff --git a/1_nn/src/op/normalization.rs b/1_nn/src/op/normalization.rs index 62cc711..bdfaedb 100644 --- a/1_nn/src/op/normalization.rs +++ b/1_nn/src/op/normalization.rs @@ -12,12 +12,23 @@ impl Operator for RmsNorm { match inputs { [x, scale] => { - dims!([_n, _d] = x); - dims!([_d] = scale); - - let _d = make_eq(&[&x.shape[1], &scale.shape[0]]).ok_or(OpError::ShapeMismatch)?; - - Ok(vec![TensorMeta::new(x.dt, [_n.clone(), _d])]) + let (x_d, scale_d) = match x.shape().len() { + 2 => { + dims!([_n, d] = x); + dims!([d_] = scale); + (d, d_) + } + 3 => { + dims!([_n, _, d] = x); + dims!([d_] = scale); + (d, d_) + } + _ => { + return Err(OpError::ShapeError); + } + }; + let _d = make_eq(&[x_d, scale_d]).ok_or(OpError::ShapeMismatch)?; + Ok(vec![TensorMeta::new(x.dt, x.shape().to_vec())]) } _ => Err(OpError::ShapeError), } diff --git a/1_nn/src/op/tile.rs b/1_nn/src/op/tile.rs index 62204b6..0a9a26c 100644 --- a/1_nn/src/op/tile.rs +++ b/1_nn/src/op/tile.rs @@ -43,7 +43,7 @@ impl Operator for Tile { let mut new_shape = shape[..axis].to_vec(); new_shape.extend_from_slice(tile.as_slice()); - new_shape.extend_from_slice(&shape[axis..]); + new_shape.extend_from_slice(&shape[axis + 1..]); Ok(vec![TensorMeta::new(x.dt, new_shape)]) } diff --git a/2_mem/src/op.rs b/2_mem/src/op.rs index 9dbac6d..a7963e9 100644 --- a/2_mem/src/op.rs +++ b/2_mem/src/op.rs @@ -48,10 +48,10 @@ pub(crate) fn tile(node: &mut Node, topo: NodeRef, edges: &mut [Edge]) { }; let tile = tile .iter() - .map(|p| { - if let Arg::Dim(dim) = p { - dim.to_usize() - } else { + .map(|p| match p { + Arg::Dim(dim) => dim.to_usize(), + Arg::Int(dim) => *dim as usize, + _ => { unreachable!() } }) diff --git a/example/src/main.rs b/example/src/main.rs index 008af60..b262b03 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -24,6 +24,8 @@ fn main() { .register_op("layer-norm", op::normalization::LayerNorm) .register_op("attention", op::attention::Attention) .register_op("split", op::split::Split) + .register_op("tile", op::tile::Tile) + .register_op("merge", op::merge::Merge) .register_op("swiglu", op::activation::SwiGLU) .register_op("gelu", op::activation::GeLU) .register_op("linear", op::linear::Linear) diff --git a/example/src/model.rs b/example/src/model.rs index c175e4e..657ea54 100644 --- a/example/src/model.rs +++ b/example/src/model.rs @@ -12,6 +12,7 @@ pub fn init(gguf: &mut GGufModel) -> nn::LLaMA { let dt_bias = match arch { "llama" => None, "qwen2" => Some(gguf.tensors["blk.0.attn_qkv.bias"].dt()), + "qwen3" => None, arch => panic!("unsupported arch {arch}"), }; @@ -21,7 +22,12 @@ pub fn init(gguf: &mut GGufModel) -> nn::LLaMA { let d = meta![gguf => llm_embedding_length]; let nh = meta![gguf => llm_attention_head_count]; let nkvh = meta![gguf => llm_attention_head_count_kv; nh]; - let dh = meta![gguf => llm_rope_dimension_count; d / nh]; + let dh = match arch { + "qwen3" => gguf.tensors["blk.0.attn_qkv.weight"].shape()[0] + .checked_div(nh + nkvh + nkvh) + .unwrap(), + _ => meta![gguf => llm_rope_dimension_count; d / nh], + }; let di = meta![gguf => llm_feed_forward_length]; let epsilon = meta![gguf => llm_attention_layer_norm_rms_epsilon; 1e-5]; let dt_embd = gguf.tensors["token_embd.weight"].dt(); @@ -63,6 +69,36 @@ pub fn init(gguf: &mut GGufModel) -> nn::LLaMA { format!("blk.{iblk}.attn_qkv.weight"), dt_bias.map(|dt| (dt, format!("blk.{iblk}.attn_qkv.bias"))), ), + q_norm: if gguf + .tensors + .contains_key(format!("blk.{iblk}.attn_q_norm.weight").as_str()) + { + Some(::nn::Normalization { + d: dh, + epsilon: epsilon as _, + items: ::nn::NormType::RmsNorm { + dt: dt_norm, + scale: format!("blk.{iblk}.attn_q_norm.weight"), + }, + }) + } else { + None + }, + k_norm: if gguf + .tensors + .contains_key(format!("blk.{iblk}.attn_k_norm.weight").as_str()) + { + Some(::nn::Normalization { + d: dh, + epsilon: epsilon as _, + items: ::nn::NormType::RmsNorm { + dt: dt_norm, + scale: format!("blk.{iblk}.attn_k_norm.weight"), + }, + }) + } else { + None + }, rope: Some(::nn::RoPE { multimodal: false, nctx,