-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dataset mixture at the shard level #425
Conversation
so, these sorts of changes are pretty high risk re changing train behaviour, I'm not updating current code on cluster because last one got merged earlier than I would have liked (w/o verification of value). Are we sure this is correct? It's still not equivalent to #107 which is sampling per local batch from the same dataset (which I think is probably a bit more desired), if sampling across datasets at the sample level is not as good, then the batch at the transition between two shards is maybe not as good, and then the mixing local batches from diff datasets into a global batch is not as good as having a global batch all from one dataset (that last one not easily solvable).... Either way, my point is doing experiments and figuring out what the best approach is not the best to do on main for this sort of change... |
Why these changes are (subtle) trouble and need lots of testing. This change will alter the sample progression for same seed in training. https://docs.python.org/dev/library/random.html#random.choices: For a given seed, the choices() function with equal weighting typically produces a different sequence than repeated calls to choice(). The algorithm used by choices() uses floating point arithmetic for internal consistency and speed. The algorithm used by choice() defaults to integer arithmetic with repeated selections to avoid small biases from round-off error. |
Thanks for the comments! I agree with you that we should be careful about these changes. I also agree that this is not equivalent to #107 (it is not intended to be). Maybe we could implement it, but I think this should be in a separate PR (this PR deals with upsampling/downsampling data sources, we could have a separate one to give the option to have global batches being formed by only one data source). Re. correctness, I've tested both #398 and this PR on standard runs (without using the new flag), and observed virtually no difference in performance when training on CC12M. I'm happy to do other training runs you think might be better to gauge the impact of this PR. This code has also been tested on two other experiments where the upsampling weights are integers. The new code yielded the same result as copying the shards that are being upsampled multiple times in the input string. |
Re. random seeds, indeed currently the code won't be exactly the same as before when using the same random seed because of the difference between |
@gabrielilharco yeah, that was actually the review comments, but again, always forget to hit the final button to make the review active. If no weights are used, self.weights should be None. If weights are None, use rng.choice otherwise use choices with the weight |
Great, I'll make the change and test it with another CC12M run and will let you know once it's done |
oh yeah, one other comment... does this actually sample dataset A 2x more frequently (in samples) than B? wouldn't the mix of samples seen from the datasets change significantly based on the shard composition (# samples per shard) for the two datasets and be more complicated than just 1::2? EDIT looping back to the original motiviation to switch to this approach from the per-sample, when you compared them, was there verification that the ratios of the samples seen (across the datasets) was the same? |
In this PR 1::2 weights for datasets A::B does not mean that B will be sampled twice as often as A, it means that we will sample from B 2x more often than normal (and 1x for A). This is different from the previous PR, where a 1::1 meant sampling from both datasets with equal frequencies. In this PR 1::1 is equal to not passing the new flag (in expectation we sample proportionally to the size of the datasets). I took this into account when comparing the PRs |
1::1 made intuitive sense for the previous approach though, selecting per sample or per batch from 1 of N datasets with those ratios you know what you get. Mixing across datasets like this though, the end mix depends on how the datasets are sharded, and that changes per instance, if you mix say ImageNet-22k and LAION-2B on two different clusters you'd end up with different ratios of samples from the same 'input' weight which seems rather non-obvious and confusing. |
It should depend only on the dataset sizes (not the shard sizes) in expectation. With this approach sampling with equal frequency from the difference sources requires knowing the sizes of the datasets though (i.e. if dataset A has size 10 and dataset B has size 100, using 10::1 would lead to seeing each dataset with equal frequency in expectation). We could add another flag for specifying the sizes, but that seems a bit messy to me. Maybe we can change the flag name to avoid confusion? E.g. |
We now use I also added some new tests at |
@gabrielilharco right, I convinced myself that changing the ratio between # of shards and samples per shard for same ds sample count does not alter the overall sampling frequency per dataset |
@gabrielilharco tests are great to have, thanks!! |
Following #398, this replaces the logic for mixing datasets from the sample level to the shard level.
Empirically this is giving better results than the previous logic at the sample level, and also simplifies the code. On two experiments I tried when mixing 6 data sources (60.1% vs 59.5% and 61.6% vs 59.8% ImageNet accuracy respectively for shard level vs sample level). This is in line with findings from #107, https://arxiv.org/abs/2112.09331.
I also tested this with the standard workflow (i.e. no
--train-data-weights
is used), and did not observe any impact on the standard workflow when training on CC12M (25.04% vs 24.95% zero-shot ImageNet accuracy).CC @rwightman @rom1504