本文档在一个小数据集上展示了如何通过PaddleX进行训练,您可以阅读PaddleX的使用教程来了解更多模型任务的训练使用方式。本示例同步在AIStudio上,可直接在线体验模型训练
wget https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz
tar xzvf vegetables_cls.tar.gz
通过如下train.py
代码进行训练
设置使用0号GPU卡
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx
定义训练和验证时的数据处理流程, 在
train_transforms
中加入了RandomCrop
和RandomHorizontalFlip
两种数据增强方式
from paddlex.cls import transforms
train_transforms = transforms.Compose([
transforms.RandomCrop(crop_size=224),
transforms.RandomHorizontalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.ResizeByShort(short_size=256),
transforms.CenterCrop(crop_size=224),
transforms.Normalize()
])
定义数据集,
pdx.datasets.ImageNet
表示读取ImageNet格式的分类数据集
train_dataset = pdx.datasets.ImageNet(
data_dir='vegetables_cls',
file_list='vegetables_cls/train_list.txt',
label_list='vegetables_cls/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.ImageNet(
data_dir='vegetables_cls',
file_list='vegetables_cls/val_list.txt',
label_list='vegetables_cls/labels.txt',
transforms=eval_transforms)
模型训练
num_classes = len(train_dataset.labels)
model = pdx.cls.MobileNetV2(num_classes=num_classes)
model.train(num_epochs=10,
train_dataset=train_dataset,
train_batch_size=32,
eval_dataset=eval_dataset,
lr_decay_epochs=[4, 6, 8],
learning_rate=0.025,
save_dir='output/mobilenetv2',
use_vdl=True)
train.py
与解压后的数据集目录vegetables_cls
放在同一目录下,在此目录下运行train.py
即可开始训练。如果您的电脑上有GPU,这将会在10分钟内训练完成,如果为CPU也大概会在30分钟内训练完毕。
python train.py
模型在训练过程中,所有的迭代信息将以标注输出流的形式,输出到命令执行的终端上,用户也可通过visualdl以可视化的方式查看训练指标的变化,通过如下方式启动visualdl后,在浏览器打开https://0.0.0.0:8001即可。
visualdl --logdir output/mobilenetv2/vdl_log --port 8000
如使用训练过程中第8轮保存的模型进行测试
import paddlex as pdx
model = pdx.load_model('output/mobilenetv2/epoch_8')
result = model.predict('vegetables_cls/bocai/100.jpg', topk=3)
print("Predict Result:", result)
预测结果输出如下,预测按score进行排序,得到前三分类结果
Predict Result: Predict Result: [{'score': 0.9999393, 'category': 'bocai', 'category_id': 0}, {'score': 6.010089e-05, 'category': 'hongxiancai', 'category_id': 2}, {'score': 5.593914e-07, 'category': 'xilanhua', 'category_id': 5}]