Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support a sampling strategy for multiple training datasets #107

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

zerovl
Copy link

@zerovl zerovl commented Jun 9, 2022

Proposing the debiased sampling method proposed in the ZeroVL paper. When training multiple datasets, the debiased sampling improves the accuracy of CLIP model. It includes a new flag:

  • --debias-sample, a boolean flag that enables the debiased sampling method

Introduction of Debiased Sampling

image
As shown in Fig2, random sampling is the most intuitive sampling method, which randomly constructs training batches with all available data. However, as shown in Fig3, random sampling leads to biased feature distributions on both image and text modalities.
Debiased sampling ensures instances within each batch come from the same dataset. Training with debiased sampling improves the quality of learned representations, and contributes to better results on many downstream tasks.

Experiments on sampling methods

We use two datasets, CC3M and SBU, to show the improvements of debiased sampling.

Experiment1: Random Sampling

1. Setting & Acc

dataset: CC3M + SBU (2.79M + 0.86M)
batchsize: 2048 (256 per GPU, 8 V100 32GB)
learning rate: 1e-3
weight decay: 0.1
sampling: random
zero-shot acc on ImageNet: top1 21.36, top5 40.98

2. Training script

torchrun --nproc_per_node 8 -m training.main 
    --train-data "/data/cc3m/cc3m_sbu_train_anno.csv" \
    --dataset-type auto \
    --batch-size 256 \
    --precision amp \
    --workers 4 \
    --imagenet-val "/data/ILSVRC/Data/CLS-LOC/val" \
    --csv-separator , \
    --lr=1e-3 \
    --wd=0.1

'/data/cc3m/cc3m_sbu_train_anno.csv' contains all samples from CC3M and SBU.

3. Log

cc3m+sbu+random_sample.log

Experiment2: Debiased Sampling

1. Setting

dataset: CC3M + SBU (2.79M + 0.86M)
batchsize: 2048 (256 per GPU, 8 V100 32GB)
learning rate: 1e-3
weight decay: 0.1
sampling: debias
zero-shot acc on ImageNet: top1 22.33, top5 42.29

2. Training script

torchrun --nproc_per_node 8 -m training.main 
    --train-data "/data/cc3m/cc3m_train_anno.csv, /data/sbu/sbu_train_anno.csv" \
    --dataset-type auto \
    --batch-size 256 \
    --precision amp \
    --workers 4 \
    --imagenet-val "/data/ILSVRC/Data/CLS-LOC/val" \
    --csv-separator , \
    --lr=1e-3 \
    --wd=0.1 \
    --debias-sample

3. Log

cc3m+sbu+debias_sample.log

@zerovl zerovl changed the title support a sampling strategy for multiple training datasets Support a sampling strategy for multiple training datasets Jun 9, 2022
@rwightman
Copy link
Collaborator

rwightman commented Jun 10, 2022

@zerovl thanks, couldn't this logic be placed in a dataset wrapper so we don't have repeat the train loop and incur more long term maintenance? Either one that covers both csv & wds or a separate one for each, that handles all of the length calcs, sampling, etc internally for each batch grabbed ...

@zerovl
Copy link
Author

zerovl commented Jun 11, 2022

@rwightman Thanks for replying. We agree with that the logic should be placed in a dataset wrapper. We are working on implementing it, and making sure experiment results are correct.

@zerovl
Copy link
Author

zerovl commented Jun 15, 2022

@rwightman hi, the implementation is done, and the log is attached. It seems that results are almost the same with the former version. Would you check the code when you are available?
new_cc3m+sbu+debias_sample.log

@rwightman
Copy link
Collaborator

@zerovl thanks for updating this and your other PR, I'll try to find some time to take a closer look next week.

@zerovl
Copy link
Author

zerovl commented Jun 19, 2022

@rwightman thanks for your time. I am willing to discuss about implementation details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants