From 3fc69fc931cb31e84fd70536df998af7c2bda763 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 8 Jun 2022 15:50:02 -0700 Subject: [PATCH 1/9] Add posembed resizing for fine-tune (native ViT models), support alternate image pre-proc resize and mean/std, add WDS resampling and more robust epoch synchronization --- src/open_clip/factory.py | 18 +++-- src/open_clip/model.py | 44 ++++++++++-- src/open_clip/transform.py | 59 +++++++++++++-- src/open_clip/utils.py | 21 +++++- src/training/data.py | 143 +++++++++++++++++++++++++++++++------ src/training/params.py | 8 ++- src/training/train.py | 5 +- 7 files changed, 259 insertions(+), 39 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 5cb71c8c2..fdf2b571a 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -5,10 +5,11 @@ import re from copy import deepcopy from pathlib import Path +from typing import Optional, Tuple import torch -from .model import CLIP, convert_weights_to_fp16 +from .model import CLIP, convert_weights_to_fp16, resize_pos_embed from .openai import load_openai_model from .pretrained import get_pretrained_url, download_pretrained from .transform import image_transform @@ -57,6 +58,13 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): return state_dict +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + def create_model( model_name: str, pretrained: str = '', @@ -105,7 +113,7 @@ def create_model( if checkpoint_path: logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') - model.load_state_dict(load_state_dict(checkpoint_path)) + load_checkpoint(model, checkpoint_path) else: logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.') raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.') @@ -129,13 +137,15 @@ def create_model_and_transforms( jit: bool = False, force_quick_gelu: bool = False, pretrained_image: bool = False, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, ): model = create_model( model_name, pretrained, precision, device, jit, force_quick_gelu=force_quick_gelu, pretrained_image=pretrained_image) - preprocess_train = image_transform(model.visual.image_size, is_train=True) - preprocess_val = image_transform(model.visual.image_size, is_train=False) + preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=mean, std=std) + preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=mean, std=std) return model, preprocess_train, preprocess_val diff --git a/src/open_clip/model.py b/src/open_clip/model.py index de6f7d262..9d427b9ca 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -2,9 +2,10 @@ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ - from collections import OrderedDict from dataclasses import dataclass +import logging +import math from typing import Tuple, Union, Callable, Optional import numpy as np @@ -14,7 +15,7 @@ from torch.utils.checkpoint import checkpoint from .timm_model import TimmModel -from .utils import freeze_batch_norm_2d +from .utils import freeze_batch_norm_2d, to_2tuple class Bottleneck(nn.Module): @@ -255,13 +256,15 @@ def __init__( self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, output_dim: int, act_layer: Callable = nn.GELU): super().__init__() - self.image_size = image_size + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) self.output_dim = output_dim self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn((image_size // patch_size) ** 2 + 1, width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) self.ln_pre = LayerNorm(width) self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer) @@ -556,3 +559,36 @@ def trace_model(model, batch_size=256, device=torch.device('cpu')): )) model.visual.image_size = image_size return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 7014c926f..ed5752758 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -1,7 +1,40 @@ +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + + from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ CenterCrop +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + def _convert_to_rgb(image): return image.convert('RGB') @@ -9,9 +42,17 @@ def _convert_to_rgb(image): def image_transform( image_size: int, is_train: bool, - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711) + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, ): + mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean + std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + normalize = Normalize(mean=mean, std=std) if is_train: return Compose([ @@ -21,10 +62,18 @@ def image_transform( normalize, ]) else: - return Compose([ - Resize(image_size, interpolation=InterpolationMode.BICUBIC), - CenterCrop(image_size), + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ _convert_to_rgb, ToTensor(), normalize, ]) + return Compose(transforms) diff --git a/src/open_clip/utils.py b/src/open_clip/utils.py index bf8221266..51e80c5e2 100644 --- a/src/open_clip/utils.py +++ b/src/open_clip/utils.py @@ -1,3 +1,6 @@ +from itertools import repeat +import collections.abc + from torch import nn as nn from torchvision.ops.misc import FrozenBatchNorm2d @@ -38,4 +41,20 @@ def freeze_batch_norm_2d(module, module_match={}, name=''): new_child = freeze_batch_norm_2d(child, module_match, full_child_name) if new_child is not child: res.add_module(child_name, new_child) - return res \ No newline at end of file + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) diff --git a/src/training/data.py b/src/training/data.py index 622e6207f..6cce8b867 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -4,7 +4,10 @@ import math import os import random +import sys +import time from dataclasses import dataclass +from multiprocessing import Value import braceexpand import numpy as np @@ -12,11 +15,11 @@ import torch import torchvision.datasets as datasets import webdataset as wds -from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample from PIL import Image -from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset from torch.utils.data.distributed import DistributedSampler - +from webdataset.filters import _shuffle +from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample try: import horovod.torch as hvd @@ -45,10 +48,28 @@ def __getitem__(self, idx): return images, texts +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + + @dataclass class DataInfo: dataloader: DataLoader - sampler: DistributedSampler + sampler: DistributedSampler = None + shared_epoch: SharedEpoch = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) def preprocess_txt(text): @@ -119,7 +140,7 @@ def get_imagenet(args, preprocess_fns, split): sampler=sampler, ) - return DataInfo(dataloader, sampler) + return DataInfo(dataloader=dataloader, sampler=sampler) def count_samples(dataloader): @@ -184,9 +205,83 @@ def tarfile_to_samples_nothrow(src, handler=log_and_continue): _SAMPLE_SHUFFLE_INITIAL = 1000 -def get_wds_dataset(args, preprocess_img, is_train, epoch=0): +class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1 + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + rng = random.Random() + rng.seed((self.seed, epoch)) + print(f'detshuffle epoch: {epoch}, seed: {self.seed}') # FIXME temporary debug print + return _shuffle(src, self.bufsize, self.initial, rng) + + +class ResampledShards2(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + epoch=-1, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = wds.shardlists.expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.rng = random.Random() + self.worker_seed = ( + wds.utils.pytorch_worker_seed if worker_seed is None else worker_seed + ) + self.deterministic = deterministic + self.epoch = epoch + + def __iter__(self): + """Return an iterator over the shards.""" + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + if self.deterministic: + seed = (self.worker_seed(), epoch) + else: + seed = (self.worker_seed(), epoch, os.getpid(), time.time()) + self.rng.seed(seed) + print(f'resampled epoch: {epoch}, seed: {seed}') # FIXME temporary debug print + for _ in range(self.nshards): + yield dict(url=self.rng.choice(self.urls)) + + +def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): input_shards = args.train_data if is_train else args.val_data assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train num_samples, num_shards = get_dataset_size(input_shards) if not num_samples: @@ -199,18 +294,26 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0): else: num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified - pipeline = [wds.SimpleShardList(input_shards)] + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + if resampled: + pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)] + else: + pipeline = [wds.SimpleShardList(input_shards)] + # at this point we have an iterator over all the shards if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) pipeline.extend([ - wds.detshuffle( - bufsize=_SHARD_SHUFFLE_SIZE, - initial=_SHARD_SHUFFLE_INITIAL, - seed=args.seed, - epoch=epoch - 1, - ), - wds.split_by_node, - wds.split_by_worker, # at this point, we have an iterator over the shards assigned to each worker at each node tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), wds.shuffle( @@ -218,7 +321,6 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0): initial=_SAMPLE_SHUFFLE_INITIAL, rng=random.Random(args.seed), ), - #wds.repeatedly, # FIXME determine if this is beneficial ]) else: pipeline.extend([ @@ -238,10 +340,11 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0): dataset = wds.DataPipeline(*pipeline) if is_train: # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size * args.world_size - num_batches = math.ceil(num_samples / global_batch_size) + num_batches = round_fn(num_samples / global_batch_size) num_workers = max(1, args.workers) - num_worker_batches = math.ceil(num_batches / num_workers) # per dataloader worker + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker num_batches = num_worker_batches * num_workers num_samples = num_batches * global_batch_size dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this @@ -254,8 +357,6 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0): batch_size=None, shuffle=False, num_workers=args.workers, - # FIXME detshuffle uses same seed each epoch unless workers are persistent - # this seems like a WDS bug, currently waiting for clarification persistent_workers=True, ) @@ -277,7 +378,7 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0): dataloader.num_batches = num_batches dataloader.num_samples = num_samples - return DataInfo(dataloader, None) + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) def get_csv_dataset(args, preprocess_fn, is_train, epoch=0): diff --git a/src/training/params.py b/src/training/params.py index ef2b0990a..721bfd10b 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -42,6 +42,12 @@ def parse_args(): default="auto", help="Which type of dataset to process." ) + parser.add_argument( + "--dataset-resampled", + default=False, + action="store_true", + help="Whether to use sampling with replacement for webdataset shard selection." + ) parser.add_argument( "--csv-separator", type=str, @@ -91,7 +97,7 @@ def parse_args(): help="Optional identifier for the experiment when storing logs. Otherwise use current time.", ) parser.add_argument( - "--workers", type=int, default=1, help="Number of workers per GPU." + "--workers", type=int, default=1, help="Number of dataloader workers per GPU." ) parser.add_argument( "--batch-size", type=int, default=64, help="Batch size per GPU." diff --git a/src/training/train.py b/src/training/train.py index 71c3dc285..cda429a21 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -57,9 +57,8 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w world_size=args.world_size, use_horovod=args.horovod) - dataloader, sampler = data['train'].dataloader, data['train'].sampler - if args.distributed and sampler is not None: - sampler.set_epoch(epoch) + data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch + dataloader = data['train'].dataloader num_batches_per_epoch = dataloader.num_batches sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) From c4897d7f70b9b4be045015e1d0fbda17ac13dfd2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 8 Jun 2022 17:01:08 -0700 Subject: [PATCH 2/9] Add assert to ensure number of shards in dataset >= workers * world_size --- src/training/data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index 6cce8b867..7e6774ede 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -90,9 +90,10 @@ def get_dataset_size(shards): else: total_size = None # num samples undefined # some common dataset sizes (at time of authors last download) - # cc3m-train: 2905954 - # cc12m: 10968539 - # LAION-400m: 407332084 + # CC3M (train): 2905954 + # CC12M: 10968539 + # LAION-400M: 407332084 + # LAION-2B (english): 2170337258 num_shards = len(shards_list) return total_size, num_shards @@ -339,6 +340,8 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): dataset = wds.DataPipeline(*pipeline) if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' # roll over and repeat a few samples to get same number of full batches on each node round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size * args.world_size From 610c9c6e446626d2420274269079e4c864f4f817 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 8 Jun 2022 21:58:07 -0700 Subject: [PATCH 3/9] Add support for per-worker seed in detshuffle2 for some tests, remove explicit seed of sample shuffle --- src/training/data.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index 7e6774ede..a4950cd1b 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -212,7 +212,7 @@ def __init__( bufsize=1000, initial=100, seed=0, - epoch=-1 + epoch=-1, ): self.bufsize = bufsize self.initial = initial @@ -228,8 +228,12 @@ def run(self, src): self.epoch += 1 epoch = self.epoch rng = random.Random() - rng.seed((self.seed, epoch)) - print(f'detshuffle epoch: {epoch}, seed: {self.seed}') # FIXME temporary debug print + if self.seed < 0: + seed = (wds.utils.pytorch_worker_seed(), epoch) + else: + seed = (self.seed, epoch) + rng.seed(seed) + print(f'detshuffle epoch: {epoch}, seed: {seed}') # FIXME temporary debug print return _shuffle(src, self.bufsize, self.initial, rng) @@ -254,9 +258,7 @@ def __init__( assert isinstance(self.urls[0], str) self.nshards = nshards self.rng = random.Random() - self.worker_seed = ( - wds.utils.pytorch_worker_seed if worker_seed is None else worker_seed - ) + self.worker_seed = wds.utils.pytorch_worker_seed if worker_seed is None else worker_seed self.deterministic = deterministic self.epoch = epoch @@ -320,7 +322,6 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): wds.shuffle( bufsize=_SAMPLE_SHUFFLE_SIZE, initial=_SAMPLE_SHUFFLE_INITIAL, - rng=random.Random(args.seed), ), ]) else: From df48eab8d1f5ddc66536ced344ff469cc1e391ce Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 8 Jun 2022 22:26:17 -0700 Subject: [PATCH 4/9] More seed details --- src/training/data.py | 16 +++++++++++++--- src/training/main.py | 2 ++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index a4950cd1b..33aed4f3b 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -16,7 +16,7 @@ import torchvision.datasets as datasets import webdataset as wds from PIL import Image -from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info from torch.utils.data.distributed import DistributedSampler from webdataset.filters import _shuffle from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample @@ -200,6 +200,16 @@ def tarfile_to_samples_nothrow(src, handler=log_and_continue): return samples +def pytorch_worker_seed(): + """get dataloader worker seed from pytorch""" + worker_info = get_worker_info() + if worker_info is not None: + # favour the seed already created for pytorch dataloader workers if it exists + return worker_info.seed + # fallback to wds rank based seed + return wds.utils.pytorch_worker_seed() + + _SHARD_SHUFFLE_SIZE = 2000 _SHARD_SHUFFLE_INITIAL = 500 _SAMPLE_SHUFFLE_SIZE = 5000 @@ -229,7 +239,7 @@ def run(self, src): epoch = self.epoch rng = random.Random() if self.seed < 0: - seed = (wds.utils.pytorch_worker_seed(), epoch) + seed = (pytorch_worker_seed(), epoch) else: seed = (self.seed, epoch) rng.seed(seed) @@ -258,7 +268,7 @@ def __init__( assert isinstance(self.urls[0], str) self.nshards = nshards self.rng = random.Random() - self.worker_seed = wds.utils.pytorch_worker_seed if worker_seed is None else worker_seed + self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed self.deterministic = deterministic self.epoch = epoch diff --git a/src/training/main.py b/src/training/main.py index ef470ad44..216c367ff 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -112,6 +112,7 @@ def main(): else: logging.info(f'Running with a single process. Device {args.device}.') + random_seed(args.seed, 0) model, preprocess_train, preprocess_val = create_model_and_transforms( args.model, args.pretrained, @@ -121,6 +122,7 @@ def main(): force_quick_gelu=args.force_quick_gelu, pretrained_image=args.pretrained_image, ) + random_seed(args.seed, args.rank) if args.trace: model = trace_model(model, batch_size=args.batch_size, device=device) From 23726f28211f712faa6047e44ccfd8d69a59add9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 9 Jun 2022 11:45:10 -0700 Subject: [PATCH 5/9] Change default seed to match latest changes (and closer match past behaviour) --- src/training/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/training/params.py b/src/training/params.py index 721bfd10b..fdc7aed9e 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -279,7 +279,7 @@ def parse_args(): help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." ) parser.add_argument( - "--seed", type=int, default=4242, help="Default random seed." + "--seed", type=int, default=0, help="Default random seed." ) args = parser.parse_args() From f1945f6628a78be1a121e9c43cd1d5b5388d82ac Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 Jun 2022 17:24:38 -0700 Subject: [PATCH 6/9] Mind blown, wds was not even seeding properly in the first place. Switching to integer instead of tuple for deterministic seeds --- src/training/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index 33aed4f3b..2dbff2dfe 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -239,9 +239,9 @@ def run(self, src): epoch = self.epoch rng = random.Random() if self.seed < 0: - seed = (pytorch_worker_seed(), epoch) + seed = pytorch_worker_seed() + epoch else: - seed = (self.seed, epoch) + seed = self.seed + epoch rng.seed(seed) print(f'detshuffle epoch: {epoch}, seed: {seed}') # FIXME temporary debug print return _shuffle(src, self.bufsize, self.initial, rng) @@ -282,7 +282,7 @@ def __iter__(self): self.epoch += 1 epoch = self.epoch if self.deterministic: - seed = (self.worker_seed(), epoch) + seed = self.worker_seed() + epoch else: seed = (self.worker_seed(), epoch, os.getpid(), time.time()) self.rng.seed(seed) From 03dac1aa4abce87be0d0b129c5747963ab9c073a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 Jun 2022 21:37:37 -0700 Subject: [PATCH 7/9] Another silly seed thing, no point in a poor seed for non-deterministic shard sampling, only reset seed if deterministic --- src/training/data.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index 2dbff2dfe..755e6f4e8 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -282,11 +282,8 @@ def __iter__(self): self.epoch += 1 epoch = self.epoch if self.deterministic: - seed = self.worker_seed() + epoch - else: - seed = (self.worker_seed(), epoch, os.getpid(), time.time()) - self.rng.seed(seed) - print(f'resampled epoch: {epoch}, seed: {seed}') # FIXME temporary debug print + # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed + self.rng.seed(self.worker_seed() + epoch) for _ in range(self.nshards): yield dict(url=self.rng.choice(self.urls)) From e390448220ea570c43b648269d6e1a27c68eea18 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 13 Jun 2022 16:23:15 -0700 Subject: [PATCH 8/9] Remvoe debug print from detshuffle2 --- src/training/data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/training/data.py b/src/training/data.py index 755e6f4e8..23ff21400 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -243,7 +243,6 @@ def run(self, src): else: seed = self.seed + epoch rng.seed(seed) - print(f'detshuffle epoch: {epoch}, seed: {seed}') # FIXME temporary debug print return _shuffle(src, self.bufsize, self.initial, rng) From 9078293a67501d9489b22b8ee5b6c10b061a6b54 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 12:10:07 -0700 Subject: [PATCH 9/9] Fix #77, guard torch.distributed imports so it won't break on Windows --- src/open_clip/loss.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 191096644..de31426df 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -1,8 +1,14 @@ import torch -import torch.distributed.nn -from torch import distributed as dist, nn as nn +import torch.nn as nn from torch.nn import functional as F +try: + import torch.distributed.nn + from torch import distributed as dist + has_distributed = True +except ImportError: + has_distributed = False + try: import horovod.torch as hvd except ImportError: @@ -18,6 +24,7 @@ def gather_features( world_size=1, use_horovod=False ): + assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' if use_horovod: assert hvd is not None, 'Please install horovod' if gather_with_grad: