Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support distributed evaluation #176

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_imagenet(args, preprocess_fns, split):
idxs = idxs.astype('int')
sampler = SubsetRandomSampler(np.where(idxs)[0])
else:
sampler = None
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if args.distributed_evaluation else None

dataloader = torch.utils.data.DataLoader(
dataset,
Expand Down Expand Up @@ -331,6 +331,8 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False):
),
])
else:
if args.distributed_evaluation:
pipeline.append(wds.split_by_node)
pipeline.extend([
wds.split_by_worker,
# at this point, we have an iterator over the shards assigned to each worker
Expand Down Expand Up @@ -359,7 +361,8 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False):
num_samples = num_batches * global_batch_size
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
else:
# last batches are partial, eval is done on single (master) node
if args.distributed_evaluation:
num_samples = num_samples // args.world_size
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I'm misunderstanding, but does it skip the last few samples due to the last partial batch?

Copy link
Contributor Author

@mehdidc mehdidc Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I'm misunderstanding, but does it skip the last few samples due to the last partial batch?

in evaluation, num_samples is only used for logging :

samples_per_val = dataloader.num_samples
,
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"

it is not affecting the dataloader which depends only on the wds pipeline. But it would be good to get it correct anyway, have to check how many examples each worker exactly receive.

num_batches = math.ceil(num_samples / args.batch_size)

dataloader = wds.WebLoader(
Expand Down Expand Up @@ -401,7 +404,11 @@ def get_csv_dataset(args, preprocess_fn, is_train, epoch=0):
caption_key=args.csv_caption_key,
sep=args.csv_separator)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
if is_train:
sampler = DistributedSampler(dataset) if args.distributed else None
else:
sampler = DistributedSampler(dataset) if args.distributed_evaluation else None

shuffle = is_train and sampler is None

dataloader = DataLoader(
Expand Down
7 changes: 7 additions & 0 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def parse_args():
default=None,
help="Path to imagenet v2 for conducting zero shot evaluation.",
)
parser.add_argument(
"--distributed-evaluation",
action="store_true",
default=False,
help="Do evaluation in a distributed way.",
)

parser.add_argument(
"--logs",
type=str,
Expand Down
92 changes: 75 additions & 17 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w

def evaluate(model, data, epoch, args, tb_writer=None):
metrics = {}
if not is_master(args):
if not is_master(args) and not args.distributed_evaluation:
return metrics
device = torch.device(args.device)
model.eval()
Expand Down Expand Up @@ -200,30 +200,53 @@ def evaluate(model, data, epoch, args, tb_writer=None):

cumulative_loss += total_loss * batch_size
num_samples += batch_size
if is_master(args) and (i % 100) == 0:
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"
f"Loss: {cumulative_loss / num_samples:.6f}\t")

if (i % 100) == 0:
if args.distributed_evaluation:
total_num_samples = torch.Tensor([num_samples]).to(device)
torch.distributed.all_reduce(total_num_samples)
total_num_samples = total_num_samples.item()

total_cumulative_loss = cumulative_loss.clone()
torch.distributed.all_reduce(total_cumulative_loss)

loss = total_cumulative_loss / total_num_samples
else:
loss = cumulative_loss / num_samples
if is_master(args):
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"
f"Loss: {loss:.6f}\t")
val_metrics = get_metrics(
image_features=torch.cat(all_image_features),
text_features=torch.cat(all_text_features),
logit_scale=logit_scale.cpu(),
args=args,
)
loss = cumulative_loss / num_samples
if args.distributed_evaluation:
total_num_samples = torch.Tensor([num_samples]).to(device)
torch.distributed.all_reduce(total_num_samples)
total_num_samples = int(total_num_samples.item())

total_cumulative_loss = cumulative_loss.clone()
torch.distributed.all_reduce(total_cumulative_loss)

