Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
.idea

# C extensions
*.so
Expand Down
4 changes: 3 additions & 1 deletion umbrella/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph
from .gemma import Gemma2
from .mistral import Mistral, MistralAwqOffload, MistralOffload, MistralCudagraph, MistralAwq
from .glm import GLM4
class AutoModelLM:
"""
自动模型加载器,根据模型类型动态加载对应的类。
Expand Down Expand Up @@ -117,7 +118,8 @@ class AutoModelLM:
"mistralai/Mistral-Small-24B-Instruct-2501": Mistral,
"stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwq,
"PyrTools/Ministral-8B-Instruct-2410-AWQ": MistralAwq,
"mistralai/Ministral-8B-Instruct-2410": Mistral
"mistralai/Ministral-8B-Instruct-2410": Mistral,
"THUDM/glm-4-9b-chat": GLM4
}

_CUDAGRAPH_MODEL_MAPPING = {
Expand Down
177 changes: 177 additions & 0 deletions umbrella/models/glm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from transformers.models.glm.modeling_glm import GlmForCausalLM, GlmConfig, apply_rotary_pos_emb, rotate_half
import torch
import torch.nn.functional as F
import gc
import flashinfer
from ..attn.cache import KV_Cache, StaticKV_Cache
from .glm_layer import GLM4Layer
from .base import LLMBase
from .model_utils import layer_norm

class GLM4(LLMBase):
def __init__(self,
model_name: str,
batch_size: int = 1,
max_length: int = 256,
device: str = 'cuda:0',
dtype=torch.float16
) -> None:

super().__init__()
self.batch_size = batch_size
self.device = device
self.dtype = dtype
self.config = GlmConfig.from_pretrained(model_name)
self.model_name = model_name
self.max_length = max_length
self.hidden_size = self.config.hidden_size
self.num_heads = self.config.num_attention_heads
self.head_dim = self.config.head_dim
self.num_key_value_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = self.config.max_position_embeddings
self.rope_theta = self.config.rope_theta
self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id]

def alloc(self, **kwargs):
self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype,
batch_size=self.batch_size)
hf_model = GlmForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype)
self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device)
if self.config.tie_word_embeddings:
self.lm_head = self.embed_tokens
else:
self.lm_head = hf_model.lm_head.weight.detach().to(self.device)

self.norm_weight = hf_model.model.norm.weight.detach().to(self.device)
self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon

self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device)
self.attention_scaling = hf_model.model.rotary_emb.attention_scaling

position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cache = emb.cos()[0]
self.sin_cache = emb.sin()[0]
self.cos_cache = self.cos_cache * self.attention_scaling
self.sin_cache = self.sin_cache * self.attention_scaling
self.cos_cache = self.cos_cache.to(self.dtype)
self.sin_cache = self.sin_cache.to(self.dtype)

self.layers: list[GLM4Layer] = []

for idx, hf_layer in enumerate(hf_model.model.layers):
layer = GLM4Layer(idx)
layer.init_parameters(hf_layer=hf_layer)
layer.to(self.device)
self.layers.append(layer)
hf_model.model.layers[idx] = None
gc.collect()

self.num_layers = len(self.layers)

@torch.inference_mode()
def layer_compute(self,
buffer: GLM4Layer,
layer_idx: int,
hidden_states: torch.FloatTensor,
position_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
storage_ids: torch.LongTensor):

residual = hidden_states
bsz, q_len, _ = hidden_states.size()

# Layer norm at the input
hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon,
buffer.input_layernorm_weight)
bsz, q_len, _ = hidden_states.size()

# Attention computation
query_states = F.linear(hidden_states, buffer.wq)
key_states = F.linear(hidden_states, buffer.wk)
value_states = F.linear(hidden_states, buffer.wv)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids)

hidden_states = self.kv_cache.compute_attention(
query_states, key_states, value_states, layer_idx, storage_ids, attention_mask
)
hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size)

# Post-attention and residual connection
hidden_states = F.linear(hidden_states, buffer.wo)
hidden_states = residual + hidden_states
residual = hidden_states

# Layer norm after attention (RMSNorm)
hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon,
buffer.post_attention_layernorm_weight)

# MLP
gate_up = F.linear(hidden_states, buffer.gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1) # Split into gate and up
gate = F.silu(gate)

hidden_states = gate * up
hidden_states = F.linear(hidden_states, buffer.down_proj)
hidden_states = residual + hidden_states
return hidden_states

@torch.inference_mode()
def inference(self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
storage_ids: torch.LongTensor):

hidden_states = F.embedding(input_ids, self.embed_tokens)
for idx in range(self.num_layers):
hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask,
storage_ids)

b, s, h = hidden_states.shape

hidden_states = hidden_states.reshape(b * s, h)
hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon)
hidden_states = hidden_states.reshape(b, s, h)
logits = F.linear(hidden_states, self.lm_head).float()
return logits

def gather_kv_incremental(self, indices: torch.LongTensor, offset: int):

self.kv_cache.gather_kv_incremental(indices=indices, offset=offset)

def clear(self):

self.kv_cache.clear()

@torch.inference_mode()
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids, unsqueeze_dim=2):
"""Applies Rotary Position Embedding to the query and key tensors for GLM.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

