Skip to content

Commit 5eb4e29

Browse files
authored
Update tran_4_5t_vl.py (#1333)
1 parent 80ec9c9 commit 5eb4e29

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

tools/paddle2torch/tran_4_5t_vl.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import paddle
32
from safetensors.numpy import load_file
43
from safetensors.torch import save_file as save_safetensors
54
import torch
@@ -40,24 +39,20 @@ def convert_pdparams_to_safetensors(pdparams_path, safetensors_path, config):
4039
"""
4140
print("----------------------------------------------------------------")
4241
print("pdparams_path:", pdparams_path)
43-
paddle.set_device('cpu')
4442
# Load the PaddlePaddle model state dictionary
4543
torch_state_dict = {}
46-
weight_map = {}
4744
pd_tensors = load_file(pdparams_path)
4845
for key, param in pd_tensors.items():
4946
if param.dtype != 'float32':
5047
param = (param.astype(np.uint32) << 16).view(np.float32)
5148
if key.endswith('.weight') and \
5249
"embed_tokens" not in key and \
50+
"lm_head" not in key and \
5351
".gate." not in key and \
5452
param.ndim == 2:
5553
param = param.T # Transpose the parameter
56-
# vision model参数为float16
5754
tensor = torch.from_numpy(param)
58-
if 'vision' in key:
59-
tensor = tensor.to(torch.float16)
60-
elif 'mlp.gate.weight' not in key and 'mlp.moe_statics.e_score_correction_bias' not in key:
55+
if 'mlp.gate.weight' not in key and 'mlp.moe_statics.e_score_correction_bias' not in key:
6156
tensor = tensor.to(torch.bfloat16)
6257

6358
key = key.replace('ernie.', 'model.')

0 commit comments

Comments
 (0)