From e14d0d88c866a675737850067ef7d44a04a07282 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 31 May 2020 14:31:45 -0700 Subject: [PATCH] More efficient data sampling during training --- aitextgen/TokenDataset.py | 16 ++++++++++------ aitextgen/aitextgen.py | 4 ++++ aitextgen/train.py | 2 +- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/aitextgen/TokenDataset.py b/aitextgen/TokenDataset.py index 598dc77..1e46a63 100644 --- a/aitextgen/TokenDataset.py +++ b/aitextgen/TokenDataset.py @@ -162,6 +162,8 @@ def save( self, cache_destination: str = "dataset_cache.tar.gz", compress: bool = True ) -> None: assert len(self.tokens) > 0, "No data loaded to save." + if not isinstance(self.tokens, list): + self.tokens = self.tokens.tolist() if compress: open_func = gzip.open @@ -180,18 +182,20 @@ def save( with open_func(cache_destination, "wb") as f: msgpack.pack(self.tokens, f) - def __len__(self): + def __len__(self) -> int: return self.num_subsets def __getitem__(self, item: int) -> torch.Tensor: - return torch.tensor( - self.tokens[item : (item + self.block_size)], dtype=torch.long - ) + # assumes self.tokens is a torch.tensor, + # which is set during training. + if isinstance(self.tokens, list): + return self.tokens[item : (item + self.block_size)] + return self.tokens.narrow(0, item, self.block_size) - def __str__(self): + def __str__(self) -> str: return self.file_path if self.file_path is not None else "loaded dataset" - def __repr__(self): + def __repr__(self) -> str: return f"TokenDataset containing {self.num_subsets:,} subsets loaded {self.str_suffix}" diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index f271cb7..86618ae 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -478,8 +478,12 @@ def train( **kwargs, ) + if isinstance(train_data.tokens, list): + train_data.tokens = torch.tensor(train_data.tokens, dtype=torch.long) + if num_workers is None: # Use all CPU cores as workers if not training on CPU + # Can overload 2x w/o diminishing returns if is_gpu_used or n_tpu_cores > 0: num_workers = os.cpu_count() * 2 # If training on the CPU, use half the CPUs diff --git a/aitextgen/train.py b/aitextgen/train.py index 186fd8e..9e418a6 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -166,7 +166,7 @@ def on_batch_end(self, trainer, pl_module): desc = f"Loss: {current_loss:.3f} — Avg: {avg_loss:.3f}" - if self.progress_bar_refresh_rate % self.steps == 0: + if self.steps % self.progress_bar_refresh_rate == 0: if self.gpu: desc += f" — GPU Mem: {get_gpu_memory_map()['gpu_0']} MB" self.main_progress_bar.update(self.progress_bar_refresh_rate)