Skip to content

Commit

Permalink
Use numpy backend for storing tokens (#26)
Browse files Browse the repository at this point in the history
* Working file loading w/ numpy

* Working text encoding w/ numpy

* Use appropriate dtypes

* Fix saving/loading

* Train using numpy

* Ensure pbar completion

* Remove msgpack as a dependency

* Use numpy slicing for feeding to model

* Update docstrings

* Update dataset with newer implementation
  • Loading branch information
minimaxir authored Jun 2, 2020
1 parent e14d0d8 commit f299262
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 59 deletions.
137 changes: 84 additions & 53 deletions aitextgen/TokenDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import logging
import csv
import os
import msgpack
import gzip
from torch.utils.data import Dataset
from typing import List
from transformers import GPT2TokenizerFast
from pkg_resources import resource_filename
import itertools
from tqdm.auto import tqdm
from array import array
import numpy as np

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -45,7 +44,8 @@ class TokenDataset(Dataset):
:param bos_token: String to override the beginning-of-string token
:param eos_token: String to override the end-of-string token
:param unk_token: String to override the unknown token
:param unk_token: String to override the padding token
:param pad_token: String to override the padding token
:param progress_bar_refresh_rate: How often to update progress bar when loading
"""

def __init__(
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
# Special case; load tokenized texts immediately
if tokenized_texts:
self.tokens = tokenized_texts
self.num_subsets = len(self.tokens) - block_size
self.num_subsets = self.tokens.shape[0] - block_size
self.block_size = block_size
self.file_path = "merged TokenDataset"
self.str_suffix = "by merging TokenDatasets."
Expand All @@ -96,8 +96,8 @@ def __init__(
open_func = gzip.open if file_path.endswith(".gz") else open

with open_func(file_path, "rb") as f:
self.tokens = msgpack.unpack(f)
self.num_subsets = len(self.tokens) - block_size
self.tokens = np.load(f)
self.num_subsets = self.tokens.shape[0] - block_size
self.block_size = block_size
self.str_suffix = "via cache."

Expand Down Expand Up @@ -150,9 +150,9 @@ def __init__(
)

assert (
len(self.tokens) >= block_size
self.tokens.shape[0] >= block_size
), f"There are fewer than {block_size} encoded tokens."
self.num_subsets = len(self.tokens) - block_size
self.num_subsets = self.tokens.shape[0] - block_size
self.block_size = block_size

if save_cache:
Expand All @@ -161,17 +161,15 @@ def __init__(
def save(
self, cache_destination: str = "dataset_cache.tar.gz", compress: bool = True
) -> None:
assert len(self.tokens) > 0, "No data loaded to save."
if not isinstance(self.tokens, list):
self.tokens = self.tokens.tolist()
assert self.tokens.shape[0] > 0, "No data loaded to save."

if compress:
open_func = gzip.open
compress_str = "and compressing "
else:
open_func = open
cache_destination = (
"dataset_cache.msgpack"
"dataset_cache.npy"
if cache_destination == "dataset_cache.tar.gz"
else cache_destination
)
Expand All @@ -180,17 +178,16 @@ def save(
logger.info(f"Caching {compress_str}dataset to {cache_destination}")

with open_func(cache_destination, "wb") as f:
msgpack.pack(self.tokens, f)
np.save(f, self.tokens)

def __len__(self) -> int:
return self.num_subsets

def __getitem__(self, item: int) -> torch.Tensor:
# assumes self.tokens is a torch.tensor,
# which is set during training.
if isinstance(self.tokens, list):
return self.tokens[item : (item + self.block_size)]
return self.tokens.narrow(0, item, self.block_size)
return torch.as_tensor(
self.tokens[item : (item + self.block_size)].astype(np.int64, copy=False),
dtype=torch.long,
)

def __str__(self) -> str:
return self.file_path if self.file_path is not None else "loaded dataset"
Expand Down Expand Up @@ -222,6 +219,22 @@ def get_lines_in_file_csv(file_path: str, header: bool = True) -> int:
return sum(1 for row in reader)


def get_dtype(vocab_size: int):
"""
Finds the appropriate numpy dtype depending on vocab size.
The highest value for the dtype serves as a placeholder.
"""
if vocab_size < 2 ** 8 - 1:
return np.uint8
elif vocab_size < 2 ** 16 - 1:
return np.uint16
elif vocab_size < 2 ** 32 - 1:
return np.uint32

return np.uint64


def encode_tokens_from_file(
file_path: str,
eos_token: str,
Expand All @@ -236,14 +249,15 @@ def encode_tokens_from_file(
"""

is_csv = file_path.endswith(".csv")
a_dtype = get_dtype(tokenizer.vocab_size)

if is_csv:
num_texts = get_lines_in_file_csv(file_path, header)
else:
num_texts = get_lines_in_file(file_path, newline)

pbar = tqdm(total=num_texts, smoothing=0, leave=True, dynamic_ncols=True,)
tokens = array("I")
tokens = np.full((num_texts, 1), -1, dtype=a_dtype)
num_batches = 0

with open(file_path, "r", encoding="utf-8", newline=newline) as f_load:
Expand Down Expand Up @@ -273,26 +287,37 @@ def encode_tokens_from_file(
if not batch:
break

encoded_tokens = array(
"I",
itertools.chain.from_iterable(
tokenizer.batch_encode_plus(
batch,
add_special_tokens=False,
return_token_type_ids=False,
return_attention_masks=False,
)["input_ids"]
),
)
tokens.extend(encoded_tokens)
encoded_texts = tokenizer.batch_encode_plus(
batch,
add_special_tokens=False,
return_token_type_ids=False,
return_attention_masks=False,
)["input_ids"]

for i, encoded_text in enumerate(encoded_texts):
if len(encoded_text) > tokens.shape[1]:
cols_to_add = len(encoded_text) - tokens.shape[1]
tokens = np.concatenate(
(
tokens,
np.full((num_texts, cols_to_add), -1, dtype=a_dtype,),
),
axis=1,
)
tokens[
(num_batches * batch_size) + i, : len(encoded_text)
] = encoded_text

num_batches += 1

if num_batches % progress_bar_refresh_rate == 0:
pbar.update(batch_size * progress_bar_refresh_rate)

pbar.n = num_texts
pbar.refresh()
pbar.close()
return tokens.tolist()
tokens = tokens.flatten()
return tokens[tokens < np.array(-1, dtype=a_dtype)]


def encode_tokens_from_list(
Expand All @@ -306,48 +331,54 @@ def encode_tokens_from_list(
Retrieves texts from a newline-delimited file/CSV and returns texts.
"""

logger.info(f"Encoding {len(texts):,} texts.")
num_texts = len(texts)
a_dtype = get_dtype(tokenizer.vocab_size)
logger.info(f"Encoding {num_texts:,} texts.")

pbar = tqdm(total=len(texts), smoothing=0, leave=True, dynamic_ncols=True,)
tokens = array("I")
pbar = tqdm(total=num_texts, smoothing=0, leave=True, dynamic_ncols=True,)
tokens = np.full((len(texts), 1), -1, dtype=a_dtype)

for i_start in range(len(texts) // batch_size + 1):
for i_start in range(num_texts // batch_size + 1):
batch = [
text + eos_token
for text in texts[
(i_start * batch_size) : ((i_start * batch_size) + batch_size)
]
]

encoded_tokens = array(
"I",
itertools.chain.from_iterable(
tokenizer.batch_encode_plus(
batch,
add_special_tokens=False,
return_token_type_ids=False,
return_attention_masks=False,
)["input_ids"]
),
)
tokens.extend(encoded_tokens)
encoded_texts = tokenizer.batch_encode_plus(
batch,
add_special_tokens=False,
return_token_type_ids=False,
return_attention_masks=False,
)["input_ids"]

for i, encoded_text in enumerate(encoded_texts):
if len(encoded_text) > tokens.shape[1]:
cols_to_add = len(encoded_text) - tokens.shape[1]
tokens = np.concatenate(
(tokens, np.full((num_texts, cols_to_add), -1, dtype=a_dtype,),),
axis=1,
)
tokens[(i_start * batch_size) + i, : len(encoded_text)] = encoded_text

if i_start % progress_bar_refresh_rate == 0:
pbar.update(batch_size * progress_bar_refresh_rate)

pbar.n = num_texts
pbar.refresh()
pbar.close()
return tokens.tolist()
tokens = tokens.flatten()
return tokens[tokens < np.array(-1, dtype=a_dtype)]


def merge_datasets(datasets: List[TokenDataset], equalize: bool = True) -> TokenDataset:
"""
Merges multiple TokenDatasets into a single TokenDataset.
This assumes that you are using the same tokenizer for all TokenDatasets.
## Parameters
* **datasets**: A list of TokenDatasets.
* **equalize**: Whether to take an equal amount of samples from all
:param datasets: A list of TokenDatasets.
:param equalize: Whether to take an equal amount of samples from all
input datasets (by taking random samples from
each dataset equal to the smallest dataset)
in order to balance out the result dataset.
Expand Down
6 changes: 3 additions & 3 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def generate_to_file(
:param sample_delim: The text used to delimit each generated text.
:param seed: Seed used for the generation. The last part of a file name
will be the seed used to reproduce a generation.
:param cleanup: Whether to polish the text before returning
See generate() for more parameters.
"""
Expand Down Expand Up @@ -450,6 +451,8 @@ def train(
:param save_gdrive: If using Colab, whether to save the notebook
to Google Drive at each save_every
:param run_id: Run identifier; used for save_gdrive
:param progress_bar_refresh_rate: How often to update
the progress bar while training.
"""

assert not self.torchscript, "You cannot train a traced TorchScript model."
Expand Down Expand Up @@ -478,9 +481,6 @@ def train(
**kwargs,
)

if isinstance(train_data.tokens, list):
train_data.tokens = torch.tensor(train_data.tokens, dtype=torch.long)

if num_workers is None:
# Use all CPU cores as workers if not training on CPU
# Can overload 2x w/o diminishing returns
Expand Down
6 changes: 5 additions & 1 deletion docs/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ data = TokenDataset(texts = ["Lorem", "Ipsum", "Dolor"])

## Saving/Loading a TokenDataset

When creating a TokenDataset, you can automatically save it as a compressed gzipped MessagePack binary when completed.
When creating a TokenDataset, you can automatically save it as a compressed gzipped numpy array when completed.

```python
data = TokenDataset("shakespeare.txt", save_cache=True)
Expand All @@ -52,6 +52,10 @@ By default, it will save to `dataset_cache.tar.gz`. You can then reload that int
data = TokenDataset("dataset_cache.tar.gz", from_cache=True)
```

<!--prettier-ignore-->
!!! note "CLI"
You can quickly create a Tokenized dataset using the command line, e.g. `aitextgen encode text.txt`. This will drastically reduce the file size, and is recommended before moving the file to cloud services (where it can be loaded using the `from_cache` parameter noted above)

## Using TokenDatasets with a Custom GPT-2 Model

The default TokenDataset has a `block_size` of `1024`, which corresponds to the _context window of the default GPT-2 model_. If you're using a custom model w/ a different maximum. Additionally, you must explicitly provide the vocab and merges files to rebuild the tokenizer, as the tokenizer will be different than the normal GPT-2 one.
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
transformers>=2.9.1
fire>=0.3.0
msgpack
pytorch-lightning>=0.7.6
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
install_requires=[
"transformers>=2.9.1",
"fire>=0.3.0",
"msgpack",
"pytorch-lightning>=0.7.6",
],
)

0 comments on commit f299262

Please sign in to comment.