Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
503 changes: 503 additions & 0 deletions backends/intel_hpu/custom_ops/llama_infer/fused_mlp_new.cc

Large diffs are not rendered by default.

265 changes: 265 additions & 0 deletions backends/intel_hpu/custom_ops/python/paddlenlp_ops/Model_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import paddle
from safetensors.paddle import load_file, save_file
from tqdm import tqdm
import json
import shutil
import glob
from typing import Dict

paddle.device.set_device("intel_hpu:5")

MAX_FILE_SIZE_IN_GB = 5
max_size_bytes = MAX_FILE_SIZE_IN_GB * 1000**3


def tensor_size(tensor):
return tensor.nbytes if hasattr(tensor, "nbytes") else tensor.numpy().nbytes


def tensors_total_size(tensors_dict):
return sum(tensor_size(tensor) for tensor in tensors_dict.values())


def save_tail_tensors_and_index(
tensors_dict,
measurement_files,
model_fp8_path,
total_size,
out_file_idx,
out_files,
approximate_total_files,
):
for measurement_file in measurement_files:
with open(measurement_file, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
key, value = line.split("\t")
if value == 0.0:
print(f"warning: amax is 0.0 for {key}, set to 1e-5")
value = 1e-5
if "self_attn" not in key:
scale = float(value) / 240.0
else:
scale = float(value)
meas_scale_tensor = paddle.to_tensor([scale], dtype=paddle.bfloat16)
# print(f"--- meas_scale for {key}: {meas_scale_tensor} ---")
if key in tensors_dict:
tensors_dict[key] = paddle.maximum(
tensors_dict[key], meas_scale_tensor
)
else:
tensors_dict[key] = meas_scale_tensor
total_size += tensor_size(meas_scale_tensor)

file_name = f"model-{out_file_idx:05d}-of-{approximate_total_files:05d}.safetensors"
file_path = os.path.join(model_fp8_path, file_name)
save_file(tensors_dict, file_path)
out_files.append({"filename": file_name, "keys": list(tensors_dict.keys())})

index_json = {"metadata": {"total_size": total_size}, "weight_map": {}}
for file_info in out_files:
for key in file_info["keys"]:
index_json["weight_map"][key] = file_info["filename"]

index_path = os.path.join(model_fp8_path, "model.safetensors.index.json")
with open(index_path, "w") as f:
json.dump(index_json, f, indent=2)


def tensorwise_quant_to_fp8(tensor):
x_abs = paddle.abs(tensor).astype(paddle.float32)
x_amax = paddle.amax(x_abs)
x_amax = paddle.clip(x_amax, min=1e-4)
scale = x_amax / 240.0
x_scaled = (tensor.cast("float32") / scale).cast("float8_e4m3fn").clone()

return paddle.view(x_scaled, "int8").clone(), paddle.to_tensor([scale]).cast(
"bfloat16"
)


def process_safetensors_file(
tensors_dict,
src_path,
model_fp8_path,
total_size,
out_file_idx,
out_files,
max_size_bytes,
approximate_total_files,
):
current_size = tensors_total_size(tensors_dict)

loaded_tensors = load_file(src_path)
for key, tensor in loaded_tensors.items():
if "_proj.weight" in key:
if tensor.dtype != paddle.bfloat16:
print(
f"Warning: Expected bfloat16 tensor for key {key}, but got {tensor.dtype}. Skipping."
)
continue
else:
tensor = paddle.Tensor(tensor, zero_copy=True)
quant_tensor, scale = tensorwise_quant_to_fp8(tensor)

t_size = tensor_size(quant_tensor) + tensor_size(scale)
if current_size + t_size > max_size_bytes and tensors_dict:
file_name = f"model-{out_file_idx:05d}-of-{approximate_total_files:05d}.safetensors"
file_path = os.path.join(model_fp8_path, file_name)
save_file(tensors_dict, file_path)
out_files.append(
{"filename": file_name, "keys": list(tensors_dict.keys())}
)
out_file_idx += 1
tensors_dict = {}
current_size = 0

new_key = key.replace("_proj.weight", "_proj.quant_weight")
tensors_dict[new_key] = quant_tensor
scale_key = key.replace("_proj.weight", "_proj.weight_scale")
tensors_dict[scale_key] = scale
current_size += t_size
total_size += t_size
else:
t_size = tensor_size(tensor)
if current_size + t_size > max_size_bytes and tensors_dict:
file_name = f"model-{out_file_idx:05d}-of-{approximate_total_files:05d}.safetensors"
file_path = os.path.join(model_fp8_path, file_name)
save_file(tensors_dict, file_path)
out_files.append(
{"filename": file_name, "keys": list(tensors_dict.keys())}
)
out_file_idx += 1
tensors_dict = {}
current_size = 0
tensors_dict[key] = tensor
current_size += t_size
total_size += t_size
return tensors_dict, total_size, out_file_idx, out_files


def main():
print(
f"Usage: python {sys.argv[0]} [model_bf16_path] [model_fp8_path] [model_measurement_file_or_folder] <ranks_total_number>"
)
if len(sys.argv) > 3:
model_bf16_path = sys.argv[1]
model_fp8_path = sys.argv[2]
model_measurement_file = sys.argv[3]
ranks = "0"
if len(sys.argv) > 4:
ranks = sys.argv[4]
if len(sys.argv) < 4 or len(sys.argv) > 5:
print("Error: Invalid number of arguments.")
return
os.makedirs(model_fp8_path, exist_ok=True)

if os.path.isdir(model_measurement_file):
measurement_files = [
os.path.join(model_measurement_file, f)
for f in os.listdir(model_measurement_file)
if os.path.isfile(os.path.join(model_measurement_file, f))
]
elif ranks.isdigit() and int(ranks) > 1:
measurement_files = [
f"{os.path.splitext(model_measurement_file)[0]}_{i}{os.path.splitext(model_measurement_file)[1]}"
for i in range(int(ranks))
]
else:
measurement_files = [model_measurement_file]

for measurement_file in measurement_files:
if not os.path.isfile(measurement_file):
print(f"Error: Measurement file not found: {measurement_file}")
return

# copy none safetensor files (except model.safetensors.index.json) to new folder
for item_name in os.listdir(model_bf16_path):
source_path = os.path.join(model_bf16_path, item_name)
if os.path.isfile(source_path):
if item_name == "model.safetensors.index.json":
with open(source_path, "r") as f:
index_data = json.load(f)
total_size = index_data.get("metadata", {}).get("total_size", None)
elif item_name == "config.json":
with open(source_path, "r") as f:
config_data = json.load(f)
config_data["quantization_config"] = {
"dense_quant_type": "tensor_wise_fp8",
"moe_quant_type": "tensor_wise_fp8",
"quantization": "mix_quant",
"kv_cache_quant_type": "float8_e4m3",
"is_quantized": True,
}
destination_path = os.path.join(model_fp8_path, item_name)
with open(destination_path, "w") as f:
json.dump(config_data, f, indent=2)
elif not item_name.lower().endswith(".safetensors"):
destination_path = os.path.join(model_fp8_path, item_name)
try:
shutil.copy2(source_path, destination_path)
except Exception as e:
print(f"Error copying {item_name}: {e}")

# 计算预计总文件数
total_size *= 0.506
approximate_total_files = int((total_size + max_size_bytes - 1) // max_size_bytes)
print(f"Approximate total files to be generated: {approximate_total_files}")
total_size = 0
out_file_idx = 1
tensors_dict: Dict[str, paddle.Tensor] = {}
out_files = []

search_pattern = os.path.join(model_bf16_path, "*.safetensors")
safetensor_files = glob.glob(search_pattern)

if not safetensor_files:
print("Warning: No *.safetensors files found in the source directory.")
return

for file in tqdm(
safetensor_files,
desc=f"Loading safetensor files from {model_bf16_path}",
unit="file",
):
(tensors_dict, total_size, out_file_idx, out_files,) = process_safetensors_file(
tensors_dict,
file,
model_fp8_path,
total_size=total_size,
out_file_idx=out_file_idx,
out_files=out_files,
max_size_bytes=max_size_bytes,
approximate_total_files=approximate_total_files,
)

save_tail_tensors_and_index(
tensors_dict,
measurement_files,
model_fp8_path,
total_size,
out_file_idx,
out_files,
approximate_total_files,
)


main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .layers import * # noqa
from .llama_block_atten import * # noqa
from .blockwise_quant import * # noqa
from .reference_models import * # noqa
4 changes: 2 additions & 2 deletions backends/intel_hpu/custom_ops/python/paddlenlp_ops/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, scaling_factor, linear_weights):
self.linear_weights = linear_weights

def forward(self, i, query_states, key_states, value_states, attention_mask):
out_linear_out = fused_sdpa_proj(
out_linear_out = fused_sdpa_proj_legacy(
query_states,
key_states,
value_states,
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(self, proj_weight, up_weight, down_weight):
self.up_weight = up_weight

def forward(self, i, x):
fused_mlp_out = fused_mlp(
fused_mlp_out = fused_mlp_bf16(
x,
self.proj_weight[i],
self.up_weight[i],
Expand Down
Loading