loss = total_cumulative_loss / total_num_samples
else:
loss = cumulative_loss / num_samples
metrics.update(
{**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples}
{**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": total_num_samples if args.distributed_evaluation else num_samples}
)

if not metrics:
return metrics

logging.info(
f"Eval Epoch: {epoch} "
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)
if is_master(args):
logging.info(
f"Eval Epoch: {epoch} "
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)

if args.save_logs:
if args.save_logs and is_master(args):
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
Expand All @@ -232,22 +255,24 @@ def evaluate(model, data, epoch, args, tb_writer=None):
f.write(json.dumps(metrics))
f.write("\n")

if args.wandb:
if args.wandb and is_master(args):
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
wandb.log({f"val/{name}": val, 'epoch': epoch})

return metrics


def get_metrics(image_features, text_features, logit_scale):
def get_metrics(image_features, text_features, logit_scale, args):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit pick but maybe better to pass a more specific argument e.g. gather_tensor=args.distributed_evaluation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit pick but maybe better to pass a more specific argument e.g. gather_tensor=args.distributed_evaluation?

Thanks I agree, I think it's better

metrics = {}
if args.distributed_evaluation:
image_features = varsize_tensor_all_gather(image_features.to(args.device)).cpu()
text_features = varsize_tensor_all_gather(text_features.to(args.device)).cpu()

logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()

logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)

for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
Expand All @@ -258,3 +283,36 @@ def get_metrics(image_features, text_features, logit_scale):
metrics[f"{name}_R@{k}"] = np.mean(preds < k)

return metrics

def varsize_tensor_all_gather(tensor: torch.Tensor):
# https://discuss.pytorch.org/t/how-to-concatenate-different-size-tensors-from-distributed-processes/44819/4
# thanks to @mranzinger
device = tensor.device
size_tens = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=device)
size_tens = tensor_all_gather(size_tens).cpu()
max_size = size_tens.max()

padded = torch.empty(max_size, *tensor.shape[1:],
dtype=tensor.dtype,
device=device)
padded[:tensor.shape[0]] = tensor

ag = tensor_all_gather(padded)

slices = []
for i, sz in enumerate(size_tens):
start_idx = i * max_size
end_idx = start_idx + sz.item()

if end_idx > start_idx:
slices.append(ag[start_idx:end_idx])

ret = torch.cat(slices, dim=0)

return ret.to(tensor)

def tensor_all_gather(tensor):
world_size = torch.distributed.get_world_size()
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, tensor)
return torch.cat(tensor_list, dim=0)
25 changes: 17 additions & 8 deletions src/training/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm

from training.distributed import is_master
from open_clip import tokenize, get_cast_dtype
from .precision import get_autocast
from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template


def zero_shot_classifier(model, classnames, templates, args):
classnames = tqdm(classnames) if is_master(args) else classnames
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
for classname in classnames:
texts = [template(classname) for template in templates] # format with class
texts = tokenize(texts).to(args.device) # tokenize
if args.distributed and not args.horovod:
Expand All @@ -34,10 +35,11 @@ def accuracy(output, target, topk=(1,)):

def run(model, classifier, dataloader, args):
autocast = get_autocast(args.precision)
dataloader = tqdm(dataloader, unit_scale=args.batch_size) if is_master(args) else dataloader
cast_dtype = get_cast_dtype(args.precision)
with torch.no_grad():
top1, top5, n = 0., 0., 0.
for images, target in tqdm(dataloader, unit_scale=args.batch_size):
for images, target in dataloader:
images = images.to(args.device)
if cast_dtype is not None:
images = images.to(dtype=cast_dtype)
Expand All @@ -57,7 +59,10 @@ def run(model, classifier, dataloader, args):
top1 += acc1
top5 += acc5
n += images.size(0)

if args.distributed_evaluation:
vals = torch.Tensor([n, top1, top5]).to(args.device)
torch.distributed.all_reduce(vals)
n, top1, top5 = vals.tolist()
top1 = (top1 / n)
top5 = (top5 / n)
return top1, top5
Expand All @@ -71,12 +76,15 @@ def zero_shot_eval(model, data, epoch, args):
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
return {}

logging.info('Starting zero-shot imagenet.')
if is_master(args):
logging.info('Starting zero-shot imagenet.')
logging.info('Building zero-shot classifier')

logging.info('Building zero-shot classifier')
classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)

logging.info('Using classifier')
if is_master(args):
logging.info('Using classifier')

results = {}
if 'imagenet-val' in data:
top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
Expand All @@ -87,6 +95,7 @@ def zero_shot_eval(model, data, epoch, args):
results['imagenetv2-zeroshot-val-top1'] = top1
results['imagenetv2-zeroshot-val-top5'] = top5

logging.info('Finished zero-shot imagenet.')
if is_master(args):
logging.info('Finished zero-shot imagenet.')

return results