-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
265 lines (220 loc) · 10.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import argparse
import os
from typing import Any
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch, pandas as pd, os
import numpy as np
import pytorch_lightning as pl
from torchmetrics import Accuracy
from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor
from torchvision import transforms
import matplotlib.pyplot as plt
########################################################################################################################
# 自定义数据集
class CustomDataset(Dataset):
def __init__(self, csv_file, img_dir,label_col , transform=None):
self.annotations = pd.read_csv(csv_file)
self.img_dir = img_dir
self.transform = transform
self.label_col = label_col
unique_labels = self.annotations.iloc[:, 1].unique()
# unique_labels = self.annotations[label_col].unique()
unique_labels.sort()
self.label2id = {label: id for id, label in enumerate(unique_labels)}
self.id2label = {id: label for label, id in self.label2id.items()}
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_id = self.annotations.iloc[index, 0]
img_path = os.path.join(self.img_dir, img_id)
image = Image.open(img_path)
label = self.annotations.iloc[index, 1]
# label = self.annotations.loc[index, self.label_col]
label = 1 if label == 'Yes' else 0
if self.transform:
image = self.transform(image)
return image, label
# 自定义数据加载器
class ImageClassificationCollator:
def __init__(self, feature_extractor):
self.feature_extractor = feature_extractor
def __call__(self, batch):
# 要先将图像转换成rgb,因为有时候是rgb有时候是rgba
batch = [(x[0].convert('RGB'), x[1]) for x in batch]
encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
return encodings
# 自定义分类器(这是训练的核心部分)
class Classifier(pl.LightningModule):
def __init__(self, model, lr: float = 2e-5, **kwargs):
super().__init__()
self.automatic_optimization = True
self.save_hyperparameters('lr', *list(kwargs))
self.model = model
self.forward = self.model.forward
self.val_acc = Accuracy(
task='binary',
num_classes=2
)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
self.log(f"train_loss", outputs.loss)
return outputs.loss
def validation_step(self, batch, batch_idx):
outputs = self(**batch)
self.log(f"val_loss", outputs.loss)
acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
self.log(f"val_acc", acc, prog_bar=True)
return outputs.loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
# 验证数据
def roc(y_true, y_score, pos_label):
"""
y_true:真实标签
y_score:模型预测分数
pos_label:正样本标签,如“1”
"""
# 统计正样本和负样本的个数
num_positive_examples = (y_true == pos_label).sum()
num_negtive_examples = len(y_true) - num_positive_examples
tp, fp = 0, 0
tpr, fpr, thresholds = [], [], []
score = max(y_score) + 1
# 根据排序后的预测分数分别计算fpr和tpr
for i in np.flip(np.argsort(y_score)):
# 处理样本预测分数相同的情况
if y_score[i] != score:
fpr.append(fp / num_negtive_examples)
tpr.append(tp / num_positive_examples)
thresholds.append(score)
score = y_score[i]
if y_true[i] == pos_label:
tp += 1
else:
fp += 1
fpr.append(fp / num_negtive_examples)
tpr.append(tp / num_positive_examples)
thresholds.append(score)
return fpr, tpr, thresholds
def compute(y_true,y_score,thresholds):
#return confusion matrix
matrix=[0,0,0,0] #TP,FP,FN,TN
for i in range(len(y_score)):
if y_score[i]>=thresholds and y_true[i] == 1:
matrix[0]+=1
elif y_score[i]>=thresholds and y_true[i] == 0:
matrix[1]+1
elif y_score[i]<thresholds and y_true[i]==1:
matrix[2]+=1
elif y_score[i] < thresholds and y_true[i] ==0:
matrix[3]+=1
return matrix, matrix[0]/(matrix[0]+matrix[1]),matrix[0]/(matrix[0]+matrix[3])
def parse_args():
"""
Helper function to parse command line arguments
:return: args object
python train.py -d "path/to/data/Data - Is Epic Intro 2024-03-25" -l "Labels-IsEpicIntro-2024-03-25.csv" -t "Is Epic" -o "path/to/models/Is Epic/"
"""
# Argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--train_data_image_dir', required=True, help='Path to image data directory')
parser.add_argument('-l', '--train_data_labels_csv', required=True, help='Path to labels CSV')
parser.add_argument('-t', '--target_column_name', required=True, help='Name of the column with target label in CSV')
parser.add_argument('-o', '--trained_model_output_dir', required=True, help='Output directory for trained model')
args = parser.parse_args()
return args
def main(train_input_dir: str, train_labels_file_name: str, target_column_name: str, train_output_dir: str):
"""
The main body of the train.py responsible for
1. loading resources
2. loading labels
3. loading data
4. transforming data
5. training model
6. saving trained model
:param train_input_dir: the folder with the CSV and training images.
:param train_labels_file_name: the CSV file name
:param target_column_name: Name of the target column within the CSV file
:param train_output_dir: the folder to save training output.
"""
csv_file = os.path.join(train_input_dir, train_labels_file_name)
img_dir = train_input_dir
# 图像预处理,要按模型的要求进行预处理,要数据增强也是在这里做,目前什么都没做
transform = transforms.Compose([
# transforms.Resize((224, 224)), # 调整图像大小为224x224
# transforms.ToTensor(), # 将图像转换为Tensor
# transforms.Normalize( # 归一化
# mean=[0.5, 0.5, 0.5], # 根据提供信息使用这个均值
# std=[0.5, 0.5, 0.5] # 和标准差来归一化图像
# ), 这几步已经用ViTImageProcessor实现了
])
# 在这里设置预训练模型的路径
model_path = 'resources/pretrained'
dataset = CustomDataset(
csv_file=csv_file,
img_dir=img_dir,
label_col=target_column_name,
transform=None # 如果有数据增强,就传入transform
)
# 划分样本集和训练集,80%的数据用于训练,20%用于验证
# 假设你的dataset是一个已经准备好的CustomDataset实例
indices = list(range(len(dataset))) # 生成索引
np.random.shuffle(indices) # 打乱索引
# 定义训练数据和验证数据的大小比例
split = int(np.floor(0.8 * len(dataset))) # 这里我们使用80%的数据作为训练集
train_indices, val_indices = indices[:split], indices[split:]
# 根据索引创建数据集子集
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)
##### 加载模型,冻结并更改分类器 #####
# 本地文件路径
feature_extractor = ViTImageProcessor.from_pretrained(model_path)
model = ViTForImageClassification.from_pretrained(
model_path,
num_labels=2,
label2id=dataset.label2id,
id2label=dataset.id2label
)
collator = ImageClassificationCollator(feature_extractor)
# 创建数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=collator, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, collate_fn=collator, shuffle=False, pin_memory=True)
# 因为样本量较少,所以batch_size设为2,实现mini-batch梯度下降, shuffle=True表示每个epoch都打乱数据集
pl.seed_everything(42) # 设置随机种子,保证结果的可重复性
classifier = Classifier(model, lr=2e-5) # 传入模型和学习率
trainer = pl.Trainer(accelerator='gpu', devices=1, precision='16-mixed', max_epochs=15) # 使用GPU训练,训练15个epoch
# trainer = pl.Trainer(max_epochs=15) # 如果没有gpu,请使用CPU训练,训练15个epoch
trainer.fit(classifier, train_dataloader, val_dataloader)
# 预测并输出结果
val_batch = next(iter(val_dataloader))
outputs = model(**val_batch)
logits = outputs.logits
logits_softmax = logits.softmax(1).data.tolist()[0] #val_scores
print('Preds: ', outputs.logits.softmax(1).argmax(1))
print('Labels:', val_batch['labels'])
#计算混淆矩阵等参数
# fpr, tpr, thresholds = roc(val_batch['labels'], logits_softmax, pos_label=1)
# matrix, precision, recall = compute(val_batch['labels'],logits_softmax,0.5)
# f1 = 2*((precision*recall)/(precision+recall))
# print(f"precision = {precision}, recall = {recall}, f1_score = {f1}")
# print(matrix)
# Create the output directory and don't error if it already exists.
os.makedirs(train_output_dir, exist_ok=True)
# # 保存模型
model.save_pretrained(train_output_dir)
feature_extractor.save_pretrained(train_output_dir)
# torch.save(model, train_output_dir+'/'+target_column_name+'.pth')
if __name__ == '__main__':
"""
Example usage:
python train.py -d "path/to/Data - Is Epic Intro 2024-03-25" -l "Labels-IsEpicIntro-2024-03-25.csv" -t "Is Epic" -o "path/to/models"
"""
args = parse_args()
train_data_image_dir = args.train_data_image_dir
train_data_labels_csv = args.train_data_labels_csv
target_column_name = args.target_column_name
trained_model_output_dir = args.trained_model_output_dir
main(train_data_image_dir, train_data_labels_csv, target_column_name, trained_model_output_dir)
########################################################################################################################