From c3f42f61ebd0ff9d96cc0d815f6e391b0d580302 Mon Sep 17 00:00:00 2001 From: tsy <1002548612@qq.com> Date: Fri, 10 Mar 2023 06:07:00 -0500 Subject: [PATCH 1/4] [Feature] Add onecycle lr scheduler --- mindcv/scheduler/dynamic_lr.py | 28 +++++++++++++++++++++++++++ mindcv/scheduler/scheduler_factory.py | 9 +++++++-- tests/modules/test_scheduler.py | 10 ++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/mindcv/scheduler/dynamic_lr.py b/mindcv/scheduler/dynamic_lr.py index 14ed3012c..530c8d192 100644 --- a/mindcv/scheduler/dynamic_lr.py +++ b/mindcv/scheduler/dynamic_lr.py @@ -22,6 +22,8 @@ import math from bisect import bisect_right +import numpy as np + def constant_lr(factor, total_iters, *, lr, steps_per_epoch, epochs): steps = steps_per_epoch * epochs @@ -191,6 +193,32 @@ def cosine_annealing_warm_restarts_lr(te, tm, eta_min, *, eta_max, steps_per_epo return lrs +def onecycle_lr(lr, min_lr, steps_per_epoch, epochs): + """Get OneCycle lr""" + lrs = [] + + def _lr_adjuster(base_lr, epoch): + lr = np.interp( + [epoch], + [ + 0, + epochs * 0.5 // 5, + epochs * 4 // 5, + epochs + ], + [0, base_lr, base_lr / 20.0, 0] + )[0] + + return lr + + for epoch in range(epochs): + for batch in range(steps_per_epoch): + lrs.append(_lr_adjuster(lr, epoch + batch / steps_per_epoch)) + lrs = np.array(lrs) + lrs = np.clip(lrs, min_lr, max(lrs)) + return lrs + + if __name__ == "__main__": # Demonstrate how these schedulers work by printing & visualizing the returned list. import matplotlib.pyplot as plt diff --git a/mindcv/scheduler/scheduler_factory.py b/mindcv/scheduler/scheduler_factory.py index 3db35f771..18b8b93cb 100644 --- a/mindcv/scheduler/scheduler_factory.py +++ b/mindcv/scheduler/scheduler_factory.py @@ -7,6 +7,7 @@ linear_lr, linear_refined_lr, multi_step_lr, + onecycle_lr, polynomial_lr, polynomial_refined_lr, step_lr, @@ -35,9 +36,9 @@ def create_scheduler( Args: steps_per_epoch: number of steps per epoch. scheduler: scheduler name like 'constant', 'cosine_decay', 'step_decay', - 'exponential_decay', 'polynomial_decay', 'multi_step_decay'. Default: 'constant'. + 'exponential_decay', 'polynomial_decay', 'multi_step_decay', 'onecycle'. Default: 'constant'. lr: learning rate value. Default: 0.01. - min_lr: lower lr bound for 'cosine_decay' schedulers. Default: 1e-6. + min_lr: lower lr bound for 'cosine_decay' and 'onecycle' schedulers. Default: 1e-6. warmup_epochs: epochs to warmup LR, if scheduler supports. Default: 3. warmup_factor: the warmup phase of scheduler is a linearly increasing lr, the beginning factor is `warmup_factor`, i.e., the lr of the first step/epoch is lr*warmup_factor, @@ -63,6 +64,8 @@ def create_scheduler( # lr warmup phase warmup_lr_scheduler = [] if warmup_epochs > 0: + if scheduler == "onecycle": + raise ValueError("OneCycle scheduler has warmup built in, please set warmup_epochs to 0") if warmup_factor == 0 and lr_epoch_stair: print( "[WARNING]: The warmup factor is set to 0, lr of 0-th epoch is always zero! " "Recommend value is 0.01." @@ -108,6 +111,8 @@ def create_scheduler( main_lr_scheduler = multi_step_lr( milestones=milestones, gamma=decay_rate, lr=lr, steps_per_epoch=steps_per_epoch, epochs=main_epochs ) + elif scheduler == "onecycle": + main_lr_scheduler = onecycle_lr(lr=lr, min_lr=min_lr, step_per_epoch=steps_per_epoch, epochs=main_epochs) elif scheduler == "constant": main_lr_scheduler = [lr for _ in range(steps_per_epoch * main_epochs)] else: diff --git a/tests/modules/test_scheduler.py b/tests/modules/test_scheduler.py index b45eb7ab8..ab385546a 100644 --- a/tests/modules/test_scheduler.py +++ b/tests/modules/test_scheduler.py @@ -89,4 +89,14 @@ def test_scheduler_dynamic(): 0.00615582970243117] lrs_ms = dynamic_lr.cosine_annealing_warm_restarts_lr(5, 2, 0.0, eta_max=1.0, steps_per_epoch=2, epochs=15) assert np.allclose(lrs_ms, lrs_manually) + + # onecycle_lr + lrs_manually = [1.00000000e-06, 5.00000000e-04, 1.00000000e-03, 9.32142857e-04, + 8.64285714e-04, 7.96428571e-04, 7.28571429e-04, 6.60714286e-04, + 5.92857143e-04, 5.25000000e-04, 4.57142857e-04, 3.89285714e-04, + 3.21428571e-04, 2.53571429e-04, 1.85714286e-04, 1.17857143e-04, + 5.00000000e-05, 3.75000000e-05, 2.50000000e-05, 1.25000000e-05] + lrs_ms = dynamic_lr.onecycle_lr(lr=0.001, min_lr=0.000001, steps_per_epoch=2, epochs=10) + assert np.allclose(lrs_ms, lrs_manually) + # fmt: on From 0045e4e09bf7dc979af826b5c4af9ccb6033494f Mon Sep 17 00:00:00 2001 From: tsy <1002548612@qq.com> Date: Fri, 10 Mar 2023 06:10:44 -0500 Subject: [PATCH 2/4] [Feature] add model script, training configs and training weights of ConvMixer --- configs/convmixer/README.md | 103 +++++++++++ configs/convmixer/convmixer_1024_20.yaml | 67 ++++++++ configs/convmixer/convmixer_1536_20.yaml | 67 ++++++++ configs/convmixer/convmixer_768_32.yaml | 66 ++++++++ mindcv/models/__init__.py | 3 + mindcv/models/convmixer.py | 207 +++++++++++++++++++++++ 6 files changed, 513 insertions(+) create mode 100644 configs/convmixer/README.md create mode 100644 configs/convmixer/convmixer_1024_20.yaml create mode 100644 configs/convmixer/convmixer_1536_20.yaml create mode 100644 configs/convmixer/convmixer_768_32.yaml create mode 100644 mindcv/models/convmixer.py diff --git a/configs/convmixer/README.md b/configs/convmixer/README.md new file mode 100644 index 000000000..c984b6c19 --- /dev/null +++ b/configs/convmixer/README.md @@ -0,0 +1,103 @@ + +# ConvMixer +> [Patches Are All You Need?](https://arxiv.org/pdf/2201.09792.pdf) + +## Introduction + +Although convolutional networks have been the dominant architecture for vision +tasks for many years, recent experiments have shown that Transformer-based models, most notably the Vision Transformer (ViT), may exceed their performance in +some settings. However, due to the quadratic runtime of the self-attention layers +in Transformers, ViTs require the use of patch embeddings, which group together +small regions of the image into single input features, in order to be applied to +larger image sizes. This raises a question: Is the performance of ViTs due to the +inherently-more-powerful Transformer architecture, or is it at least partly due to +using patches as the input representation? In this paper, we present some evidence +for the latter: specifically, we propose the ConvMixer, an extremely simple model +that is similar in spirit to the ViT and the even-more-basic MLP-Mixer in that it +operates directly on patches as input, separates the mixing of spatial and channel +dimensions, and maintains equal size and resolution throughout the network. In +contrast, however, the ConvMixer uses only standard convolutions to achieve the +mixing steps. Despite its simplicity, we show that the ConvMixer outperforms the +ViT, MLP-Mixer, and some of their variants for similar parameter counts and data +set sizes, in addition to outperforming classical vision models such as the ResNet. + + +## Results + +**Implementation and configs for training were taken and adjusted from [this repository](https://gitee.com/cvisionlab/models/tree/convmixer/release/research/cv/convmixer), which implements ConvMixer models in mindspore.** + +Our reproduced model performance on ImageNet-1K is reported as follows. + +
+ +| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | +|----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| +| convmixer_768_32 | Converted from PyTorch | 79.68 | 94.92 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_768_32.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_768_32.ckpt) | +| convmixer_768_32 | 8xRTX3090 | 73.05 | 90.53 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_768_32.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/convmixer_768_32_trained.ckpt) | +| convmixer_1024_20 | Converted from PyTorch | 76.68 | 93.3 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_1024_20.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_1024_20.ckpt) | +| convmixer_1536_20 | Converted from PyTorch | 80.98 | 95.51 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_1536_20.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_1536_20.ckpt) | + + +
+ +#### Notes + +- Context: The weights in the table were taken from [official repository](https://github.com/locuslab/convmixer) and converted to mindspore format +- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K. + +## Quick Start + +### Preparation + +#### Installation +Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV. + +#### Dataset Preparation +Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation. + +### Training + +* Distributed Training + + +```shell +# distrubted training on multiple GPU/Ascend devices +mpirun -n 8 python train.py --config configs/convmixer/convmixer_768_32.yaml --data_dir /path/to/imagenet --distributed True +``` + +> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`. + +Similarly, you can train the model on multiple GPU devices with the above `mpirun` command. + +For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py). + +**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size. + +* Standalone Training + +If you want to train or finetune the model on a smaller dataset without distributed training, please run: + +```shell +# standalone training on a CPU/GPU/Ascend device +python train.py --config configs/convmixer/convmixer_768_32.yaml --data_dir /path/to/dataset --distribute False +``` + +### Validation + +To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`. + +```shell +python validate.py -c configs/convmixer/convmixer_768_32.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt +``` + +### Deployment + +Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV. + +## References + +Paper - https://arxiv.org/pdf/2201.09792.pdf + +Official repo - https://github.com/locuslab/convmixer + +Mindspore implementation - https://gitee.com/cvisionlab/models/tree/convmixer/release/research/cv/convmixer diff --git a/configs/convmixer/convmixer_1024_20.yaml b/configs/convmixer/convmixer_1024_20.yaml new file mode 100644 index 000000000..39f652d38 --- /dev/null +++ b/configs/convmixer/convmixer_1024_20.yaml @@ -0,0 +1,67 @@ +# system +mode: 0 +distribute: False +num_parallel_workers: 4 +val_while_train: True + +# dataset +dataset: 'imagenet' +data_dir: 'path/to/imagenet/' +shuffle: True +dataset_download: False +batch_size: 32 +drop_remainder: True +val_split: val + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +auto_augment: 'randaug-m9-mstd0.5-inc1' +interpolation: bilinear +re_prob: 0.25 +re_value: 'random' +cutmix: 0.5 +mixup: 0.5 +mixup_prob: 1.0 +mixup_mode: batch +mixup_off_epoch: 0.0 +switch_prob: 0.5 +crop_pct: 0.96 + +# model +model: 'convmixer_1024_20' +num_classes: 1000 +pretrained: False +ckpt_path: '' + +keep_checkpoint_max: 10 +ckpt_save_dir: './ckpt' + +epoch_size: 300 +dataset_sink_mode: True +amp_level: 'O0' +use_ema: False +ema_decay: 0.9999 +use_clip_grad: True +clip_value: 1.0 + +# loss +loss: 'CE' +label_smoothing: 0.1 + +# lr scheduler +lr_scheduler: 'cosine_decay' +lr: 0.001 +warmup_epochs: 0 +warmup_factor: 0.007 +min_lr: 0.0001 +decay_epochs: 300 + +# optimizer +opt: 'adamw' +momentum: 0.9 +weight_decay: 0.0001 +loss_scale: 1024 +dynamic_loss_scale: False diff --git a/configs/convmixer/convmixer_1536_20.yaml b/configs/convmixer/convmixer_1536_20.yaml new file mode 100644 index 000000000..ca09ff90c --- /dev/null +++ b/configs/convmixer/convmixer_1536_20.yaml @@ -0,0 +1,67 @@ +# system +mode: 0 +distribute: False +num_parallel_workers: 4 +val_while_train: True + +# dataset +dataset: 'imagenet' +data_dir: 'path/to/imagenet/' +shuffle: True +dataset_download: False +batch_size: 32 +drop_remainder: True +val_split: val + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +auto_augment: 'randaug-m9-mstd0.5-inc1' +interpolation: bilinear +re_prob: 0.25 +re_value: 'random' +cutmix: 0.5 +mixup: 0.5 +mixup_prob: 1.0 +mixup_mode: batch +mixup_off_epoch: 0.0 +switch_prob: 0.5 +crop_pct: 0.96 + +# model +model: 'convmixer_1536_20' +num_classes: 1000 +pretrained: False +ckpt_path: '' + +keep_checkpoint_max: 10 +ckpt_save_dir: './ckpt' + +epoch_size: 300 +dataset_sink_mode: True +amp_level: 'O0' +use_ema: False +ema_decay: 0.9999 +use_clip_grad: True +clip_value: 1.0 + +# loss +loss: 'CE' +label_smoothing: 0.1 + +# lr scheduler +lr_scheduler: 'cosine_decay' +lr: 0.001 +warmup_epochs: 0 +warmup_factor: 0.007 +min_lr: 0.0001 +decay_epochs: 300 + +# optimizer +opt: 'adamw' +momentum: 0.9 +weight_decay: 0.0001 +loss_scale: 1024 +dynamic_loss_scale: False diff --git a/configs/convmixer/convmixer_768_32.yaml b/configs/convmixer/convmixer_768_32.yaml new file mode 100644 index 000000000..c58effdb6 --- /dev/null +++ b/configs/convmixer/convmixer_768_32.yaml @@ -0,0 +1,66 @@ +# system +mode: 0 +distribute: False +num_parallel_workers: 4 +val_while_train: True + +# dataset +dataset: 'imagenet' +data_dir: 'path/to/imagenet/' +shuffle: True +dataset_download: False +batch_size: 32 +drop_remainder: True +val_split: val + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +auto_augment: 'randaug-m9-mstd0.5-inc1' +interpolation: bilinear +re_prob: 0.25 +re_value: 'random' +cutmix: 0.5 +mixup: 0.5 +mixup_prob: 1.0 +mixup_mode: batch +mixup_off_epoch: 0.0 +switch_prob: 0.5 +crop_pct: 0.96 + +# model +model: 'convmixer_768_32' +num_classes: 1000 +pretrained: False +ckpt_path: '' + +keep_checkpoint_max: 10 +ckpt_save_dir: './ckpt' + +epoch_size: 300 +dataset_sink_mode: True +amp_level: 'O0' +use_ema: False +ema_decay: 0.9999 +use_clip_grad: True +clip_value: 1.0 + +# loss +loss: 'CE' +label_smoothing: 0.1 + +# lr scheduler +lr_scheduler: 'onecycle' +lr: 0.001 +warmup_epochs: 0 +min_lr: 0.000001 +decay_epochs: 300 + +# optimizer +opt: 'adamw' +momentum: 0.9 +weight_decay: 0.0001 +loss_scale: 1024 +dynamic_loss_scale: False diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py index d0521efff..68530a906 100644 --- a/mindcv/models/__init__.py +++ b/mindcv/models/__init__.py @@ -2,6 +2,7 @@ from . import ( bit, convit, + convmixer, convnext, densenet, dpn, @@ -44,6 +45,7 @@ ) from .bit import * from .convit import * +from .convmixer import * from .convnext import * from .densenet import * from .dpn import * @@ -90,6 +92,7 @@ __all__ = [] __all__.extend(bit.__all__) __all__.extend(convit.__all__) +__all__.extend(convmixer.__all__) __all__.extend(convnext.__all__) __all__.extend(densenet.__all__) __all__.extend(dpn.__all__) diff --git a/mindcv/models/convmixer.py b/mindcv/models/convmixer.py new file mode 100644 index 000000000..2246739d3 --- /dev/null +++ b/mindcv/models/convmixer.py @@ -0,0 +1,207 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Implementation of the ConvMixer model. +Refer to "Patches Are All You Need?" +""" +import mindspore.nn as nn +from mindspore.ops import ReduceMean + +from .registry import register_model +from .utils import load_pretrained + +__all__ = [ + "convmixer_768_32", + "convmixer_1024_20", + "convmixer_1536_20" +] + + +def _cfg(classifier, url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "first_conv": 'network.0', + "classifier": classifier, + **kwargs + } + + +default_cfgs = { + "convmixer_768_32": _cfg( + classifier="network.37", + url="https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_768_32.ckpt"), + "convmixer_1024_20": _cfg( + classifier="network.25", + url="https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_1024_20.ckpt"), + "convmixer_1536_20": _cfg( + classifier="network.25", + url="https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_1536_20.ckpt"), +} + + +class Residual(nn.Cell): + """Residual connection. """ + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def construct(self, *inputs, **kwargs): + x = inputs[0] + return self.fn(x) + x + + +class AvgPoolReduceMean(nn.Cell): + """AvgPool cell implemented on the basis of ReduceMean op.""" + + def construct(self, *inputs, **kwargs): + """Forward pass.""" + x = inputs[0] + return ReduceMean(True)(x, (2, 3)) + + +class ConvMixer(nn.Cell): + """ConvMixer model.""" + + def __init__( + self, + dim, + depth, + kernel_size=9, + patch_size=7, + in_channels=3, + n_classes=1000, + act_type='gelu', + onnx_export=False, + ): + super().__init__() + if act_type.lower() == 'gelu': + act = nn.GELU + elif act_type.lower() == 'relu': + act = nn.ReLU + else: + raise NotImplementedError() + + avg_pool = AvgPoolReduceMean() if onnx_export \ + else nn.AdaptiveAvgPool2d((1, 1)) + + self.network = nn.SequentialCell( + nn.Conv2d( + in_channels, + dim, + kernel_size=patch_size, + stride=patch_size, + has_bias=True, + pad_mode='pad', + padding=0, + ), + act(), + nn.BatchNorm2d(dim), + *[nn.SequentialCell( + Residual( + nn.SequentialCell( + nn.Conv2d( + dim, + dim, + kernel_size, + group=dim, + pad_mode='same', + has_bias=True + ), + act(), + nn.BatchNorm2d(dim) + ) + ), + nn.Conv2d( + dim, + dim, + kernel_size=1, + has_bias=True, + pad_mode='pad', + padding=0, + ), + act(), + nn.BatchNorm2d(dim) + ) for _ in range(depth)], + avg_pool, + nn.Flatten(), + nn.Dense(dim, n_classes) + ) + + def construct(self, *inputs, **kwargs): + """Forward pass.""" + x = inputs[0] + x = self.network(x) + return x + + +@register_model +def convmixer_1536_20(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs): + """Create ConvMixer-1536/20 model.""" + model = ConvMixer( + 1536, + 20, + kernel_size=9, + patch_size=7, + in_channels=in_channels, + n_classes=num_classes, + ) + default_cfg = default_cfgs['convmixer_1536_20'] + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model + + +@register_model +def convmixer_1024_20(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs): + """Create ConvMixer-1024/20 model.""" + model = ConvMixer( + 1024, + 20, + kernel_size=9, + patch_size=14, + in_channels=in_channels, + n_classes=num_classes, + ) + default_cfg = default_cfgs['convmixer_1024_20'] + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model + + +@register_model +def convmixer_768_32(pretrained: bool = False, num_classes=1000, in_channels=3, + act_type='relu', **kwargs): + """Create ConvMixer-768/32 model.""" + model = ConvMixer( + 768, + 32, + kernel_size=7, + patch_size=7, + in_channels=in_channels, + n_classes=num_classes, + act_type=act_type, + ) + default_cfg = default_cfgs['convmixer_768_32'] + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model From 994b3ca5cb37164a492cb12d5b5eb2b03ba7c2e2 Mon Sep 17 00:00:00 2001 From: tsy <1002548612@qq.com> Date: Fri, 10 Mar 2023 06:53:09 -0500 Subject: [PATCH 3/4] Renamed configs --- configs/convmixer/README.md | 14 +++++++------- ...xer_1024_20.yaml => convmixer_1024_20_gpu.yaml} | 0 ...xer_1536_20.yaml => convmixer_1536_20_gpu.yaml} | 0 ...mixer_768_32.yaml => convmixer_768_32_gpu.yaml} | 0 4 files changed, 7 insertions(+), 7 deletions(-) rename configs/convmixer/{convmixer_1024_20.yaml => convmixer_1024_20_gpu.yaml} (100%) rename configs/convmixer/{convmixer_1536_20.yaml => convmixer_1536_20_gpu.yaml} (100%) rename configs/convmixer/{convmixer_768_32.yaml => convmixer_768_32_gpu.yaml} (100%) diff --git a/configs/convmixer/README.md b/configs/convmixer/README.md index c984b6c19..a846869e1 100644 --- a/configs/convmixer/README.md +++ b/configs/convmixer/README.md @@ -32,10 +32,10 @@ Our reproduced model performance on ImageNet-1K is reported as follows. | Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | |----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| -| convmixer_768_32 | Converted from PyTorch | 79.68 | 94.92 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_768_32.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_768_32.ckpt) | -| convmixer_768_32 | 8xRTX3090 | 73.05 | 90.53 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_768_32.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/convmixer_768_32_trained.ckpt) | -| convmixer_1024_20 | Converted from PyTorch | 76.68 | 93.3 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_1024_20.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_1024_20.ckpt) | -| convmixer_1536_20 | Converted from PyTorch | 80.98 | 95.51 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_1536_20.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_1536_20.ckpt) | +| convmixer_768_32 | Converted from PyTorch | 79.68 | 94.92 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_768_32_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_768_32.ckpt) | +| convmixer_768_32 | 8xRTX3090 | 73.05 | 90.53 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_768_32_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/convmixer_768_32_trained.ckpt) | +| convmixer_1024_20 | Converted from PyTorch | 76.68 | 93.3 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_1024_20_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_1024_20.ckpt) | +| convmixer_1536_20 | Converted from PyTorch | 80.98 | 95.51 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/convmixer/convmixer_1536_20_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/ConvMixer/Converted/convmixer_1536_20.ckpt) | @@ -62,7 +62,7 @@ Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/201 ```shell # distrubted training on multiple GPU/Ascend devices -mpirun -n 8 python train.py --config configs/convmixer/convmixer_768_32.yaml --data_dir /path/to/imagenet --distributed True +mpirun -n 8 python train.py --config configs/convmixer/convmixer_768_32_gpu.yaml --data_dir /path/to/imagenet --distributed True ``` > If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`. @@ -79,7 +79,7 @@ If you want to train or finetune the model on a smaller dataset without distribu ```shell # standalone training on a CPU/GPU/Ascend device -python train.py --config configs/convmixer/convmixer_768_32.yaml --data_dir /path/to/dataset --distribute False +python train.py --config configs/convmixer/convmixer_768_32_gpu.yaml --data_dir /path/to/dataset --distribute False ``` ### Validation @@ -87,7 +87,7 @@ python train.py --config configs/convmixer/convmixer_768_32.yaml --data_dir /pat To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`. ```shell -python validate.py -c configs/convmixer/convmixer_768_32.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt +python validate.py -c configs/convmixer/convmixer_768_32_gpu.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt ``` ### Deployment diff --git a/configs/convmixer/convmixer_1024_20.yaml b/configs/convmixer/convmixer_1024_20_gpu.yaml similarity index 100% rename from configs/convmixer/convmixer_1024_20.yaml rename to configs/convmixer/convmixer_1024_20_gpu.yaml diff --git a/configs/convmixer/convmixer_1536_20.yaml b/configs/convmixer/convmixer_1536_20_gpu.yaml similarity index 100% rename from configs/convmixer/convmixer_1536_20.yaml rename to configs/convmixer/convmixer_1536_20_gpu.yaml diff --git a/configs/convmixer/convmixer_768_32.yaml b/configs/convmixer/convmixer_768_32_gpu.yaml similarity index 100% rename from configs/convmixer/convmixer_768_32.yaml rename to configs/convmixer/convmixer_768_32_gpu.yaml From c1b06f8cbfe68f3611eccbacb804a4426a59b264 Mon Sep 17 00:00:00 2001 From: tsy <1002548612@qq.com> Date: Fri, 10 Mar 2023 07:41:21 -0500 Subject: [PATCH 4/4] Remove copyright in model script --- mindcv/models/convmixer.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/mindcv/models/convmixer.py b/mindcv/models/convmixer.py index 2246739d3..268acef24 100644 --- a/mindcv/models/convmixer.py +++ b/mindcv/models/convmixer.py @@ -1,17 +1,3 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ """ Implementation of the ConvMixer model. Refer to "Patches Are All You Need?"