Skip to content

Latest commit

 

History

History
56 lines (46 loc) · 1.8 KB

README.md

File metadata and controls

56 lines (46 loc) · 1.8 KB

ResNet_CIFAR_KD

Prerequisites

  • Python 3.6+
  • PyTorch 1.0+
  • CUDA 11.0+
  • Windows/Linux OS for PyTorch GPU

Installation

Python

Create and activate a new pip environment

python3 -m venv project
source pytorch/bin/activate

Install PyTorch for pip environment

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

Anaconda

Install PyTorch GPU in a new conda environment

conda create -n project pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

Teacher models:

Densenet

DLA

Structure of the repo:

-> checkpoint_teacher : stores the teacher models checkpoints

-> checkpoint : stores the resnet student checkpoint model

-> teacher_models : stores the teacher model specs

Training


# [OPTIONAL:] Train the teacher: 
python train_teacher.py --teacher "densenet"
python train_teacher.py --teacher "dla"

or you can just use the checkpoints provided already :) 

# You can train the student directly with: 
python train_student.py --teacher "densenet"
python train_student.py --teacher "dla"

The entire structure of our Resnets architecture

Neutron 1 Neutron 2 Neutron 3 Neutron 4 Neutron 5 Neutron 6