From c50c9cfe4324bc337824db28b6017edc70536542 Mon Sep 17 00:00:00 2001 From: YeonwooSung Date: Sun, 31 Mar 2024 17:12:34 +0900 Subject: [PATCH] feat: Add codes for meta pseudo labeling --- README.md | 9 +- TrainingTricks/meta_pseudo_label/.gitignore | 148 ++++ TrainingTricks/meta_pseudo_label/README.md | 109 +++ .../meta_pseudo_label/augmentation.py | 328 +++++++++ TrainingTricks/meta_pseudo_label/data.py | 254 +++++++ TrainingTricks/meta_pseudo_label/main.py | 648 ++++++++++++++++++ TrainingTricks/meta_pseudo_label/models.py | 157 +++++ TrainingTricks/meta_pseudo_label/utils.py | 142 ++++ 8 files changed, 1791 insertions(+), 4 deletions(-) create mode 100644 TrainingTricks/meta_pseudo_label/.gitignore create mode 100644 TrainingTricks/meta_pseudo_label/README.md create mode 100644 TrainingTricks/meta_pseudo_label/augmentation.py create mode 100644 TrainingTricks/meta_pseudo_label/data.py create mode 100644 TrainingTricks/meta_pseudo_label/main.py create mode 100644 TrainingTricks/meta_pseudo_label/models.py create mode 100644 TrainingTricks/meta_pseudo_label/utils.py diff --git a/README.md b/README.md index 7fa844c5..8b870a04 100644 --- a/README.md +++ b/README.md @@ -60,10 +60,11 @@ Free online book on Artificial Intelligence to help people learn AI easily. - [Self-Supervised Learning](./SelfSupervisedLearning) -- [Training](./Training) - * [Adversarial ML](./Training/AdversarialML/) - * [Knowledge Distillation](./Training/KnowledgeDistillation/) - * [Transfer Learning](./Training/TransferLearning/) +- [TrainingTricks](./TrainingTricks) + * [Adversarial ML](./TrainingTricks/AdversarialML/) + * [Meta Pseudo Labeling](./TrainingTricks/meta_pseudo_label/) + * [Knowledge Distillation](./TrainingTricks/KnowledgeDistillation/) + * [Transfer Learning](./TrainingTricks/TransferLearning/) - [XAI](./XAI) diff --git a/TrainingTricks/meta_pseudo_label/.gitignore b/TrainingTricks/meta_pseudo_label/.gitignore new file mode 100644 index 00000000..419ae1e6 --- /dev/null +++ b/TrainingTricks/meta_pseudo_label/.gitignore @@ -0,0 +1,148 @@ +.vscode +wandb/ +results/ +*.npz +*.jpg +*.JPG +*.jpeg +*.JPEG +*.png +*.PNG +*.webp +*.WEBP +*.gif +*.GIF +*.zip +*.tar +checkpoint/ +data/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/TrainingTricks/meta_pseudo_label/README.md b/TrainingTricks/meta_pseudo_label/README.md new file mode 100644 index 00000000..99641beb --- /dev/null +++ b/TrainingTricks/meta_pseudo_label/README.md @@ -0,0 +1,109 @@ +# Meta Pseudo Labels +This is an unofficial PyTorch implementation of [Meta Pseudo Labels](https://arxiv.org/abs/2003.10580). +The official Tensorflow implementation is [here](https://github.com/google-research/google-research/tree/master/meta_pseudo_labels). + + +## Results + +| | CIFAR-10-4K | SVHN-1K | ImageNet-10% | +|:---:|:---:|:---:|:---:| +| Paper (w/ finetune) | 96.11 ± 0.07 | 98.01 ± 0.07 | 73.89 | +| This code (w/o finetune) | 96.01 | - | - | +| This code (w/ finetune) | 96.08 | - | - | +| Acc. curve | [w/o finetune](https://tensorboard.dev/experiment/ehMVEk39SrGiqM43ye2c7w/)
[w/ finetune](https://tensorboard.dev/experiment/vbqR7dt2Q9aw6rf8yVu56g/) | - | - | + +* February 2022, Retested. + +## Usage + +Train the model by 4000 labeled data of CIFAR-10 dataset: + +``` +python main.py \ + --seed 2 \ + --name cifar10-4K.2 \ + --expand-labels \ + --dataset cifar10 \ + --num-classes 10 \ + --num-labeled 4000 \ + --total-steps 300000 \ + --eval-step 1000 \ + --randaug 2 16 \ + --batch-size 128 \ + --teacher_lr 0.05 \ + --student_lr 0.05 \ + --weight-decay 5e-4 \ + --ema 0.995 \ + --nesterov \ + --mu 7 \ + --label-smoothing 0.15 \ + --temperature 0.7 \ + --threshold 0.6 \ + --lambda-u 8 \ + --warmup-steps 5000 \ + --uda-steps 5000 \ + --student-wait-steps 3000 \ + --teacher-dropout 0.2 \ + --student-dropout 0.2 \ + --finetune-epochs 625 \ + --finetune-batch-size 512 \ + --finetune-lr 3e-5 \ + --finetune-weight-decay 0 \ + --finetune-momentum 0.9 \ + --amp +``` + +Train the model by 10000 labeled data of CIFAR-100 dataset by using DistributedDataParallel: +``` +python -m torch.distributed.launch --nproc_per_node 4 main.py \ + --seed 2 \ + --name cifar100-10K.2 \ + --dataset cifar100 \ + --num-classes 100 \ + --num-labeled 10000 \ + --expand-labels \ + --total-steps 300000 \ + --eval-step 1000 \ + --randaug 2 16 \ + --batch-size 128 \ + --teacher_lr 0.05 \ + --student_lr 0.05 \ + --weight-decay 5e-4 \ + --ema 0.995 \ + --nesterov \ + --mu 7 \ + --label-smoothing 0.15 \ + --temperature 0.7 \ + --threshold 0.6 \ + --lambda-u 8 \ + --warmup-steps 5000 \ + --uda-steps 5000 \ + --student-wait-steps 3000 \ + --teacher-dropout 0.2 \ + --student-dropout 0.2 \ + --finetune-epochs 250 \ + --finetune-batch-size 512 \ + --finetune-lr 3e-5 \ + --finetune-weight-decay 0 \ + --finetune-momentum 0.9 \ + --amp +``` + +Monitoring training progress + +tensorboard +``` +tensorboard --logdir results +``` +or + +Use wandb + +## Requirements +- python 3.6+ +- torch 1.7+ +- torchvision 0.8+ +- tensorboard +- wandb +- numpy +- tqdm diff --git a/TrainingTricks/meta_pseudo_label/augmentation.py b/TrainingTricks/meta_pseudo_label/augmentation.py new file mode 100644 index 00000000..c8b8acab --- /dev/null +++ b/TrainingTricks/meta_pseudo_label/augmentation.py @@ -0,0 +1,328 @@ +# code in this file is adpated from +# https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py +# https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py +# https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py +import logging +import random + +import numpy as np +import PIL +import PIL.ImageOps +import PIL.ImageEnhance +import PIL.ImageDraw +from PIL import Image + +logger = logging.getLogger(__name__) + +PARAMETER_MAX = 10 +RESAMPLE_MODE = Image.BICUBIC +FILL_COLOR = (128, 128, 128) + + +def AutoContrast(img, **kwarg): + return PIL.ImageOps.autocontrast(img) + + +def Brightness(img, v, max_v, bias=0): + v = _float_parameter(v, max_v) + bias + return PIL.ImageEnhance.Brightness(img).enhance(v) + + +def Color(img, v, max_v, bias=0): + v = _float_parameter(v, max_v) + bias + return PIL.ImageEnhance.Color(img).enhance(v) + + +def Contrast(img, v, max_v, bias=0): + v = _float_parameter(v, max_v) + bias + return PIL.ImageEnhance.Contrast(img).enhance(v) + + +def Cutout(img, v, max_v, **kwarg): + if v == 0: + return img + v = _float_parameter(v, max_v) + v = int(v * min(img.size)) + + w, h = img.size + x0 = np.random.uniform(0, w) + y0 = np.random.uniform(0, h) + x0 = int(max(0, x0 - v / 2.)) + y0 = int(max(0, y0 - v / 2.)) + x1 = int(min(w, x0 + v)) + y1 = int(min(h, y0 + v)) + xy = (x0, y0, x1, y1) + # gray + color = FILL_COLOR + img = img.copy() + PIL.ImageDraw.Draw(img).rectangle(xy, color) + return img + + +def CutoutConst(img, v, max_v, **kwarg): + v = _int_parameter(v, max_v) + w, h = img.size + x0 = np.random.uniform(0, w) + y0 = np.random.uniform(0, h) + x0 = int(max(0, x0 - v / 2.)) + y0 = int(max(0, y0 - v / 2.)) + x1 = int(min(w, x0 + v)) + y1 = int(min(h, y0 + v)) + xy = (x0, y0, x1, y1) + # gray + color = FILL_COLOR + img = img.copy() + PIL.ImageDraw.Draw(img).rectangle(xy, color) + return img + + +def Equalize(img, **kwarg): + return PIL.ImageOps.equalize(img) + + +def Identity(img, **kwarg): + return img + + +def Invert(img, **kwarg): + return PIL.ImageOps.invert(img) + + +def Posterize(img, v, max_v, bias=0): + v = 8 - _round_parameter(v, max_v) + bias + return PIL.ImageOps.posterize(img, v) + + +def Rotate(img, v, max_v, **kwarg): + v = _float_parameter(v, max_v) + if random.random() < 0.5: + v = -v + return img.rotate(v, RESAMPLE_MODE, fillcolor=FILL_COLOR) + + +def Sharpness(img, v, max_v, bias): + v = _float_parameter(v, max_v) + bias + return PIL.ImageEnhance.Sharpness(img).enhance(v) + + +def ShearX(img, v, max_v, **kwarg): + v = _float_parameter(v, max_v) + if random.random() < 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0), RESAMPLE_MODE, fillcolor=FILL_COLOR) + + +def ShearY(img, v, max_v, **kwarg): + v = _float_parameter(v, max_v) + if random.random() < 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0), RESAMPLE_MODE, fillcolor=FILL_COLOR) + + +def Solarize(img, v, max_v, **kwarg): + v = _int_parameter(v, max_v) + return PIL.ImageOps.solarize(img, 255 - v) + + +def SolarizeAdd(img, v, max_v, threshold=128, **kwarg): + v = _int_parameter(v, max_v) + img_np = np.array(img).astype(np.int) + img_np = img_np + v + img_np = np.clip(img_np, 0, 255) + img_np = img_np.astype(np.uint8) + img = Image.fromarray(img_np) + return PIL.ImageOps.solarize(img, threshold) + + +def TranslateX(img, v, max_v, **kwarg): + v = _float_parameter(v, max_v) + if random.random() < 0.5: + v = -v + v = int(v * img.size[0]) + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0), RESAMPLE_MODE, fillcolor=FILL_COLOR) + + +def TranslateY(img, v, max_v, **kwarg): + v = _float_parameter(v, max_v) + if random.random() < 0.5: + v = -v + v = int(v * img.size[1]) + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v), RESAMPLE_MODE, fillcolor=FILL_COLOR) + + +def TranslateXConst(img, v, max_v, **kwarg): + v = _float_parameter(v, max_v) + if random.random() < 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0), RESAMPLE_MODE, fillcolor=FILL_COLOR) + + +def TranslateYConst(img, v, max_v, **kwarg): + v = _float_parameter(v, max_v) + if random.random() < 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v), RESAMPLE_MODE, fillcolor=FILL_COLOR) + + +def _float_parameter(v, max_v): + return float(v) * max_v / PARAMETER_MAX + + +def _int_parameter(v, max_v): + return int(v * max_v / PARAMETER_MAX) + + +def _round_parameter(v, max_v): + return int(round(v * max_v / PARAMETER_MAX)) + + +def rand_augment_pool(): + # Test + augs = [ + (AutoContrast, None, None), + (Equalize, None, None), + (Invert, None, None), + (Rotate, 30, None), + (Posterize, 4, 0), + (Solarize, 256, None), + (Color, 1.8, 0.1), + (Contrast, 1.8, 0.1), + (Brightness, 1.8, 0.1), + (Sharpness, 1.8, 0.1), + (ShearX, 0.3, None), + (ShearY, 0.3, None), + (TranslateXConst, 100, None), + (TranslateYConst, 100, None), + (CutoutConst, 40, None), # Use RandomErasing instead of Cutout. + ] + return augs + + +def fixmatch_augment_pool(): + # FixMatch paper + augs = [ + (AutoContrast, None, None), + (Brightness, 0.9, 0.05), + (Color, 0.9, 0.05), + (Contrast, 0.9, 0.05), + (Equalize, None, None), + (Identity, None, None), + (Posterize, 4, 4), + (Rotate, 30, 0), + (Sharpness, 0.9, 0.05), + (ShearX, 0.3, 0), + (ShearY, 0.3, 0), + (Solarize, 256, None), + (TranslateX, 0.3, 0), + (TranslateY, 0.3, 0) + ] + return augs + + +def cifar_augment_pool(): + # Test + augs = [ + (AutoContrast, None, None), + (Equalize, None, None), + (Invert, None, None), + (Rotate, 30, None), + (Posterize, 4, 0), + (Solarize, 256, None), + (Color, 1.8, 0.1), + (Contrast, 1.8, 0.1), + (Brightness, 1.8, 0.1), + (Sharpness, 1.8, 0.1), + (ShearX, 0.3, None), + (ShearY, 0.3, None), + (TranslateXConst, 32 // 8, None), + (TranslateYConst, 32 // 8, None), + (CutoutConst, 32 // 8, None), + ] + return augs + + +def soft_augment_pool(): + # Test + augs = [(AutoContrast, None, None), + (Brightness, 1.8, 0.1), + (Color, 1.8, 0.1), + (Contrast, 1.8, 0.1), + (CutoutConst, 40, None), + (Posterize, 4, 0), + (Sharpness, 1.8, 0.1), + ] + return augs + + +class SoftAugment(object): + def __init__(self, n, m, resample_mode=PIL.Image.BICUBIC): + global RESAMPLE_MODE + RESAMPLE_MODE = resample_mode + self.n = n + self.m = m + self.augment_pool = soft_augment_pool() + + def __call__(self, img): + ops = random.choices(self.augment_pool, k=self.n) + for op, max_v, bias in ops: + prob = np.random.uniform(0.2, 0.8) + if random.random() + prob >= 1: + img = op(img, v=self.m, max_v=max_v, bias=bias) + return img + + +class RandAugment(object): + def __init__(self, n, m, resample_mode=PIL.Image.BICUBIC, fill=(128, 128, 128)): + global RESAMPLE_MODE, FILL_COLOR + RESAMPLE_MODE = resample_mode + FILL_COLOR = fill + self.n = int(n) + self.m = m + self.mstd = mstd + self.augment_pool = rand_augment_pool() + + def __call__(self, img): + ops = random.choices(self.augment_pool, k=self.n) + for op, max_v, bias in ops: + prob = np.random.uniform(0.2, 0.8) + if random.random() <= prob: + img = op(img, v=self.m, max_v=max_v, bias=bias) + return img + + +class RandAugmentCIFAR(object): + def __init__(self, n, m, resample_mode=PIL.Image.BICUBIC, fill=(128, 128, 128)): + global RESAMPLE_MODE, FILL_COLOR + RESAMPLE_MODE = resample_mode + FILL_COLOR = fill + self.n = int(n) + self.m = m + self.augment_pool = cifar_augment_pool() + + def __call__(self, img): + ops = random.choices(self.augment_pool, k=self.n) + for op, max_v, bias in ops: + prob = np.random.uniform(0.2, 0.8) + if random.random() <= prob: + img = op(img, v=self.m, max_v=max_v, bias=bias) + img = CutoutConst(img, v=self.m, max_v=32 // 4) + return img + + +class RandAugmentMC(object): + def __init__(self, n, m, resample_mode=PIL.Image.BICUBIC, fill=(128, 128, 128)): + global RESAMPLE_MODE, FILL_COLOR + RESAMPLE_MODE = resample_mode + FILL_COLOR = fill + self.n = int(n) + self.m = m + self.augment_pool = fixmatch_augment_pool() + + def __call__(self, img): + ops = random.choices(self.augment_pool, k=self.n) + for op, max_v, bias in ops: + v = np.random.randint(1, self.m) + if random.random() < 0.5: + img = op(img, v=v, max_v=max_v, bias=bias) + img = CutoutConst(img, 40) + return img diff --git a/TrainingTricks/meta_pseudo_label/data.py b/TrainingTricks/meta_pseudo_label/data.py new file mode 100644 index 00000000..f43018aa --- /dev/null +++ b/TrainingTricks/meta_pseudo_label/data.py @@ -0,0 +1,254 @@ +import logging +import math + +import numpy as np +from PIL import Image +from torchvision import datasets +from torchvision import transforms +from torchvision.transforms import InterpolationMode + +from augmentation import RandAugmentCIFAR + +logger = logging.getLogger(__name__) + +cifar10_mean = (0.491400, 0.482158, 0.4465231) +cifar10_std = (0.247032, 0.243485, 0.2615877) +cifar100_mean = (0.507075, 0.486549, 0.440918) +cifar100_std = (0.267334, 0.256438, 0.276151) +normal_mean = (0.5, 0.5, 0.5) +normal_std = (0.5, 0.5, 0.5) + + +def get_cifar10(args): + if args.randaug: + n, m = args.randaug + else: + n, m = 2, 10 # default + transform_labeled = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(size=args.resize, + padding=int(args.resize * 0.125), + fill=128, + padding_mode='constant'), + transforms.ToTensor(), + transforms.Normalize(mean=cifar10_mean, std=cifar10_std), + ]) + transform_finetune = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(size=args.resize, + padding=int(args.resize * 0.125), + fill=128, + padding_mode='constant'), + RandAugmentCIFAR(n=n, m=m), + transforms.ToTensor(), + transforms.Normalize(mean=cifar10_mean, std=cifar10_std), + ]) + transform_val = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=cifar10_mean, std=cifar10_std) + ]) + base_dataset = datasets.CIFAR10(args.data_path, train=True, download=True) + + train_labeled_idxs, train_unlabeled_idxs, finetune_idxs = x_u_split(args, base_dataset.targets) + + train_labeled_dataset = CIFAR10SSL( + args.data_path, train_labeled_idxs, train=True, + transform=transform_labeled + ) + finetune_dataset = CIFAR10SSL( + args.data_path, finetune_idxs, train=True, + transform=transform_finetune + ) + train_unlabeled_dataset = CIFAR10SSL( + args.data_path, train_unlabeled_idxs, + train=True, + transform=TransformMPL(args, mean=cifar10_mean, std=cifar10_std) + ) + + test_dataset = datasets.CIFAR10(args.data_path, train=False, + transform=transform_val, download=False) + + return train_labeled_dataset, train_unlabeled_dataset, test_dataset, finetune_dataset + + +def get_cifar100(args): + if args.randaug: + n, m = args.randaug + else: + n, m = 2, 10 # default + transform_labeled = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(size=args.resize, + padding=int(args.resize * 0.125), + fill=128, + padding_mode='constant'), + transforms.ToTensor(), + transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) + transform_finetune = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(size=args.resize, + padding=int(args.resize * 0.125), + fill=128, + padding_mode='constant'), + RandAugmentCIFAR(n=n, m=m), + transforms.ToTensor(), + transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) + + transform_val = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) + + base_dataset = datasets.CIFAR100(args.data_path, train=True, download=True) + + train_labeled_idxs, train_unlabeled_idxs, finetune_idxs = x_u_split(args, base_dataset.targets) + + train_labeled_dataset = CIFAR100SSL( + args.data_path, train_labeled_idxs, train=True, + transform=transform_labeled + ) + finetune_dataset = CIFAR100SSL( + args.data_path, finetune_idxs, train=True, + transform=transform_fintune + ) + train_unlabeled_dataset = CIFAR100SSL( + args.data_path, train_unlabeled_idxs, train=True, + transform=TransformMPL(args, mean=cifar100_mean, std=cifar100_std) + ) + + test_dataset = datasets.CIFAR100(args.data_path, train=False, + transform=transform_val, download=False) + + return train_labeled_dataset, train_unlabeled_dataset, test_dataset, finetune_dataset + + +def x_u_split(args, labels): + label_per_class = args.num_labeled // args.num_classes + labels = np.array(labels) + labeled_idx = [] + # unlabeled data: all training data + unlabeled_idx = np.array(range(len(labels))) + for i in range(args.num_classes): + idx = np.where(labels == i)[0] + idx = np.random.choice(idx, label_per_class, False) + labeled_idx.extend(idx) + labeled_idx = np.array(labeled_idx) + assert len(labeled_idx) == args.num_labeled + + if args.expand_labels or args.num_labeled < args.batch_size: + num_expand_x = math.ceil( + args.batch_size * args.eval_step / args.num_labeled) + labeled_idx_ex = np.hstack([labeled_idx for _ in range(num_expand_x)]) + np.random.shuffle(labeled_idx_ex) + np.random.shuffle(labeled_idx) + return labeled_idx_ex, unlabeled_idx, labeled_idx + else: + np.random.shuffle(labeled_idx) + return labeled_idx, unlabeled_idx, lebeled_idx + + +def x_u_split_test(args, labels): + label_per_class = args.num_labeled // args.num_classes + labels = np.array(labels) + labeled_idx = [] + unlabeled_idx = [] + for i in range(args.num_classes): + idx = np.where(labels == i)[0] + np.random.shuffle(idx) + labeled_idx.extend(idx[:label_per_class]) + unlabeled_idx.extend(idx[label_per_class:]) + labeled_idx = np.array(labeled_idx) + unlabeled_idx = np.array(unlabeled_idx) + assert len(labeled_idx) == args.num_labeled + + if args.expand_labels or args.num_labeled < args.batch_size: + num_expand_x = math.ceil( + args.batch_size * args.eval_step / args.num_labeled) + labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)]) + + np.random.shuffle(labeled_idx) + np.random.shuffle(unlabeled_idx) + return labeled_idx, unlabeled_idx + + +class TransformMPL(object): + def __init__(self, args, mean, std): + if args.randaug: + n, m = args.randaug + else: + n, m = 2, 10 # default + + self.ori = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(size=args.resize, + padding=int(args.resize * 0.125), + fill=128, + padding_mode='constant')]) + self.aug = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(size=args.resize, + padding=int(args.resize * 0.125), + fill=128, + padding_mode='constant'), + RandAugmentCIFAR(n=n, m=m)]) + self.normalize = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std)]) + + def __call__(self, x): + ori = self.ori(x) + aug = self.aug(x) + return self.normalize(ori), self.normalize(aug) + + +class CIFAR10SSL(datasets.CIFAR10): + def __init__(self, root, indexs, train=True, + transform=None, target_transform=None, + download=False): + super().__init__(root, train=train, + transform=transform, + target_transform=target_transform, + download=download) + if indexs is not None: + self.data = self.data[indexs] + self.targets = np.array(self.targets)[indexs] + + def __getitem__(self, index): + img, target = self.data[index], self.targets[index] + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + +class CIFAR100SSL(datasets.CIFAR100): + def __init__(self, root, indexs, train=True, + transform=None, target_transform=None, + download=False): + super().__init__(root, train=train, + transform=transform, + target_transform=target_transform, + download=download) + if indexs is not None: + self.data = self.data[indexs] + self.targets = np.array(self.targets)[indexs] + + def __getitem__(self, index): + img, target = self.data[index], self.targets[index] + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + +DATASET_GETTERS = {'cifar10': get_cifar10, + 'cifar100': get_cifar100} diff --git a/TrainingTricks/meta_pseudo_label/main.py b/TrainingTricks/meta_pseudo_label/main.py new file mode 100644 index 00000000..6ba4792e --- /dev/null +++ b/TrainingTricks/meta_pseudo_label/main.py @@ -0,0 +1,648 @@ +import argparse +import logging +import math +import os +import random +import time + +import numpy as np +import torch +from torch.cuda import amp +from torch import nn +from torch.nn import functional as F +from torch import optim +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter +# import wandb +from tqdm import tqdm + +from data import DATASET_GETTERS +from models import WideResNet, ModelEMA +from utils import (AverageMeter, accuracy, create_loss_fn, + save_checkpoint, reduce_tensor, model_load_state_dict) + +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser() +parser.add_argument('--name', type=str, required=True, help='experiment name') +parser.add_argument('--data-path', default='./data', type=str, help='data path') +parser.add_argument('--save-path', default='./checkpoint', type=str, help='save path') +parser.add_argument('--dataset', default='cifar10', type=str, + choices=['cifar10', 'cifar100'], help='dataset name') +parser.add_argument('--num-labeled', type=int, default=4000, help='number of labeled data') +parser.add_argument("--expand-labels", action="store_true", help="expand labels to fit eval steps") +parser.add_argument('--total-steps', default=300000, type=int, help='number of total steps to run') +parser.add_argument('--eval-step', default=1000, type=int, help='number of eval steps to run') +parser.add_argument('--start-step', default=0, type=int, + help='manual epoch number (useful on restarts)') +parser.add_argument('--workers', default=4, type=int, help='number of workers') +parser.add_argument('--num-classes', default=10, type=int, help='number of classes') +parser.add_argument('--resize', default=32, type=int, help='resize image') +parser.add_argument('--batch-size', default=64, type=int, help='train batch size') +parser.add_argument('--teacher-dropout', default=0, type=float, help='dropout on last dense layer') +parser.add_argument('--student-dropout', default=0, type=float, help='dropout on last dense layer') +parser.add_argument('--teacher_lr', default=0.01, type=float, help='train learning late') +parser.add_argument('--student_lr', default=0.01, type=float, help='train learning late') +parser.add_argument('--momentum', default=0.9, type=float, help='SGD Momentum') +parser.add_argument('--nesterov', action='store_true', help='use nesterov') +parser.add_argument('--weight-decay', default=0, type=float, help='train weight decay') +parser.add_argument('--ema', default=0, type=float, help='EMA decay rate') +parser.add_argument('--warmup-steps', default=0, type=int, help='warmup steps') +parser.add_argument('--student-wait-steps', default=0, type=int, help='warmup steps') +parser.add_argument('--grad-clip', default=1e9, type=float, help='gradient norm clipping') +parser.add_argument('--resume', default='', type=str, help='path to checkpoint') +parser.add_argument('--evaluate', action='store_true', help='only evaluate model on validation set') +parser.add_argument('--finetune', action='store_true', + help='only finetune model on labeled dataset') +parser.add_argument('--finetune-epochs', default=625, type=int, help='finetune epochs') +parser.add_argument('--finetune-batch-size', default=512, type=int, help='finetune batch size') +parser.add_argument('--finetune-lr', default=3e-5, type=float, help='finetune learning late') +parser.add_argument('--finetune-weight-decay', default=0, type=float, help='finetune weight decay') +parser.add_argument('--finetune-momentum', default=0.9, type=float, help='finetune SGD Momentum') +parser.add_argument('--seed', default=None, type=int, help='seed for initializing training') +parser.add_argument('--label-smoothing', default=0, type=float, help='label smoothing alpha') +parser.add_argument('--mu', default=7, type=int, help='coefficient of unlabeled batch size') +parser.add_argument('--threshold', default=0.95, type=float, help='pseudo label threshold') +parser.add_argument('--temperature', default=1, type=float, help='pseudo label temperature') +parser.add_argument('--lambda-u', default=1, type=float, help='coefficient of unlabeled loss') +parser.add_argument('--uda-steps', default=1, type=float, help='warmup steps of lambda-u') +parser.add_argument("--randaug", nargs="+", type=int, help="use it like this. --randaug 2 10") +parser.add_argument("--amp", action="store_true", help="use 16-bit (mixed) precision") +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument("--local_rank", type=int, default=-1, + help="For distributed training: local_rank") + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + +def get_cosine_schedule_with_warmup(optimizer, + num_warmup_steps, + num_training_steps, + num_wait_steps=0, + num_cycles=0.5, + last_epoch=-1): + def lr_lambda(current_step): + if current_step < num_wait_steps: + return 0.0 + + if current_step < num_warmup_steps + num_wait_steps: + return float(current_step) / float(max(1, num_warmup_steps + num_wait_steps)) + + progress = float(current_step - num_warmup_steps - num_wait_steps) / \ + float(max(1, num_training_steps - num_warmup_steps - num_wait_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_lr(optimizer): + return optimizer.param_groups[0]['lr'] + + +def train_loop(args, labeled_loader, unlabeled_loader, test_loader, finetune_dataset, + teacher_model, student_model, avg_student_model, criterion, + t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler, s_scaler): + logger.info("***** Running Training *****") + logger.info(f" Task = {args.dataset}@{args.num_labeled}") + logger.info(f" Total steps = {args.total_steps}") + + if args.world_size > 1: + labeled_epoch = 0 + unlabeled_epoch = 0 + labeled_loader.sampler.set_epoch(labeled_epoch) + unlabeled_loader.sampler.set_epoch(unlabeled_epoch) + + labeled_iter = iter(labeled_loader) + unlabeled_iter = iter(unlabeled_loader) + + # for author's code formula + # moving_dot_product = torch.empty(1).to(args.device) + # limit = 3.0**(0.5) # 3 = 6 / (f_in + f_out) + # nn.init.uniform_(moving_dot_product, -limit, limit) + + for step in range(args.start_step, args.total_steps): + if step % args.eval_step == 0: + pbar = tqdm(range(args.eval_step), disable=args.local_rank not in [-1, 0]) + batch_time = AverageMeter() + data_time = AverageMeter() + s_losses = AverageMeter() + t_losses = AverageMeter() + t_losses_l = AverageMeter() + t_losses_u = AverageMeter() + t_losses_mpl = AverageMeter() + mean_mask = AverageMeter() + + teacher_model.train() + student_model.train() + end = time.time() + + try: + # error occurs ↓ + # images_l, targets = labeled_iter.next() + images_l, targets = next(labeled_iter) + except: + if args.world_size > 1: + labeled_epoch += 1 + labeled_loader.sampler.set_epoch(labeled_epoch) + labeled_iter = iter(labeled_loader) + # error occurs ↓ + # images_l, targets = labeled_iter.next() + images_l, targets = next(labeled_iter) + + try: + # error occurs ↓ + # (images_uw, images_us), _ = unlabeled_iter.next() + (images_uw, images_us), _ = next(unlabeled_iter) + except: + if args.world_size > 1: + unlabeled_epoch += 1 + unlabeled_loader.sampler.set_epoch(unlabeled_epoch) + unlabeled_iter = iter(unlabeled_loader) + # error occurs ↓ + # (images_uw, images_us), _ = unlabeled_iter.next() + (images_uw, images_us), _ = next(unlabeled_iter) + + data_time.update(time.time() - end) + + images_l = images_l.to(args.device) + images_uw = images_uw.to(args.device) + images_us = images_us.to(args.device) + targets = targets.to(args.device) + with amp.autocast(enabled=args.amp): + batch_size = images_l.shape[0] + t_images = torch.cat((images_l, images_uw, images_us)) + t_logits = teacher_model(t_images) + t_logits_l = t_logits[:batch_size] + t_logits_uw, t_logits_us = t_logits[batch_size:].chunk(2) + del t_logits + + t_loss_l = criterion(t_logits_l, targets) + + soft_pseudo_label = torch.softmax(t_logits_uw.detach() / args.temperature, dim=-1) + max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1) + mask = max_probs.ge(args.threshold).float() + t_loss_u = torch.mean( + -(soft_pseudo_label * torch.log_softmax(t_logits_us, dim=-1)).sum(dim=-1) * mask + ) + weight_u = args.lambda_u * min(1., (step + 1) / args.uda_steps) + t_loss_uda = t_loss_l + weight_u * t_loss_u + + s_images = torch.cat((images_l, images_us)) + s_logits = student_model(s_images) + s_logits_l = s_logits[:batch_size] + s_logits_us = s_logits[batch_size:] + del s_logits + + s_loss_l_old = F.cross_entropy(s_logits_l.detach(), targets) + s_loss = criterion(s_logits_us, hard_pseudo_label) + + s_scaler.scale(s_loss).backward() + if args.grad_clip > 0: + s_scaler.unscale_(s_optimizer) + nn.utils.clip_grad_norm_(student_model.parameters(), args.grad_clip) + s_scaler.step(s_optimizer) + s_scaler.update() + s_scheduler.step() + if args.ema > 0: + avg_student_model.update_parameters(student_model) + + with amp.autocast(enabled=args.amp): + with torch.no_grad(): + s_logits_l = student_model(images_l) + s_loss_l_new = F.cross_entropy(s_logits_l.detach(), targets) + + # theoretically correct formula (https://github.com/kekmodel/MPL-pytorch/issues/6) + # dot_product = s_loss_l_old - s_loss_l_new + + # author's code formula + dot_product = s_loss_l_new - s_loss_l_old + # moving_dot_product = moving_dot_product * 0.99 + dot_product * 0.01 + # dot_product = dot_product - moving_dot_product + + _, hard_pseudo_label = torch.max(t_logits_us.detach(), dim=-1) + t_loss_mpl = dot_product * F.cross_entropy(t_logits_us, hard_pseudo_label) + # test + # t_loss_mpl = torch.tensor(0.).to(args.device) + t_loss = t_loss_uda + t_loss_mpl + + t_scaler.scale(t_loss).backward() + if args.grad_clip > 0: + t_scaler.unscale_(t_optimizer) + nn.utils.clip_grad_norm_(teacher_model.parameters(), args.grad_clip) + t_scaler.step(t_optimizer) + t_scaler.update() + t_scheduler.step() + + teacher_model.zero_grad() + student_model.zero_grad() + + if args.world_size > 1: + s_loss = reduce_tensor(s_loss.detach(), args.world_size) + t_loss = reduce_tensor(t_loss.detach(), args.world_size) + t_loss_l = reduce_tensor(t_loss_l.detach(), args.world_size) + t_loss_u = reduce_tensor(t_loss_u.detach(), args.world_size) + t_loss_mpl = reduce_tensor(t_loss_mpl.detach(), args.world_size) + mask = reduce_tensor(mask, args.world_size) + + s_losses.update(s_loss.item()) + t_losses.update(t_loss.item()) + t_losses_l.update(t_loss_l.item()) + t_losses_u.update(t_loss_u.item()) + t_losses_mpl.update(t_loss_mpl.item()) + mean_mask.update(mask.mean().item()) + + batch_time.update(time.time() - end) + pbar.set_description( + f"Train Iter: {step+1:3}/{args.total_steps:3}. " + f"LR: {get_lr(s_optimizer):.4f}. Data: {data_time.avg:.2f}s. " + f"Batch: {batch_time.avg:.2f}s. S_Loss: {s_losses.avg:.4f}. " + f"T_Loss: {t_losses.avg:.4f}. Mask: {mean_mask.avg:.4f}. ") + pbar.update() + if args.local_rank in [-1, 0]: + args.writer.add_scalar("lr", get_lr(s_optimizer), step) +# wandb.log({"lr": get_lr(s_optimizer)}) + + args.num_eval = step // args.eval_step + if (step + 1) % args.eval_step == 0: + pbar.close() + if args.local_rank in [-1, 0]: + args.writer.add_scalar("train/1.s_loss", s_losses.avg, args.num_eval) + args.writer.add_scalar("train/2.t_loss", t_losses.avg, args.num_eval) + args.writer.add_scalar("train/3.t_labeled", t_losses_l.avg, args.num_eval) + args.writer.add_scalar("train/4.t_unlabeled", t_losses_u.avg, args.num_eval) + args.writer.add_scalar("train/5.t_mpl", t_losses_mpl.avg, args.num_eval) + args.writer.add_scalar("train/6.mask", mean_mask.avg, args.num_eval) +# wandb.log({"train/1.s_loss": s_losses.avg, +# "train/2.t_loss": t_losses.avg, +# "train/3.t_labeled": t_losses_l.avg, +# "train/4.t_unlabeled": t_losses_u.avg, +# "train/5.t_mpl": t_losses_mpl.avg, +# "train/6.mask": mean_mask.avg}) + + test_model = avg_student_model if avg_student_model is not None else student_model + test_loss, top1, top5 = evaluate(args, test_loader, test_model, criterion) + + args.writer.add_scalar("test/loss", test_loss, args.num_eval) + args.writer.add_scalar("test/acc@1", top1, args.num_eval) + args.writer.add_scalar("test/acc@5", top5, args.num_eval) +# wandb.log({"test/loss": test_loss, +# "test/acc@1": top1, +# "test/acc@5": top5}) + + is_best = top1 > args.best_top1 + if is_best: + args.best_top1 = top1 + args.best_top5 = top5 + + logger.info(f"top-1 acc: {top1:.2f}") + logger.info(f"Best top-1 acc: {args.best_top1:.2f}") + + save_checkpoint(args, { + 'step': step + 1, + 'teacher_state_dict': teacher_model.state_dict(), + 'student_state_dict': student_model.state_dict(), + 'avg_state_dict': avg_student_model.state_dict() if avg_student_model is not None else None, + 'best_top1': args.best_top1, + 'best_top5': args.best_top5, + 'teacher_optimizer': t_optimizer.state_dict(), + 'student_optimizer': s_optimizer.state_dict(), + 'teacher_scheduler': t_scheduler.state_dict(), + 'student_scheduler': s_scheduler.state_dict(), + 'teacher_scaler': t_scaler.state_dict(), + 'student_scaler': s_scaler.state_dict(), + }, is_best) + + if args.local_rank in [-1, 0]: + args.writer.add_scalar("result/test_acc@1", args.best_top1) +# wandb.log({"result/test_acc@1": args.best_top1}) + + # finetune + del t_scaler, t_scheduler, t_optimizer, teacher_model, labeled_loader, unlabeled_loader + del s_scaler, s_scheduler, s_optimizer + ckpt_name = f'{args.save_path}/{args.name}_best.pth.tar' + loc = f'cuda:{args.gpu}' + checkpoint = torch.load(ckpt_name, map_location=loc) + logger.info(f"=> loading checkpoint '{ckpt_name}'") + if checkpoint['avg_state_dict'] is not None: + model_load_state_dict(student_model, checkpoint['avg_state_dict']) + else: + model_load_state_dict(student_model, checkpoint['student_state_dict']) + finetune(args, finetune_dataset, test_loader, student_model, criterion) + return + + +def evaluate(args, test_loader, model, criterion): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + model.eval() + test_iter = tqdm(test_loader, disable=args.local_rank not in [-1, 0]) + with torch.no_grad(): + end = time.time() + for step, (images, targets) in enumerate(test_iter): + data_time.update(time.time() - end) + batch_size = images.shape[0] + images = images.to(args.device) + targets = targets.to(args.device) + with amp.autocast(enabled=args.amp): + outputs = model(images) + loss = criterion(outputs, targets) + + acc1, acc5 = accuracy(outputs, targets, (1, 5)) + losses.update(loss.item(), batch_size) + top1.update(acc1[0], batch_size) + top5.update(acc5[0], batch_size) + batch_time.update(time.time() - end) + end = time.time() + test_iter.set_description( + f"Test Iter: {step+1:3}/{len(test_loader):3}. Data: {data_time.avg:.2f}s. " + f"Batch: {batch_time.avg:.2f}s. Loss: {losses.avg:.4f}. " + f"top1: {top1.avg:.2f}. top5: {top5.avg:.2f}. ") + + test_iter.close() + return losses.avg, top1.avg, top5.avg + + +def finetune(args, finetune_dataset, test_loader, model, criterion): + model.drop = nn.Identity() + train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler + labeled_loader = DataLoader( + finetune_dataset, + batch_size=args.finetune_batch_size, + num_workers=args.workers, + pin_memory=True) + optimizer = optim.SGD(model.parameters(), + lr=args.finetune_lr, + momentum=args.finetune_momentum, + weight_decay=args.finetune_weight_decay, + nesterov=True) + scaler = amp.GradScaler(enabled=args.amp) + + logger.info("***** Running Finetuning *****") + logger.info(f" Finetuning steps = {len(labeled_loader)*args.finetune_epochs}") + + for epoch in range(args.finetune_epochs): + if args.world_size > 1: + labeled_loader.sampler.set_epoch(epoch + 624) + + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + model.train() + end = time.time() + labeled_iter = tqdm(labeled_loader, disable=args.local_rank not in [-1, 0]) + for step, (images, targets) in enumerate(labeled_iter): + data_time.update(time.time() - end) + batch_size = images.shape[0] + images = images.to(args.device) + targets = targets.to(args.device) + with amp.autocast(enabled=args.amp): + model.zero_grad() + outputs = model(images) + loss = criterion(outputs, targets) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + if args.world_size > 1: + loss = reduce_tensor(loss.detach(), args.world_size) + losses.update(loss.item(), batch_size) + batch_time.update(time.time() - end) + labeled_iter.set_description( + f"Finetune Epoch: {epoch+1:2}/{args.finetune_epochs:2}. Data: {data_time.avg:.2f}s. " + f"Batch: {batch_time.avg:.2f}s. Loss: {losses.avg:.4f}. ") + labeled_iter.close() + if args.local_rank in [-1, 0]: + args.writer.add_scalar("finetune/train_loss", losses.avg, epoch) + test_loss, top1, top5 = evaluate(args, test_loader, model, criterion) + args.writer.add_scalar("finetune/test_loss", test_loss, epoch) + args.writer.add_scalar("finetune/acc@1", top1, epoch) + args.writer.add_scalar("finetune/acc@5", top5, epoch) +# wandb.log({"finetune/train_loss": losses.avg, +# "finetune/test_loss": test_loss, +# "finetune/acc@1": top1, +# "finetune/acc@5": top5}) + + is_best = top1 > args.best_top1 + if is_best: + args.best_top1 = top1 + args.best_top5 = top5 + + logger.info(f"top-1 acc: {top1:.2f}") + logger.info(f"Best top-1 acc: {args.best_top1:.2f}") + + save_checkpoint(args, { + 'step': step + 1, + 'best_top1': args.best_top1, + 'best_top5': args.best_top5, + 'student_state_dict': model.state_dict(), + 'avg_state_dict': None, + 'student_optimizer': optimizer.state_dict(), + }, is_best, finetune=True) + if args.local_rank in [-1, 0]: + args.writer.add_scalar("result/finetune_acc@1", args.best_top1) +# wandb.log({"result/finetune_acc@1": args.best_top1}) + return + + +def main(): + args = parser.parse_args() + args.best_top1 = 0. + args.best_top5 = 0. + + if args.local_rank != -1: + args.gpu = args.local_rank + torch.distributed.init_process_group(backend='nccl') + args.world_size = torch.distributed.get_world_size() + else: + args.gpu = 0 + args.world_size = 1 + + args.device = torch.device('cuda', args.gpu) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARNING) + + logger.warning( + f"Process rank: {args.local_rank}, " + f"device: {args.device}, " + f"distributed training: {bool(args.local_rank != -1)}, " + f"16-bits training: {args.amp}") + + logger.info(dict(args._get_kwargs())) + + if args.local_rank in [-1, 0]: + args.writer = SummaryWriter(f"results/{args.name}") +# wandb.init(name=args.name, project='MPL', config=args) + + if args.seed is not None: + set_seed(args) + + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + + labeled_dataset, unlabeled_dataset, test_dataset, finetune_dataset = DATASET_GETTERS[args.dataset](args) + + if args.local_rank == 0: + torch.distributed.barrier() + + train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler + labeled_loader = DataLoader( + labeled_dataset, + sampler=train_sampler(labeled_dataset), + batch_size=args.batch_size, + num_workers=args.workers, + drop_last=True) + + unlabeled_loader = DataLoader( + unlabeled_dataset, + sampler=train_sampler(unlabeled_dataset), + batch_size=args.batch_size * args.mu, + num_workers=args.workers, + drop_last=True) + + test_loader = DataLoader(test_dataset, + sampler=SequentialSampler(test_dataset), + batch_size=args.batch_size, + num_workers=args.workers) + + if args.dataset == "cifar10": + depth, widen_factor = 28, 2 + elif args.dataset == 'cifar100': + depth, widen_factor = 28, 8 + + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + + teacher_model = WideResNet(num_classes=args.num_classes, + depth=depth, + widen_factor=widen_factor, + dropout=0, + dense_dropout=args.teacher_dropout) + student_model = WideResNet(num_classes=args.num_classes, + depth=depth, + widen_factor=widen_factor, + dropout=0, + dense_dropout=args.student_dropout) + + if args.local_rank == 0: + torch.distributed.barrier() + + logger.info(f"Model: WideResNet {depth}x{widen_factor}") + logger.info(f"Params: {sum(p.numel() for p in teacher_model.parameters())/1e6:.2f}M") + + teacher_model.to(args.device) + student_model.to(args.device) + avg_student_model = None + if args.ema > 0: + avg_student_model = ModelEMA(student_model, args.ema) + + criterion = create_loss_fn(args) + + no_decay = ['bn'] + teacher_parameters = [ + {'params': [p for n, p in teacher_model.named_parameters() if not any( + nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in teacher_model.named_parameters() if any( + nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + student_parameters = [ + {'params': [p for n, p in student_model.named_parameters() if not any( + nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in student_model.named_parameters() if any( + nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + + t_optimizer = optim.SGD(teacher_parameters, + lr=args.teacher_lr, + momentum=args.momentum, + nesterov=args.nesterov) + s_optimizer = optim.SGD(student_parameters, + lr=args.student_lr, + momentum=args.momentum, + nesterov=args.nesterov) + + t_scheduler = get_cosine_schedule_with_warmup(t_optimizer, + args.warmup_steps, + args.total_steps) + s_scheduler = get_cosine_schedule_with_warmup(s_optimizer, + args.warmup_steps, + args.total_steps, + args.student_wait_steps) + + t_scaler = amp.GradScaler(enabled=args.amp) + s_scaler = amp.GradScaler(enabled=args.amp) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + logger.info(f"=> loading checkpoint '{args.resume}'") + loc = f'cuda:{args.gpu}' + checkpoint = torch.load(args.resume, map_location=loc) + args.best_top1 = checkpoint['best_top1'].to(torch.device('cpu')) + args.best_top5 = checkpoint['best_top5'].to(torch.device('cpu')) + if not (args.evaluate or args.finetune): + args.start_step = checkpoint['step'] + t_optimizer.load_state_dict(checkpoint['teacher_optimizer']) + s_optimizer.load_state_dict(checkpoint['student_optimizer']) + t_scheduler.load_state_dict(checkpoint['teacher_scheduler']) + s_scheduler.load_state_dict(checkpoint['student_scheduler']) + t_scaler.load_state_dict(checkpoint['teacher_scaler']) + s_scaler.load_state_dict(checkpoint['student_scaler']) + model_load_state_dict(teacher_model, checkpoint['teacher_state_dict']) + if avg_student_model is not None: + model_load_state_dict(avg_student_model, checkpoint['avg_state_dict']) + + else: + if checkpoint['avg_state_dict'] is not None: + model_load_state_dict(student_model, checkpoint['avg_state_dict']) + else: + model_load_state_dict(student_model, checkpoint['student_state_dict']) + + logger.info(f"=> loaded checkpoint '{args.resume}' (step {checkpoint['step']})") + else: + logger.info(f"=> no checkpoint found at '{args.resume}'") + + if args.local_rank != -1: + teacher_model = nn.parallel.DistributedDataParallel( + teacher_model, device_ids=[args.local_rank], + output_device=args.local_rank, find_unused_parameters=True) + student_model = nn.parallel.DistributedDataParallel( + student_model, device_ids=[args.local_rank], + output_device=args.local_rank, find_unused_parameters=True) + + if args.finetune: + del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader + del s_scaler, s_scheduler, s_optimizer + finetune(args, finetune_dataset, test_loader, student_model, criterion) + return + + if args.evaluate: + del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader, labeled_loader + del s_scaler, s_scheduler, s_optimizer + evaluate(args, test_loader, student_model, criterion) + return + + teacher_model.zero_grad() + student_model.zero_grad() + train_loop(args, labeled_loader, unlabeled_loader, test_loader, finetune_dataset, + teacher_model, student_model, avg_student_model, criterion, + t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler, s_scaler) + return + + +if __name__ == '__main__': + main() diff --git a/TrainingTricks/meta_pseudo_label/models.py b/TrainingTricks/meta_pseudo_label/models.py new file mode 100644 index 00000000..3fffde6b --- /dev/null +++ b/TrainingTricks/meta_pseudo_label/models.py @@ -0,0 +1,157 @@ +from copy import deepcopy +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + + +class ModelEMA(nn.Module): + def __init__(self, model, decay=0.9999, device=None): + super().__init__() + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device + if self.device is not None: + self.module.to(device=device) + + def forward(self, input): + return self.module(input) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.parameters(), model.parameters()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + for ema_v, model_v in zip(self.module.buffers(), model.buffers()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(model_v) + + def update_parameters(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def state_dict(self): + return self.module.state_dict() + + def load_state_dict(self, state_dict): + self.module.load_state_dict(state_dict) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, out_planes, stride, dropout=0.0, activate_before_residual=False): + super(BasicBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) + self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) + self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.dropout = dropout + self.equalInOut = (in_planes == out_planes) + self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=stride, + padding=0, bias=False) or None + self.activate_before_residual = activate_before_residual + + def forward(self, x): + if not self.equalInOut and self.activate_before_residual is True: + x = self.relu1(self.bn1(x)) + else: + out = self.relu1(self.bn1(x)) + out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) + if self.dropout > 0: + out = F.dropout(out, p=self.dropout, training=self.training) + out = self.conv2(out) + return torch.add(x if self.equalInOut else self.convShortcut(x), out) + + +class NetworkBlock(nn.Module): + def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout=0.0, + activate_before_residual=False): + super(NetworkBlock, self).__init__() + self.layer = self._make_layer( + block, in_planes, out_planes, nb_layers, stride, dropout, activate_before_residual) + + def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout, + activate_before_residual): + layers = [] + for i in range(int(nb_layers)): + layers.append(block(i == 0 and in_planes or out_planes, out_planes, + i == 0 and stride or 1, dropout, activate_before_residual)) + return nn.Sequential(*layers) + + def forward(self, x): + return self.layer(x) + + +class WideResNet(nn.Module): + def __init__(self, num_classes, depth=28, widen_factor=2, dropout=0.0, dense_dropout=0.0): + super(WideResNet, self).__init__() + channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] + assert((depth - 4) % 6 == 0) + n = (depth - 4) / 6 + block = BasicBlock + # 1st conv before any network block + self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, + padding=1, bias=False) + # 1st block + self.block1 = NetworkBlock( + n, channels[0], channels[1], block, 1, dropout, activate_before_residual=True) + # 2nd block + self.block2 = NetworkBlock( + n, channels[1], channels[2], block, 2, dropout) + # 3rd block + self.block3 = NetworkBlock( + n, channels[2], channels[3], block, 2, dropout) + # global average pooling and classifier + self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001) + self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.drop = nn.Dropout(dense_dropout) + self.fc = nn.Linear(channels[3], num_classes) + self.channels = channels[3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, + mode='fan_out', + nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0.0) + + def forward(self, x): + out = self.conv1(x) + out = self.block1(out) + out = self.block2(out) + out = self.block3(out) + out = self.relu(self.bn1(out)) + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(-1, self.channels) + return self.fc(self.drop(out)) + + +def build_wideresnet(args): + if args.dataset == "cifar10": + depth, widen_factor = 28, 2 + elif args.dataset == 'cifar100': + depth, widen_factor = 28, 8 + + model = WideResNet(num_classes=args.num_classes, + depth=depth, + widen_factor=widen_factor, + dropout=0, + dense_dropout=args.dense_dropout) + if args.local_rank in [-1, 0]: + logger.info(f"Model: WideResNet {depth}x{widen_factor}") + logger.info(f"Total params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M") + return model diff --git a/TrainingTricks/meta_pseudo_label/utils.py b/TrainingTricks/meta_pseudo_label/utils.py new file mode 100644 index 00000000..af3bb407 --- /dev/null +++ b/TrainingTricks/meta_pseudo_label/utils.py @@ -0,0 +1,142 @@ +import logging +import os +import shutil +from collections import OrderedDict + +import torch +from torch import distributed as dist +from torch import nn +from torch.nn import functional as F + +logger = logging.getLogger(__name__) + + +def reduce_tensor(tensor, n): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= n + return rt + + +def create_loss_fn(args): + # if args.label_smoothing > 0: + # criterion = SmoothCrossEntropyV2(alpha=args.label_smoothing) + # else: + criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + return criterion.to(args.device) + + +def module_load_state_dict(model, state_dict): + try: + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + except: + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = f'module.{k}' # add `module.` + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + + +def model_load_state_dict(model, state_dict): + try: + model.load_state_dict(state_dict) + except: + module_load_state_dict(model, state_dict) + + +def save_checkpoint(args, state, is_best, finetune=False): + os.makedirs(args.save_path, exist_ok=True) + if finetune: + name = f'{args.name}_finetune' + else: + name = args.name + filename = f'{args.save_path}/{name}_last.pth.tar' + torch.save(state, filename, _use_new_zipfile_serialization=False) + if is_best: + shutil.copyfile(filename, f'{args.save_path}/{args.name}_best.pth.tar') + + +def accuracy(output, target, topk=(1,)): + output = output.to(torch.device('cpu')) + target = target.to(torch.device('cpu')) + maxk = max(topk) + batch_size = target.shape[0] + + _, idx = output.sort(dim=1, descending=True) + pred = idx.narrow(1, 0, maxk).t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class SmoothCrossEntropy(nn.Module): + def __init__(self, alpha=0.1): + super(SmoothCrossEntropy, self).__init__() + self.alpha = alpha + + def forward(self, logits, labels): + if self.alpha == 0: + loss = F.cross_entropy(logits, labels) + else: + num_classes = logits.shape[-1] + alpha_div_k = self.alpha / num_classes + target_probs = F.one_hot(labels, num_classes=num_classes).float() * \ + (1. - self.alpha) + alpha_div_k + loss = (-(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1)).mean() + return loss + + +class SmoothCrossEntropyV2(nn.Module): + """ + NLL loss with label smoothing. + """ + + def __init__(self, label_smoothing=0.1): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super().__init__() + assert label_smoothing < 1.0 + self.smoothing = label_smoothing + self.confidence = 1. - label_smoothing + + def forward(self, x, target): + if self.smoothing == 0: + loss = F.cross_entropy(x, target) + else: + logprobs = F.log_softmax(x, dim=-1) + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = (self.confidence * nll_loss + self.smoothing * smooth_loss).mean() + return loss + + +class AverageMeter(object): + """Computes and stores the average and current value + Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count