# Interleave them instead of usual shape
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)

# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)

# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
82 changes: 82 additions & 0 deletions umbrella/models/glm_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations
import torch
from transformers.models.glm.modeling_glm import GlmDecoderLayer
from ..quantization.awq_utils import AwqLinear


# refers to https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm/modeling_glm.py#L319

class GLM4Layer:
def __init__(self, layer_idx, device="cpu") -> None:
self.wq: torch.Tensor = None
self.wk: torch.Tensor = None
self.wv: torch.Tensor = None
self.wo: torch.Tensor = None

self.gate_up_proj: torch.Tensor = None
self.down_proj: torch.Tensor = None

self.input_layernorm_weight: torch.Tensor = None
self.input_layernorm_variance_epsilon: float = 0.0

self.post_attention_layernorm_weight: torch.Tensor = None
self.post_attention_layernorm_variance_epsilon: float = 0.0

self.layer_idx = layer_idx
self.device = device

def init_parameters(self, hf_layer: GlmDecoderLayer):
self.wq: torch.Tensor = hf_layer.self_attn.q_proj.weight.detach()
self.wk: torch.Tensor = hf_layer.self_attn.k_proj.weight.detach()
self.wv: torch.Tensor = hf_layer.self_attn.v_proj.weight.detach()
self.wo: torch.Tensor = hf_layer.self_attn.o_proj.weight.detach()

self.gate_up_proj = hf_layer.mlp.gate_up_proj.weight.detach()
self.down_proj = hf_layer.mlp.down_proj.weight.detach()

self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach()
self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon

self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach()
self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon

def to(self, device: str = 'cuda:0', non_blocking=True):
self.device = device
self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking)
self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device,
non_blocking=non_blocking)

self.wq = self.wq.to(device, non_blocking=non_blocking)
self.wk = self.wk.to(device, non_blocking=non_blocking)
self.wv = self.wv.to(device, non_blocking=non_blocking)
self.wo = self.wo.to(device, non_blocking=non_blocking)
self.gate_up_proj = self.gate_up_proj.to(device, non_blocking=non_blocking)
self.down_proj = self.down_proj.to(device, non_blocking=non_blocking)

def copy(self, layer: GLM4Layer):
self.wq.copy_(layer.wq, non_blocking=True)
self.wk.copy_(layer.wk, non_blocking=True)
self.wv.copy_(layer.wv, non_blocking=True)
self.wo.copy_(layer.wo, non_blocking=True)
self.gate_up_proj.copy_(layer.gate_up_proj, non_blocking=True)
self.down_proj.copy_(layer.down_proj, non_blocking=True)

self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True)
self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True)

self.input_layernorm_variance_epsilon = layer.input_layernorm_variance_epsilon
self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon

self.layer_idx = layer.layer_idx

def alloc_space(self, layer: GLM4Layer, device):
self.device = device
self.wq = torch.zeros_like(layer.wq).to(device)
self.wk = torch.zeros_like(layer.wk).to(device)
self.wv = torch.zeros_like(layer.wv).to(device)
self.wo = torch.zeros_like(layer.wo).to(device)

self.gate_up_proj = torch.zeros_like(layer.gate_up_proj).to(device)
self.down_proj = torch.zeros_like(layer.down_proj).to(device)
self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device)
self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device)
10 changes: 8 additions & 2 deletions umbrella/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
""",

'gemma2': "{}",
'mistral': "[INST] {} [/INST]"

'mistral': "[INST] {} [/INST]",
'glm4': """<|user|>
{}<|end|>
<|assistant|>
"""
}

SysPrompts = {
Expand All @@ -39,6 +42,9 @@
'gemma2': "",
'gemma2-it': "",
'mistral': "",
'glm4': """<|system|>
You are a helpful assistant.
"""

}

Expand Down