diff --git a/examples/weather/graphcast/validation_base.py b/examples/weather/graphcast/validation_base.py index 765285cfe..f119d616c 100644 --- a/examples/weather/graphcast/validation_base.py +++ b/examples/weather/graphcast/validation_base.py @@ -82,7 +82,9 @@ def step(self, channels=[0, 1, 2], iter=0, time_idx=None): "num_samples_per_year": self.num_samples_per_year_train, "device": self.dist.device, } - for i, data in enumerate(self.val_datapipe): + val_data_iter = iter(itertools.islice(self.val_datapipe, len(self.val_datapipe))) + for i in range(len(self.val_datapipe)): + data = next(val_data_iter) invar = data[0]["invar"] outvar = data[0]["outvar"][0] try: