Skip to content

Commit

Permalink
More efficient data sampling during training
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed May 31, 2020
1 parent 5a1d1e9 commit e14d0d8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
16 changes: 10 additions & 6 deletions aitextgen/TokenDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def save(
self, cache_destination: str = "dataset_cache.tar.gz", compress: bool = True
) -> None:
assert len(self.tokens) > 0, "No data loaded to save."
if not isinstance(self.tokens, list):
self.tokens = self.tokens.tolist()

if compress:
open_func = gzip.open
Expand All @@ -180,18 +182,20 @@ def save(
with open_func(cache_destination, "wb") as f:
msgpack.pack(self.tokens, f)

def __len__(self):
def __len__(self) -> int:
return self.num_subsets

def __getitem__(self, item: int) -> torch.Tensor:
return torch.tensor(
self.tokens[item : (item + self.block_size)], dtype=torch.long
)
# assumes self.tokens is a torch.tensor,
# which is set during training.
if isinstance(self.tokens, list):
return self.tokens[item : (item + self.block_size)]
return self.tokens.narrow(0, item, self.block_size)

def __str__(self):
def __str__(self) -> str:
return self.file_path if self.file_path is not None else "loaded dataset"

def __repr__(self):
def __repr__(self) -> str:
return f"TokenDataset containing {self.num_subsets:,} subsets loaded {self.str_suffix}"


Expand Down
4 changes: 4 additions & 0 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,12 @@ def train(
**kwargs,
)

if isinstance(train_data.tokens, list):
train_data.tokens = torch.tensor(train_data.tokens, dtype=torch.long)

if num_workers is None:
# Use all CPU cores as workers if not training on CPU
# Can overload 2x w/o diminishing returns
if is_gpu_used or n_tpu_cores > 0:
num_workers = os.cpu_count() * 2
# If training on the CPU, use half the CPUs
Expand Down
2 changes: 1 addition & 1 deletion aitextgen/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def on_batch_end(self, trainer, pl_module):

desc = f"Loss: {current_loss:.3f} — Avg: {avg_loss:.3f}"

if self.progress_bar_refresh_rate % self.steps == 0:
if self.steps % self.progress_bar_refresh_rate == 0:
if self.gpu:
desc += f" — GPU Mem: {get_gpu_memory_map()['gpu_0']} MB"
self.main_progress_bar.update(self.progress_bar_refresh_rate)
Expand Down

0 comments on commit e14d0d8

Please sign in to comment.