forked from prophesier/diff-svc
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimplify.py
More file actions
88 lines (76 loc) · 3.43 KB
/
simplify.py
File metadata and controls
88 lines (76 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import re
import shutil
import torch
def get_model_folder(path):
model_lists = os.listdir(path)
res_list = []
filter_list = ["hubert", "xiaoma_pe", "hifigan", "checkpoints", ".yaml", ".zip"]
for path in model_lists:
if not any(word if word in path else False for word in filter_list):
res_list.append(path)
return res_list
def scan(path):
model_str = ""
path_lists = get_model_folder(path)
for i in range(0, len(path_lists)):
if re.search(u'[\u4e00-\u9fa5]', path_lists[i]):
print(f'{path_lists[i]}:中文路径!此项跳过')
continue
model_str += f"{i}:{path_lists[i]} "
if (i + 1) % 5 == 0:
print(f"{model_str}")
model_str = ""
if len(path_lists) % 5 != 0:
print(model_str)
return path_lists
def simplify_pth(model_name, proj_name, output_path):
model_path = f'./checkpoints/{proj_name}'
checkpoint_dict = torch.load(f'{model_path}/{model_name}')
torch.save({'epoch': checkpoint_dict['epoch'],
'state_dict': checkpoint_dict['state_dict'],
'global_step': None,
'checkpoint_callback_best': None,
'optimizer_states': None,
'lr_schedulers': None
}, output_path)
def mkdir(paths: list):
for path in paths:
if not os.path.exists(path):
os.mkdir(path)
if __name__ == '__main__':
if os.path.exists("./checkpoints"):
path_list = scan("./checkpoints")
else:
print("请检查checkpoints文件夹是否存在")
exit()
a = input("\r\n请输入序号并回车:")
project_name = path_list[int(a)]
path_list = scan(f"./checkpoints/{path_list[int(a)]}")
b = input("\r\n请输入序号并回车:")
pth_name = path_list[int(b)]
print("\r\n选择:\r\n"
"0.存储精简模型到对应模型目录(本地精简模型时推荐使用这个)\r\n"
"1.存储精简模型和config.yaml到程序根目录(新建文件夹,九天毕昇上导出精简模型推荐使用这个)\r\n"
"2.复制完整模型和config.yaml到程序根目录(新建文件夹,九天毕昇上导出完整模型推荐使用这个)\r\n"
"输入其他退出")
f = int(input("\r\n请输入序号并回车:"))
if f == 0:
print(f"已保存精简模型至对应模型目录")
shutil.copyfile(f'./checkpoints/{project_name}/config.yaml', f"./{project_name}/config.yaml")
output = f"./checkpoints/{project_name}/clean_{pth_name}"
simplify_pth(pth_name, project_name, output)
elif f == 1:
print(f"已保存精简模型至: 根目录下新建文件夹/{project_name}")
mkdir([f"./{project_name}"])
shutil.copyfile(f'./checkpoints/{project_name}/config.yaml', f"./{project_name}/config.yaml")
output = f"./{project_name}/clean_{pth_name}"
simplify_pth(pth_name, project_name, output)
elif f == 2:
print(f"已保存完整模型至: 根目录下新建文件夹/{project_name}")
mkdir([f"./{project_name}"])
shutil.copyfile(f'./checkpoints/{project_name}/config.yaml', f"./{project_name}/config.yaml")
shutil.copyfile(f'./checkpoints/{project_name}/{pth_name}', f"./{project_name}/{pth_name}")
else:
print("输入错误,程序退出")
exit()