-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support distributed evaluation #176
base: main
Are you sure you want to change the base?
Changes from all commits
25bf954
9b26fef
5f9a978
ff36243
84b2dc7
03839c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,7 +155,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w | |
|
||
def evaluate(model, data, epoch, args, tb_writer=None): | ||
metrics = {} | ||
if not is_master(args): | ||
if not is_master(args) and not args.distributed_evaluation: | ||
return metrics | ||
device = torch.device(args.device) | ||
model.eval() | ||
|
@@ -200,30 +200,53 @@ def evaluate(model, data, epoch, args, tb_writer=None): | |
|
||
cumulative_loss += total_loss * batch_size | ||
num_samples += batch_size | ||
if is_master(args) and (i % 100) == 0: | ||
logging.info( | ||
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" | ||
f"Loss: {cumulative_loss / num_samples:.6f}\t") | ||
|
||
if (i % 100) == 0: | ||
if args.distributed_evaluation: | ||
total_num_samples = torch.Tensor([num_samples]).to(device) | ||
torch.distributed.all_reduce(total_num_samples) | ||
total_num_samples = total_num_samples.item() | ||
|
||
total_cumulative_loss = cumulative_loss.clone() | ||
torch.distributed.all_reduce(total_cumulative_loss) | ||
|
||
loss = total_cumulative_loss / total_num_samples | ||
else: | ||
loss = cumulative_loss / num_samples | ||
if is_master(args): | ||
logging.info( | ||
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" | ||
f"Loss: {loss:.6f}\t") | ||
val_metrics = get_metrics( | ||
image_features=torch.cat(all_image_features), | ||
text_features=torch.cat(all_text_features), | ||
logit_scale=logit_scale.cpu(), | ||
args=args, | ||
) | ||
loss = cumulative_loss / num_samples | ||
if args.distributed_evaluation: | ||
total_num_samples = torch.Tensor([num_samples]).to(device) | ||
torch.distributed.all_reduce(total_num_samples) | ||
total_num_samples = int(total_num_samples.item()) | ||
|
||
total_cumulative_loss = cumulative_loss.clone() | ||
torch.distributed.all_reduce(total_cumulative_loss) | ||
|
||
loss = total_cumulative_loss / total_num_samples | ||
else: | ||
loss = cumulative_loss / num_samples | ||
metrics.update( | ||
{**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} | ||
{**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": total_num_samples if args.distributed_evaluation else num_samples} | ||
) | ||
|
||
if not metrics: | ||
return metrics | ||
|
||
logging.info( | ||
f"Eval Epoch: {epoch} " | ||
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) | ||
) | ||
if is_master(args): | ||
logging.info( | ||
f"Eval Epoch: {epoch} " | ||
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) | ||
) | ||
|
||
if args.save_logs: | ||
if args.save_logs and is_master(args): | ||
for name, val in metrics.items(): | ||
if tb_writer is not None: | ||
tb_writer.add_scalar(f"val/{name}", val, epoch) | ||
|
@@ -232,22 +255,24 @@ def evaluate(model, data, epoch, args, tb_writer=None): | |
f.write(json.dumps(metrics)) | ||
f.write("\n") | ||
|
||
if args.wandb: | ||
if args.wandb and is_master(args): | ||
assert wandb is not None, 'Please install wandb.' | ||
for name, val in metrics.items(): | ||
wandb.log({f"val/{name}": val, 'epoch': epoch}) | ||
|
||
return metrics | ||
|
||
|
||
def get_metrics(image_features, text_features, logit_scale): | ||
def get_metrics(image_features, text_features, logit_scale, args): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit pick but maybe better to pass a more specific argument e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks I agree, I think it's better |
||
metrics = {} | ||
if args.distributed_evaluation: | ||
image_features = varsize_tensor_all_gather(image_features.to(args.device)).cpu() | ||
text_features = varsize_tensor_all_gather(text_features.to(args.device)).cpu() | ||
|
||
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() | ||
logits_per_text = logits_per_image.t().detach().cpu() | ||
|
||
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} | ||
ground_truth = torch.arange(len(text_features)).view(-1, 1) | ||
|
||
for name, logit in logits.items(): | ||
ranking = torch.argsort(logit, descending=True) | ||
preds = torch.where(ranking == ground_truth)[1] | ||
|
@@ -258,3 +283,36 @@ def get_metrics(image_features, text_features, logit_scale): | |
metrics[f"{name}_R@{k}"] = np.mean(preds < k) | ||
|
||
return metrics | ||
|
||
def varsize_tensor_all_gather(tensor: torch.Tensor): | ||
# https://discuss.pytorch.org/t/how-to-concatenate-different-size-tensors-from-distributed-processes/44819/4 | ||
# thanks to @mranzinger | ||
device = tensor.device | ||
size_tens = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=device) | ||
size_tens = tensor_all_gather(size_tens).cpu() | ||
max_size = size_tens.max() | ||
|
||
padded = torch.empty(max_size, *tensor.shape[1:], | ||
dtype=tensor.dtype, | ||
device=device) | ||
padded[:tensor.shape[0]] = tensor | ||
|
||
ag = tensor_all_gather(padded) | ||
|
||
slices = [] | ||
for i, sz in enumerate(size_tens): | ||
start_idx = i * max_size | ||
end_idx = start_idx + sz.item() | ||
|
||
if end_idx > start_idx: | ||
slices.append(ag[start_idx:end_idx]) | ||
|
||
ret = torch.cat(slices, dim=0) | ||
|
||
return ret.to(tensor) | ||
|
||
def tensor_all_gather(tensor): | ||
world_size = torch.distributed.get_world_size() | ||
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] | ||
torch.distributed.all_gather(tensor_list, tensor) | ||
return torch.cat(tensor_list, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe I'm misunderstanding, but does it skip the last few samples due to the last partial batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in evaluation,
num_samples
is only used for logging :open_clip/src/training/train.py
Line 172 in 03839c5
open_clip/src/training/train.py
Line 217 in 03839c5
it is not affecting the dataloader which depends only on the wds pipeline. But it would be good to get it correct anyway, have to check how many examples each worker exactly receive.