-
Notifications
You must be signed in to change notification settings - Fork 11
The current implementation of NeoBERT is far slower than ModernBERT's. #5
Copy link
Copy link
Open
Description
The GPU I am using is an A100 80G SXM4. Below is my test code and the results.
15:02:01.344 ==================================================
15:02:01.344 评 测 总 结
15:02:01.344 ==================================================
15:02:01.344 配置: Batch Size=8, Sequence Length=1024, DType=torch.bfloat16
15:02:01.344 硬件: NVIDIA A100-SXM4-80GB
15:02:01.344 --------------------------------------------------
15:02:01.344 ModernBERT (Custom):
15:02:01.345 - 参数量: 335.23 M
15:02:01.345 - 平均前向时间: 43.265 ms
15:02:01.345 - 吞吐量: 184.91 sequences/sec
15:02:01.345
15:02:01.345 ModernBERT (Transformers):
15:02:01.345 - 参数量: 336.28 M
15:02:01.345 - 平均前向时间: 39.233 ms
15:02:01.345 - 吞吐量: 203.91 sequences/sec
15:02:01.345
15:02:01.345 NeoBERT:
15:02:01.345 - 参数量: 368.00 M
15:02:01.345 - 平均前向时间: 72.868 ms
15:02:01.345 - 吞吐量: 109.79 sequences/sec
15:02:01.345 --------------------------------------------------
15:02:01.345 结论:
15:02:01.345 ✅ 自定义 ModernBERT 比 NeoBERT 快 68.42%
15:02:01.345 ✅ Transformers ModernBERT 比 NeoBERT 快 85.73%
15:02:01.345 ✅ Transformers ModernBERT 比自定义 ModernBERT 快 9.32%
15:02:01.345 ==================================================
import torch
import time
from transformers import AutoConfig,ModernBertConfig, ModernBertForMaskedLM
# 从本地文件导入模型定义
# 确保 modernbert_model.py 和 neobert_model.py 在同一目录下
from modeling import ModernBertForDiffusionLM
from modeling_neobert import NeoBERTLMHead, NeoBERTConfig
# --- 辅助函数 ---
def count_parameters(model):
"""计算模型的可训练参数量(单位:百万 M)"""
return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
def create_modernbert_model(config_params):
"""根据参数创建 ModernBERT 模型"""
print(" - 创建 ModernBERT 配置...")
config = ModernBertConfig(**config_params)
# 允许Hugging Face自动选择最优的注意力实现 (如 flash_attention_2)
config._attn_implementation = "flash_attention_2"
print(" - 实例化 ModernBertForDiffusionLM...")
model = ModernBertForDiffusionLM(config)
return model
def create_transformers_modernbert_model(config_params):
"""根据参数创建 Transformers ModernBERT 模型"""
print(" - 创建 Transformers ModernBERT 配置...")
config = ModernBertConfig(**config_params)
# 允许Hugging Face自动选择最优的注意力实现 (如 flash_attention_2)
config._attn_implementation = "flash_attention_2"
print(" - 实例化 ModernBertForMaskedLM...")
model = ModernBertForMaskedLM(config)
return model
def create_neobert_model(config_params):
"""根据参数创建 NeoBERT 模型"""
print(" - 创建 NeoBERT 配置...")
config = NeoBERTConfig(**config_params)
print(" - 实例化 NeoBERTLMHead...")
model = NeoBERTLMHead(config)
return model
def benchmark(model, model_name, input_ids, attention_mask, n_warmup, n_runs, dtype):
"""
对给定的模型进行前向速度评测
Args:
model (nn.Module): 要评测的模型
model_name (str): 模型名称
input_ids (torch.Tensor): 输入张量
attention_mask (torch.Tensor): 注意力掩码
n_warmup (int): 预热运行次数
n_runs (int): 实际评测运行次数
dtype (torch.dtype): 运行的数据类型
"""
print(f"\n--- 开始评测: {model_name} ---")
device = input_ids.device
model.to(device, dtype=dtype).eval()
# 将attention_mask转换为模型可能需要的类型
if model_name == "NeoBERT":
attention_mask = attention_mask.to(torch.bool)
# 预热
print(f" - 预热 {n_warmup} 次...")
with torch.inference_mode():
for _ in range(n_warmup):
if model_name in ["ModernBERT", "Transformers-ModernBERT"]:
_ = model(input_ids=input_ids, attention_mask=attention_mask)
elif model_name == "NeoBERT":
_ = model(input_ids=input_ids, attention_mask=attention_mask) # NeoBERT 在其实现中会处理mask
torch.cuda.synchronize()
# 计时
print(f" - 运行 {n_runs} 次进行计时...")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
with torch.inference_mode():
start_event.record()
for _ in range(n_runs):
if model_name in ["ModernBERT", "Transformers-ModernBERT"]:
_ = model(input_ids=input_ids, attention_mask=attention_mask)
elif model_name == "NeoBERT":
_ = model(input_ids=input_ids, attention_mask=attention_mask)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_time_ms / n_runs
throughput = (input_ids.shape[0] * n_runs) / (elapsed_time_ms / 1000)
print("--- 评测结果 ---")
print(f" - 平均前向时间: {avg_time_ms:.3f} ms")
print(f" - 吞吐量: {throughput:.2f} sequences/sec")
return avg_time_ms, throughput
def main():
# --- 评测超参数 ---
BATCH_SIZE = 8
SEQ_LEN = 1024
N_WARMUP = 20
N_RUNS = 100
# --- 模型配置 (目标参数量: ~280M) ---
VOCAB_SIZE = 32000
# ModernBERT 配置
modernbert_params = {
"vocab_size": VOCAB_SIZE,
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"intermediate_size": 2736, # 4 * hidden_size
"max_position_embeddings": SEQ_LEN,
"pad_token_id": 0,
"global_attn_every_n_layers":1,
}
# NeoBERT 配置 (调整层数以匹配参数量)
# NeoBERT的FFN使用SwiGLU,参数量会略有不同,因此层数可能需要微调
# BERT FFN params: H*4H + 4H*H = 8H^2
# SwiGLU FFN params: H*I + H*I + I*H (I=intermediate_size)
# LLaMA-style SwiGLU: I = 2/3 * 4H. Params = 2*(H * 8/3H) + 8/3H * H = 16/3 H^2 + 8/3 H^2 = 8H^2.
# 所以当 intermediate_size = 4H 时,参数量几乎一样。
neobert_params = {
"vocab_size": VOCAB_SIZE,
"hidden_size": 1024,
"num_hidden_layers": 24, # 先使用相同的层数
"num_attention_heads": 16,
"intermediate_size": 4096, # 4 * hidden_size
"max_length": SEQ_LEN,
"pad_token_id": 0,
}
# --- 环境检查 ---
# if not torch.cuda.is_available():
# print("错误: 本评测需要 CUDA GPU。")
# return
device = "cuda"
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print(f"使用设备: {device}, 数据类型: {dtype}")
# --- 模型创建与参数统计 ---
print("\n[1] 正在创建 ModernBERT 模型...")
model_modern = create_modernbert_model(modernbert_params)
params_modern = count_parameters(model_modern)
print(f" - ModernBERT 参数量: {params_modern:.2f} M")
print("\n[2] 正在创建 Transformers ModernBERT 模型...")
model_transformers_modern = create_transformers_modernbert_model(modernbert_params)
params_transformers_modern = count_parameters(model_transformers_modern)
print(f" - Transformers ModernBERT 参数量: {params_transformers_modern:.2f} M")
print("\n[3] 正在创建 NeoBERT 模型...")
model_neo = create_neobert_model(neobert_params)
params_neo = count_parameters(model_neo)
print(f" - NeoBERT 参数量: {params_neo:.2f} M")
if abs(params_modern - params_neo) / params_modern > 0.05:
print("\n警告: 两个模型的参数量差距超过5%,请微调配置以保证公平对比。")
print(model_modern)
print(model_transformers_modern)
print(model_neo)
# --- 准备输入数据 ---
print(f"\n[4] 准备输入数据 (Batch={BATCH_SIZE}, SeqLen={SEQ_LEN})...")
input_ids = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
attention_mask = torch.ones_like(input_ids)
# --- 运行评测 ---
# 评测 ModernBERT
results = {}
results["modernbert"] = benchmark(model_modern, "ModernBERT", input_ids, attention_mask.clone(), N_WARMUP, N_RUNS, dtype)
# 释放显存
del model_modern
torch.cuda.empty_cache()
# 评测 Transformers ModernBERT
results["transformers_modernbert"] = benchmark(model_transformers_modern, "Transformers-ModernBERT", input_ids, attention_mask.clone(), N_WARMUP, N_RUNS, dtype)
# 释放显存
del model_transformers_modern
torch.cuda.empty_cache()
# 评测 NeoBERT
results["neobert"] = benchmark(model_neo, "NeoBERT", input_ids, attention_mask.clone(), N_WARMUP, N_RUNS, dtype)
# 释放显存
del model_neo
torch.cuda.empty_cache()
# --- 总结报告 ---
print("\n" + "="*50)
print(" 评 测 总 结")
print("="*50)
print(f"配置: Batch Size={BATCH_SIZE}, Sequence Length={SEQ_LEN}, DType={dtype}")
print(f"硬件: {torch.cuda.get_device_name(0)}")
print("-"*50)
modern_avg_time, modern_tput = results["modernbert"]
transformers_modern_avg_time, transformers_modern_tput = results["transformers_modernbert"]
neo_avg_time, neo_tput = results["neobert"]
print(f"ModernBERT (自定义):")
print(f" - 参数量: {params_modern:.2f} M")
print(f" - 平均前向时间: {modern_avg_time:.3f} ms")
print(f" - 吞吐量: {modern_tput:.2f} sequences/sec")
print(f"\nModernBERT (Transformers):")
print(f" - 参数量: {params_transformers_modern:.2f} M")
print(f" - 平均前向时间: {transformers_modern_avg_time:.3f} ms")
print(f" - 吞吐量: {transformers_modern_tput:.2f} sequences/sec")
print(f"\nNeoBERT:")
print(f" - 参数量: {params_neo:.2f} M")
print(f" - 平均前向时间: {neo_avg_time:.3f} ms")
print(f" - 吞吐量: {neo_tput:.2f} sequences/sec")
print("-"*50)
# 最终对比
print(f"结论:")
# NeoBERT vs 自定义 ModernBERT
speed_ratio_1 = neo_avg_time / modern_avg_time
if neo_avg_time < modern_avg_time:
print(f"✅ NeoBERT 比自定义 ModernBERT 快 {100 * (1 - speed_ratio_1):.2f}%")
else:
print(f"✅ 自定义 ModernBERT 比 NeoBERT 快 {100 * (speed_ratio_1 - 1):.2f}%")
# NeoBERT vs Transformers ModernBERT
speed_ratio_2 = neo_avg_time / transformers_modern_avg_time
if neo_avg_time < transformers_modern_avg_time:
print(f"✅ NeoBERT 比 Transformers ModernBERT 快 {100 * (1 - speed_ratio_2):.2f}%")
else:
print(f"✅ Transformers ModernBERT 比 NeoBERT 快 {100 * (speed_ratio_2 - 1):.2f}%")
# 自定义 vs Transformers ModernBERT
speed_ratio_3 = transformers_modern_avg_time / modern_avg_time
if transformers_modern_avg_time < modern_avg_time:
print(f"✅ Transformers ModernBERT 比自定义 ModernBERT 快 {100 * (1 - speed_ratio_3):.2f}%")
else:
print(f"✅ 自定义 ModernBERT 比 Transformers ModernBERT 快 {100 * (speed_ratio_3 - 1):.2f}%")
print("="*50)
if __name__ == "__main__":
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels