From 25bf954ff1e7db9dc8b5344a061c671cb849c60a Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 24 Sep 2022 15:05:59 +0200 Subject: [PATCH 1/5] support distributed validation for zero-shot and retrieval --- src/training/data.py | 16 +++++++++---- src/training/params.py | 7 ++++++ src/training/train.py | 47 ++++++++++++++++++++++++++------------- src/training/zero_shot.py | 24 ++++++++++++++------ 4 files changed, 68 insertions(+), 26 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index 23ff21400..e3f9b4c03 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -132,7 +132,7 @@ def get_imagenet(args, preprocess_fns, split): idxs = idxs.astype('int') sampler = SubsetRandomSampler(np.where(idxs)[0]) else: - sampler = None + sampler = torch.utils.data.distributed.DistributedSampler(dataset) if args.distributed_evaluation else None dataloader = torch.utils.data.DataLoader( dataset, @@ -331,6 +331,8 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): ), ]) else: + if args.distributed_evaluation: + pipeline.append(wds.split_by_node) pipeline.extend([ wds.split_by_worker, # at this point, we have an iterator over the shards assigned to each worker @@ -359,8 +361,10 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): num_samples = num_batches * global_batch_size dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this else: - # last batches are partial, eval is done on single (master) node - num_batches = math.ceil(num_samples / args.batch_size) + if args.distributed_evaluation: + num_batches = math.ceil(num_samples / (args.batch_size * args.world_size)) + else: + num_batches = math.ceil(num_samples / args.batch_size) dataloader = wds.WebLoader( dataset, @@ -401,7 +405,11 @@ def get_csv_dataset(args, preprocess_fn, is_train, epoch=0): caption_key=args.csv_caption_key, sep=args.csv_separator) num_samples = len(dataset) - sampler = DistributedSampler(dataset) if args.distributed and is_train else None + if is_train: + sampler = DistributedSampler(dataset) if args.distributed else None + else: + sampler = DistributedSampler(dataset) if args.distributed_evaluation else None + shuffle = is_train and sampler is None dataloader = DataLoader( diff --git a/src/training/params.py b/src/training/params.py index 41dfb08c5..4b2d07d94 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -78,6 +78,13 @@ def parse_args(): default=None, help="Path to imagenet v2 for conducting zero shot evaluation.", ) + parser.add_argument( + "--distributed-evaluation", + action="store_true", + default=False, + help="Do evaluation in a distributed way.", + ) + parser.add_argument( "--logs", type=str, diff --git a/src/training/train.py b/src/training/train.py index 028fa31d6..5520557d6 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -152,7 +152,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w def evaluate(model, data, epoch, args, tb_writer=None): metrics = {} - if not is_master(args): + if not is_master(args) and not args.distributed_evaluation: return metrics device = torch.device(args.device) model.eval() @@ -197,17 +197,27 @@ def evaluate(model, data, epoch, args, tb_writer=None): cumulative_loss += total_loss * batch_size num_samples += batch_size - if is_master(args) and (i % 100) == 0: - logging.info( - f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" - f"Loss: {cumulative_loss / num_samples:.6f}\t") - + if (i % 100) == 0: + if args.distributed_evaluation: + cumulative_loss_all_reduced = cumulative_loss.clone() + torch.distributed.all_reduce(cumulative_loss_all_reduced, op=torch.distributed.ReduceOp.AVG) + else: + cumulative_loss_all_reduced = cumulative_loss + if is_master(args): + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" + f"Loss: {cumulative_loss_all_reduced / num_samples:.6f}\t") + if args.distributed_evaluation: + image_features = all_gather(torch.cat(all_image_features)) + text_features = all_gather(torch.cat(all_text_features)) val_metrics = get_metrics( - image_features=torch.cat(all_image_features), - text_features=torch.cat(all_text_features), + image_features=image_features, + text_features=text_features, logit_scale=logit_scale.cpu(), ) loss = cumulative_loss / num_samples + if args.distributed_evaluation: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) metrics.update( {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} ) @@ -215,12 +225,13 @@ def evaluate(model, data, epoch, args, tb_writer=None): if not metrics: return metrics - logging.info( - f"Eval Epoch: {epoch} " - + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) - ) + if is_master(args): + logging.info( + f"Eval Epoch: {epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) - if args.save_logs: + if args.save_logs and is_master(args): for name, val in metrics.items(): if tb_writer is not None: tb_writer.add_scalar(f"val/{name}", val, epoch) @@ -229,7 +240,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): f.write(json.dumps(metrics)) f.write("\n") - if args.wandb: + if args.wandb and is_master(args): assert wandb is not None, 'Please install wandb.' for name, val in metrics.items(): wandb.log({f"val/{name}": val, 'epoch': epoch}) @@ -241,7 +252,7 @@ def get_metrics(image_features, text_features, logit_scale): metrics = {} logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() - + logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} ground_truth = torch.arange(len(text_features)).view(-1, 1) @@ -255,3 +266,9 @@ def get_metrics(image_features, text_features, logit_scale): metrics[f"{name}_R@{k}"] = np.mean(preds < k) return metrics + +def all_gather(tensor): + world_size = torch.distributed.get_world_size() + tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, tensor, async_op=False) + return torch.cat(tensor_list, dim=0) \ No newline at end of file diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 4bfe7b861..64bdba294 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -3,16 +3,17 @@ import torch import torch.nn.functional as F from tqdm import tqdm - +from training.distributed import is_master from open_clip import tokenize from .precision import get_autocast from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template def zero_shot_classifier(model, classnames, templates, args): + classnames = tqdm(classnames) if is_master(args) else classnames with torch.no_grad(): zeroshot_weights = [] - for classname in tqdm(classnames): + for classname in classnames: texts = [template(classname) for template in templates] # format with class texts = tokenize(texts).to(args.device) # tokenize if args.distributed and not args.horovod: @@ -34,9 +35,10 @@ def accuracy(output, target, topk=(1,)): def run(model, classifier, dataloader, args): autocast = get_autocast(args.precision) + dataloader = tqdm(dataloader, unit_scale=args.batch_size) if is_master(args) else dataloader with torch.no_grad(): top1, top5, n = 0., 0., 0. - for images, target in tqdm(dataloader, unit_scale=args.batch_size): + for images, target in dataloader: images = images.to(args.device) target = target.to(args.device) @@ -57,6 +59,10 @@ def run(model, classifier, dataloader, args): top1 = (top1 / n) top5 = (top5 / n) + if args.distributed_evaluation: + top = torch.Tensor([top1, top5]).to(args.device) + torch.distributed.all_reduce(top, op=torch.distributed.ReduceOp.AVG) + top1, top5 = top.tolist() return top1, top5 @@ -68,12 +74,15 @@ def zero_shot_eval(model, data, epoch, args): if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: return {} - logging.info('Starting zero-shot imagenet.') + if is_master(args): + logging.info('Starting zero-shot imagenet.') + logging.info('Building zero-shot classifier') - logging.info('Building zero-shot classifier') classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) - logging.info('Using classifier') + if is_master(args): + logging.info('Using classifier') + results = {} if 'imagenet-val' in data: top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) @@ -84,6 +93,7 @@ def zero_shot_eval(model, data, epoch, args): results['imagenetv2-zeroshot-val-top1'] = top1 results['imagenetv2-zeroshot-val-top5'] = top5 - logging.info('Finished zero-shot imagenet.') + if is_master(args): + logging.info('Finished zero-shot imagenet.') return results From 9b26fefc5a80844ccdf199ef502b026d7f09762e Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 24 Sep 2022 17:29:37 +0200 Subject: [PATCH 2/5] update with distributed evaluation for retrieval --- src/training/train.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 5520557d6..40e729514 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -207,13 +207,11 @@ def evaluate(model, data, epoch, args, tb_writer=None): logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" f"Loss: {cumulative_loss_all_reduced / num_samples:.6f}\t") - if args.distributed_evaluation: - image_features = all_gather(torch.cat(all_image_features)) - text_features = all_gather(torch.cat(all_text_features)) val_metrics = get_metrics( - image_features=image_features, - text_features=text_features, + image_features=torch.cat(all_image_features), + text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), + args=args, ) loss = cumulative_loss / num_samples if args.distributed_evaluation: @@ -248,11 +246,13 @@ def evaluate(model, data, epoch, args, tb_writer=None): return metrics -def get_metrics(image_features, text_features, logit_scale): +def get_metrics(image_features, text_features, logit_scale, args): metrics = {} + if args.distributed_evaluation: + image_features = all_gather(image_features.to(args.device)).cpu() + text_features = all_gather(text_features.to(args.device)).cpu() logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() - logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} ground_truth = torch.arange(len(text_features)).view(-1, 1) @@ -270,5 +270,5 @@ def get_metrics(image_features, text_features, logit_scale): def all_gather(tensor): world_size = torch.distributed.get_world_size() tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] - torch.distributed.all_gather(tensor_list, tensor, async_op=False) + torch.distributed.all_gather(tensor_list, tensor) return torch.cat(tensor_list, dim=0) \ No newline at end of file From 5f9a978af5b7500438a41cf924a02b475bad6bfd Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 24 Sep 2022 18:27:07 +0200 Subject: [PATCH 3/5] retrieval eval: fix gather with variable length tensors on different ranks --- src/training/data.py | 5 ++-- src/training/train.py | 63 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index e3f9b4c03..08af2196e 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -362,9 +362,8 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this else: if args.distributed_evaluation: - num_batches = math.ceil(num_samples / (args.batch_size * args.world_size)) - else: - num_batches = math.ceil(num_samples / args.batch_size) + num_samples = num_samples // args.world_size + num_batches = math.ceil(num_samples / args.batch_size) dataloader = wds.WebLoader( dataset, diff --git a/src/training/train.py b/src/training/train.py index 40e729514..4a6884b73 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -199,25 +199,39 @@ def evaluate(model, data, epoch, args, tb_writer=None): num_samples += batch_size if (i % 100) == 0: if args.distributed_evaluation: - cumulative_loss_all_reduced = cumulative_loss.clone() - torch.distributed.all_reduce(cumulative_loss_all_reduced, op=torch.distributed.ReduceOp.AVG) + total_num_samples = torch.Tensor([num_samples]).to(device) + torch.distributed.all_reduce(total_num_samples) + total_num_samples = total_num_samples.item() + + total_cumulative_loss = cumulative_loss.clone() + torch.distributed.all_reduce(total_cumulative_loss) + + loss = total_cumulative_loss / total_num_samples else: - cumulative_loss_all_reduced = cumulative_loss + loss = cumulative_loss / num_samples if is_master(args): logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" - f"Loss: {cumulative_loss_all_reduced / num_samples:.6f}\t") + f"Loss: {loss:.6f}\t") val_metrics = get_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), args=args, ) - loss = cumulative_loss / num_samples if args.distributed_evaluation: - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + total_num_samples = torch.Tensor([num_samples]).to(device) + torch.distributed.all_reduce(total_num_samples) + total_num_samples = int(total_num_samples.item()) + + total_cumulative_loss = cumulative_loss.clone() + torch.distributed.all_reduce(total_cumulative_loss) + + loss = total_cumulative_loss / total_num_samples + else: + loss = cumulative_loss / num_samples metrics.update( - {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} + {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": total_num_samples if args.distributed_evaluation else num_samples} ) if not metrics: @@ -249,13 +263,13 @@ def evaluate(model, data, epoch, args, tb_writer=None): def get_metrics(image_features, text_features, logit_scale, args): metrics = {} if args.distributed_evaluation: - image_features = all_gather(image_features.to(args.device)).cpu() - text_features = all_gather(text_features.to(args.device)).cpu() + image_features = varsize_tensor_all_gather(image_features.to(args.device)).cpu() + text_features = varsize_tensor_all_gather(text_features.to(args.device)).cpu() + logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} ground_truth = torch.arange(len(text_features)).view(-1, 1) - for name, logit in logits.items(): ranking = torch.argsort(logit, descending=True) preds = torch.where(ranking == ground_truth)[1] @@ -267,7 +281,34 @@ def get_metrics(image_features, text_features, logit_scale, args): return metrics -def all_gather(tensor): +def varsize_tensor_all_gather(tensor: torch.Tensor): + # https://discuss.pytorch.org/t/how-to-concatenate-different-size-tensors-from-distributed-processes/44819/4 + # thanks to @mranzinger + device = tensor.device + size_tens = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=device) + size_tens = tensor_all_gather(size_tens).cpu() + max_size = size_tens.max() + + padded = torch.empty(max_size, *tensor.shape[1:], + dtype=tensor.dtype, + device=device) + padded[:tensor.shape[0]] = tensor + + ag = tensor_all_gather(padded) + + slices = [] + for i, sz in enumerate(size_tens): + start_idx = i * max_size + end_idx = start_idx + sz.item() + + if end_idx > start_idx: + slices.append(ag[start_idx:end_idx]) + + ret = torch.cat(slices, dim=0) + + return ret.to(tensor) + +def tensor_all_gather(tensor): world_size = torch.distributed.get_world_size() tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] torch.distributed.all_gather(tensor_list, tensor) From ff36243851007d69d1068b69b5ee2a8b1cb36072 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 24 Sep 2022 18:35:26 +0200 Subject: [PATCH 4/5] fix zero-shot distributed eval --- src/training/zero_shot.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 64bdba294..42a82e4e1 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -56,13 +56,15 @@ def run(model, classifier, dataloader, args): top1 += acc1 top5 += acc5 n += images.size(0) - - top1 = (top1 / n) - top5 = (top5 / n) if args.distributed_evaluation: + n = torch.Tensor([n]).to(args.device) + torch.distributed.all_reduce(n) + n = n.item() top = torch.Tensor([top1, top5]).to(args.device) - torch.distributed.all_reduce(top, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(top) top1, top5 = top.tolist() + top1 = (top1 / n) + top5 = (top5 / n) return top1, top5 From 84b2dc77632d966917b7ed0989717dab05dc1abd Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 24 Sep 2022 18:51:32 +0200 Subject: [PATCH 5/5] zero-shot distributed eval: all reduce metrics all together --- src/training/zero_shot.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 42a82e4e1..6af87f62c 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -57,12 +57,9 @@ def run(model, classifier, dataloader, args): top5 += acc5 n += images.size(0) if args.distributed_evaluation: - n = torch.Tensor([n]).to(args.device) - torch.distributed.all_reduce(n) - n = n.item() - top = torch.Tensor([top1, top5]).to(args.device) - torch.distributed.all_reduce(top) - top1, top5 = top.tolist() + vals = torch.Tensor([n, top1, top5]).to(args.device) + torch.distributed.all_reduce(vals) + n, top1, top5 = vals.tolist() top1 = (top1 / n) top5 = (top5 / n) return top1, top5