Skip to content

BowTen/mLLM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

58 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mLLM 🔥

一个从零开始实现的轻量级大语言模型推理框架,使用 C++ 编写,专注于学习和理解 Transformer 架构的底层实现!

✨ 核心特性

  • 🤖 模型支持: 目前支持 Qwen3-0.6B 模型,直接兼容 Hugging Face 官方模型文件格式
  • ⚡ 双后端支持: 实现了 CUDA 和 CPU 算子,支持 GPU 和 CPU 推理
  • 🧮 Sgemm: 实现了 CUDA 的 Sgemm 算子,在 RTX3060 上性能达到 cublas 的 105%,RTX4090 上达到 82%
  • 📝 BPE Tokenizer: 自主实现的字节对编码分词器
  • 💾 KV Cache: 实现了键值缓存机制,提升推理效率

🏛️ 项目结构

MiniLLM/
├── include/
│   ├── base/                  # 基础组件
│   │   ├── allocator.h        # 内存分配器
│   │   ├── buffer.h           # 数据缓冲区
│   │   ├── common.h           # 通用定义
│   │   ├── json.hpp           # JSON 解析库
│   │   ├── safetensors.h      # SafeTensors 文件解析
│   │   ├── tensor.h           # 张量类定义
│   │   └── util.h             # 通用工具函数
│   ├── kernel/                # 计算内核
│   │   ├── cpu/               # CPU 算子实现
│   │   │   └── ...
│   │   ├── cuda/              # CUDA 算子实现
│   │   │   └── ...
│   │   └── kernel.h           # 内核接口
│   ├── model/                 # 模型架构
│   │   ├── qwen3.h            # Qwen3 主模型
│   │   ├── qwen3_decode_layer.h    # 解码层
│   │   ├── qwen3_mlp.h        # 多层感知机
│   │   ├── qwen3_rotary_embedding.h # 旋转位置编码
│   │   └── qwen3_self_attn.h  # 自注意力机制
│   ├── op/                    # 算子定义
│   │   ├── add.h              # 加法操作
│   │   ├── causal_mask.h      # 因果掩码
│   │   ├── ele_mul.h          # 元素乘法
│   │   ├── embedding.h        # 嵌入层
│   │   ├── layer.h            # 层基类
│   │   ├── linear.h           # 线性层
│   │   ├── mat_mul.h          # 矩阵乘法
│   │   ├── rms_norm.h         # RMS 归一化
│   │   ├── silu.h             # SiLU 激活函数
│   │   └── softmax.h          # Softmax 操作
│   └── tokenizer/             # 分词器
│       └── tokenizer.h        # 分词器接口
├── src/                       # 源文件实现
│   ├── base/
│   ├── kernel/
│   ├── model/
│   ├── op/
│   ├── tokenizer/
│   └── CMakeLists.txt
├── demo/
│   ├── qwen3_chat.cpp         # Qwen3 聊天演示
│   └── CMakeLists.txt
├── test/                      # 测试程序
├── scripts/                   # 辅助脚本
├── CMakeLists.txt
└── README.md

🧮 Sgemm 优化

主要优化手段:

  • 矩阵分块: 将矩阵分为若干tile,每个线程块负责一个tile的计算。每个tile再分为若干frag,每个线程负责一个frag的计算。线程间并行计算frag,提高效率。同时大量线程并发可隐藏访存延迟。
  • 缓存: 每个线程块处理一个tile前先将其从全局内存缓存到共享内存中,每个线程处理一个frag前将其从共享内存缓存到寄存器中。减少访存次数,防止计算核心等待,提高其利用率。
  • 预取: 在处理当前tile/frag时,可以发射指令预加载下一个tile/frag到寄存器中。当前计算不依赖这次访存,可并行执行,实现一定的访存延迟隐藏

M=N=K, M 整除64的情况

RTX 3060: RTX3060 SgemmDiv64性能对比图 RTX 4090: RTX4090 SgemmDiv64性能对比图

M=N=K, M 整除1的情况

这种情况目前只是在 div64 的基础上将所有 float4 存取展开,导致性能大幅下降,后续优化TODO
RTX 3060: RTX3060 SgemmDiv1性能对比图 RTX 4090: RTX4090 SgemmDiv1性能对比图

🤖 Qwen3 架构图

Qwen3架构图

⌨️ 聊天示例

#include "model/qwen3.h"

using namespace mllm;
using namespace mllm::tokenizer;

class Qwen3Chat
{
public:
    Qwen3Chat(const std::string &model_path, mllm::base::Device device, float temperature, float top_p, size_t top_k, float min_p)
        : model(mllm::model::Qwen3::from_pretrained(model_path, device, temperature, top_k, top_p, min_p)),
          tokenizer(model.get_tokenizer())
    {
    }

    std::string chat(const std::string &input, bool enable_thinking)
    {
        auto chat_token_ids = tokenizer->encode_with_chat_template(input, true, enable_thinking);
        auto input_id = tokenizer->to_tensor(chat_token_ids, model.device());
        base::Tensor next_id({1, 1}, model.device(), false, model.stream());

        bool is_end_think = false;
        std::string output;
        while (true)
        {
            if (!enable_thinking)
                is_end_think = true;
            input_id.toDevice(model.device());
            next_id.toDevice(model.device());
            model.forward(input_id, next_id);
            next_id.toDevice(base::Device::CPU);
            uint32_t id = *(reinterpret_cast<uint32_t *>(next_id.data()));

            std::string next_token = tokenizer->decode(id);
            output.append(next_token);

            if (id != BPETokenizer::QWEN3_END_OF_TEXT && id != BPETokenizer::QWEN3_IM_END)
            {
                std::cout << next_token;
                std::cout.flush();
            }
            if (id == BPETokenizer::QWEN3_END_THINK)
                is_end_think = true;
            if (is_end_think && id == BPETokenizer::QWEN3_END_OF_TEXT)
            {
                std::cout << std::endl;
                break;
            }
            input_id = next_id.clone();
        }

        return output;
    }

private:
    mllm::model::Qwen3 model;
    BPETokenizerPtr tokenizer;
};

int main()
{
    std::string model_path = "path_to_the_model/Qwen/Qwen3-0.6B";
    std::cout << "Loading model..." << std::endl;
    Qwen3Chat qwen3(model_path, base::Device::CUDA, 1.0, 1, 10000, 0.0);
    std::cout << "Loading accomplished." << std::endl;

    while (true)
    {
        std::string input;
        std::cout << "User: ";
        getline(std::cin, input);
        if (input == "exit")
            break;
        std::cout << "Qwen3: ";
        std::cout.flush();
        qwen3.chat(input, false);
    }

    return 0;
}

Qwen3 chat demo 演示动图

依赖库

  • armadillo
  • glog
  • gtest
  • cnpy

About

A simple LLM inference framework implemented in C++, mini LLM !

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published