Skip to content

Commit

Permalink
feat: Add codes for meta pseudo labeling
Browse files Browse the repository at this point in the history
  • Loading branch information
YeonwooSung committed Mar 31, 2024
1 parent 2765815 commit c50c9cf
Show file tree
Hide file tree
Showing 8 changed files with 1,791 additions and 4 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ Free online book on Artificial Intelligence to help people learn AI easily.

- [Self-Supervised Learning](./SelfSupervisedLearning)

- [Training](./Training)
* [Adversarial ML](./Training/AdversarialML/)
* [Knowledge Distillation](./Training/KnowledgeDistillation/)
* [Transfer Learning](./Training/TransferLearning/)
- [TrainingTricks](./TrainingTricks)
* [Adversarial ML](./TrainingTricks/AdversarialML/)
* [Meta Pseudo Labeling](./TrainingTricks/meta_pseudo_label/)
* [Knowledge Distillation](./TrainingTricks/KnowledgeDistillation/)
* [Transfer Learning](./TrainingTricks/TransferLearning/)

- [XAI](./XAI)

Expand Down
148 changes: 148 additions & 0 deletions TrainingTricks/meta_pseudo_label/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
.vscode
wandb/
results/
*.npz
*.jpg
*.JPG
*.jpeg
*.JPEG
*.png
*.PNG
*.webp
*.WEBP
*.gif
*.GIF
*.zip
*.tar
checkpoint/
data/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
109 changes: 109 additions & 0 deletions TrainingTricks/meta_pseudo_label/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Meta Pseudo Labels
This is an unofficial PyTorch implementation of [Meta Pseudo Labels](https://arxiv.org/abs/2003.10580).
The official Tensorflow implementation is [here](https://github.com/google-research/google-research/tree/master/meta_pseudo_labels).


## Results

| | CIFAR-10-4K | SVHN-1K | ImageNet-10% |
|:---:|:---:|:---:|:---:|
| Paper (w/ finetune) | 96.11 ± 0.07 | 98.01 ± 0.07 | 73.89 |
| This code (w/o finetune) | 96.01 | - | - |
| This code (w/ finetune) | 96.08 | - | - |
| Acc. curve | [w/o finetune](https://tensorboard.dev/experiment/ehMVEk39SrGiqM43ye2c7w/)<br>[w/ finetune](https://tensorboard.dev/experiment/vbqR7dt2Q9aw6rf8yVu56g/) | - | - |

* February 2022, Retested.

## Usage

Train the model by 4000 labeled data of CIFAR-10 dataset:

```
python main.py \
--seed 2 \
--name cifar10-4K.2 \
--expand-labels \
--dataset cifar10 \
--num-classes 10 \
--num-labeled 4000 \
--total-steps 300000 \
--eval-step 1000 \
--randaug 2 16 \
--batch-size 128 \
--teacher_lr 0.05 \
--student_lr 0.05 \
--weight-decay 5e-4 \
--ema 0.995 \
--nesterov \
--mu 7 \
--label-smoothing 0.15 \
--temperature 0.7 \
--threshold 0.6 \
--lambda-u 8 \
--warmup-steps 5000 \
--uda-steps 5000 \
--student-wait-steps 3000 \
--teacher-dropout 0.2 \
--student-dropout 0.2 \
--finetune-epochs 625 \
--finetune-batch-size 512 \
--finetune-lr 3e-5 \
--finetune-weight-decay 0 \
--finetune-momentum 0.9 \
--amp
```

Train the model by 10000 labeled data of CIFAR-100 dataset by using DistributedDataParallel:
```
python -m torch.distributed.launch --nproc_per_node 4 main.py \
--seed 2 \
--name cifar100-10K.2 \
--dataset cifar100 \
--num-classes 100 \
--num-labeled 10000 \
--expand-labels \
--total-steps 300000 \
--eval-step 1000 \
--randaug 2 16 \
--batch-size 128 \
--teacher_lr 0.05 \
--student_lr 0.05 \
--weight-decay 5e-4 \
--ema 0.995 \
--nesterov \
--mu 7 \
--label-smoothing 0.15 \
--temperature 0.7 \
--threshold 0.6 \
--lambda-u 8 \
--warmup-steps 5000 \
--uda-steps 5000 \
--student-wait-steps 3000 \
--teacher-dropout 0.2 \
--student-dropout 0.2 \
--finetune-epochs 250 \
--finetune-batch-size 512 \
--finetune-lr 3e-5 \
--finetune-weight-decay 0 \
--finetune-momentum 0.9 \
--amp
```

Monitoring training progress

tensorboard
```
tensorboard --logdir results
```
or

Use wandb

## Requirements
- python 3.6+
- torch 1.7+
- torchvision 0.8+
- tensorboard
- wandb
- numpy
- tqdm
Loading

0 comments on commit c50c9cf

Please sign in to comment.