diff --git a/aitextgen/TokenDataset.py b/aitextgen/TokenDataset.py index 1e46a63..bf69d29 100644 --- a/aitextgen/TokenDataset.py +++ b/aitextgen/TokenDataset.py @@ -2,7 +2,6 @@ import logging import csv import os -import msgpack import gzip from torch.utils.data import Dataset from typing import List @@ -10,7 +9,7 @@ from pkg_resources import resource_filename import itertools from tqdm.auto import tqdm -from array import array +import numpy as np logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -45,7 +44,8 @@ class TokenDataset(Dataset): :param bos_token: String to override the beginning-of-string token :param eos_token: String to override the end-of-string token :param unk_token: String to override the unknown token - :param unk_token: String to override the padding token + :param pad_token: String to override the padding token + :param progress_bar_refresh_rate: How often to update progress bar when loading """ def __init__( @@ -74,7 +74,7 @@ def __init__( # Special case; load tokenized texts immediately if tokenized_texts: self.tokens = tokenized_texts - self.num_subsets = len(self.tokens) - block_size + self.num_subsets = self.tokens.shape[0] - block_size self.block_size = block_size self.file_path = "merged TokenDataset" self.str_suffix = "by merging TokenDatasets." @@ -96,8 +96,8 @@ def __init__( open_func = gzip.open if file_path.endswith(".gz") else open with open_func(file_path, "rb") as f: - self.tokens = msgpack.unpack(f) - self.num_subsets = len(self.tokens) - block_size + self.tokens = np.load(f) + self.num_subsets = self.tokens.shape[0] - block_size self.block_size = block_size self.str_suffix = "via cache." @@ -150,9 +150,9 @@ def __init__( ) assert ( - len(self.tokens) >= block_size + self.tokens.shape[0] >= block_size ), f"There are fewer than {block_size} encoded tokens." - self.num_subsets = len(self.tokens) - block_size + self.num_subsets = self.tokens.shape[0] - block_size self.block_size = block_size if save_cache: @@ -161,9 +161,7 @@ def __init__( def save( self, cache_destination: str = "dataset_cache.tar.gz", compress: bool = True ) -> None: - assert len(self.tokens) > 0, "No data loaded to save." - if not isinstance(self.tokens, list): - self.tokens = self.tokens.tolist() + assert self.tokens.shape[0] > 0, "No data loaded to save." if compress: open_func = gzip.open @@ -171,7 +169,7 @@ def save( else: open_func = open cache_destination = ( - "dataset_cache.msgpack" + "dataset_cache.npy" if cache_destination == "dataset_cache.tar.gz" else cache_destination ) @@ -180,17 +178,16 @@ def save( logger.info(f"Caching {compress_str}dataset to {cache_destination}") with open_func(cache_destination, "wb") as f: - msgpack.pack(self.tokens, f) + np.save(f, self.tokens) def __len__(self) -> int: return self.num_subsets def __getitem__(self, item: int) -> torch.Tensor: - # assumes self.tokens is a torch.tensor, - # which is set during training. - if isinstance(self.tokens, list): - return self.tokens[item : (item + self.block_size)] - return self.tokens.narrow(0, item, self.block_size) + return torch.as_tensor( + self.tokens[item : (item + self.block_size)].astype(np.int64, copy=False), + dtype=torch.long, + ) def __str__(self) -> str: return self.file_path if self.file_path is not None else "loaded dataset" @@ -222,6 +219,22 @@ def get_lines_in_file_csv(file_path: str, header: bool = True) -> int: return sum(1 for row in reader) +def get_dtype(vocab_size: int): + """ + Finds the appropriate numpy dtype depending on vocab size. + + The highest value for the dtype serves as a placeholder. + """ + if vocab_size < 2 ** 8 - 1: + return np.uint8 + elif vocab_size < 2 ** 16 - 1: + return np.uint16 + elif vocab_size < 2 ** 32 - 1: + return np.uint32 + + return np.uint64 + + def encode_tokens_from_file( file_path: str, eos_token: str, @@ -236,6 +249,7 @@ def encode_tokens_from_file( """ is_csv = file_path.endswith(".csv") + a_dtype = get_dtype(tokenizer.vocab_size) if is_csv: num_texts = get_lines_in_file_csv(file_path, header) @@ -243,7 +257,7 @@ def encode_tokens_from_file( num_texts = get_lines_in_file(file_path, newline) pbar = tqdm(total=num_texts, smoothing=0, leave=True, dynamic_ncols=True,) - tokens = array("I") + tokens = np.full((num_texts, 1), -1, dtype=a_dtype) num_batches = 0 with open(file_path, "r", encoding="utf-8", newline=newline) as f_load: @@ -273,26 +287,37 @@ def encode_tokens_from_file( if not batch: break - encoded_tokens = array( - "I", - itertools.chain.from_iterable( - tokenizer.batch_encode_plus( - batch, - add_special_tokens=False, - return_token_type_ids=False, - return_attention_masks=False, - )["input_ids"] - ), - ) - tokens.extend(encoded_tokens) + encoded_texts = tokenizer.batch_encode_plus( + batch, + add_special_tokens=False, + return_token_type_ids=False, + return_attention_masks=False, + )["input_ids"] + + for i, encoded_text in enumerate(encoded_texts): + if len(encoded_text) > tokens.shape[1]: + cols_to_add = len(encoded_text) - tokens.shape[1] + tokens = np.concatenate( + ( + tokens, + np.full((num_texts, cols_to_add), -1, dtype=a_dtype,), + ), + axis=1, + ) + tokens[ + (num_batches * batch_size) + i, : len(encoded_text) + ] = encoded_text num_batches += 1 if num_batches % progress_bar_refresh_rate == 0: pbar.update(batch_size * progress_bar_refresh_rate) + pbar.n = num_texts + pbar.refresh() pbar.close() - return tokens.tolist() + tokens = tokens.flatten() + return tokens[tokens < np.array(-1, dtype=a_dtype)] def encode_tokens_from_list( @@ -306,12 +331,14 @@ def encode_tokens_from_list( Retrieves texts from a newline-delimited file/CSV and returns texts. """ - logger.info(f"Encoding {len(texts):,} texts.") + num_texts = len(texts) + a_dtype = get_dtype(tokenizer.vocab_size) + logger.info(f"Encoding {num_texts:,} texts.") - pbar = tqdm(total=len(texts), smoothing=0, leave=True, dynamic_ncols=True,) - tokens = array("I") + pbar = tqdm(total=num_texts, smoothing=0, leave=True, dynamic_ncols=True,) + tokens = np.full((len(texts), 1), -1, dtype=a_dtype) - for i_start in range(len(texts) // batch_size + 1): + for i_start in range(num_texts // batch_size + 1): batch = [ text + eos_token for text in texts[ @@ -319,24 +346,30 @@ def encode_tokens_from_list( ] ] - encoded_tokens = array( - "I", - itertools.chain.from_iterable( - tokenizer.batch_encode_plus( - batch, - add_special_tokens=False, - return_token_type_ids=False, - return_attention_masks=False, - )["input_ids"] - ), - ) - tokens.extend(encoded_tokens) + encoded_texts = tokenizer.batch_encode_plus( + batch, + add_special_tokens=False, + return_token_type_ids=False, + return_attention_masks=False, + )["input_ids"] + + for i, encoded_text in enumerate(encoded_texts): + if len(encoded_text) > tokens.shape[1]: + cols_to_add = len(encoded_text) - tokens.shape[1] + tokens = np.concatenate( + (tokens, np.full((num_texts, cols_to_add), -1, dtype=a_dtype,),), + axis=1, + ) + tokens[(i_start * batch_size) + i, : len(encoded_text)] = encoded_text if i_start % progress_bar_refresh_rate == 0: pbar.update(batch_size * progress_bar_refresh_rate) + pbar.n = num_texts + pbar.refresh() pbar.close() - return tokens.tolist() + tokens = tokens.flatten() + return tokens[tokens < np.array(-1, dtype=a_dtype)] def merge_datasets(datasets: List[TokenDataset], equalize: bool = True) -> TokenDataset: @@ -344,10 +377,8 @@ def merge_datasets(datasets: List[TokenDataset], equalize: bool = True) -> Token Merges multiple TokenDatasets into a single TokenDataset. This assumes that you are using the same tokenizer for all TokenDatasets. - ## Parameters - - * **datasets**: A list of TokenDatasets. - * **equalize**: Whether to take an equal amount of samples from all + :param datasets: A list of TokenDatasets. + :param equalize: Whether to take an equal amount of samples from all input datasets (by taking random samples from each dataset equal to the smallest dataset) in order to balance out the result dataset. diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 86618ae..69e839d 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -345,6 +345,7 @@ def generate_to_file( :param sample_delim: The text used to delimit each generated text. :param seed: Seed used for the generation. The last part of a file name will be the seed used to reproduce a generation. + :param cleanup: Whether to polish the text before returning See generate() for more parameters. """ @@ -450,6 +451,8 @@ def train( :param save_gdrive: If using Colab, whether to save the notebook to Google Drive at each save_every :param run_id: Run identifier; used for save_gdrive + :param progress_bar_refresh_rate: How often to update + the progress bar while training. """ assert not self.torchscript, "You cannot train a traced TorchScript model." @@ -478,9 +481,6 @@ def train( **kwargs, ) - if isinstance(train_data.tokens, list): - train_data.tokens = torch.tensor(train_data.tokens, dtype=torch.long) - if num_workers is None: # Use all CPU cores as workers if not training on CPU # Can overload 2x w/o diminishing returns diff --git a/docs/dataset.md b/docs/dataset.md index dac5396..b5d8162 100644 --- a/docs/dataset.md +++ b/docs/dataset.md @@ -33,7 +33,7 @@ data = TokenDataset(texts = ["Lorem", "Ipsum", "Dolor"]) ## Saving/Loading a TokenDataset -When creating a TokenDataset, you can automatically save it as a compressed gzipped MessagePack binary when completed. +When creating a TokenDataset, you can automatically save it as a compressed gzipped numpy array when completed. ```python data = TokenDataset("shakespeare.txt", save_cache=True) @@ -52,6 +52,10 @@ By default, it will save to `dataset_cache.tar.gz`. You can then reload that int data = TokenDataset("dataset_cache.tar.gz", from_cache=True) ``` + +!!! note "CLI" + You can quickly create a Tokenized dataset using the command line, e.g. `aitextgen encode text.txt`. This will drastically reduce the file size, and is recommended before moving the file to cloud services (where it can be loaded using the `from_cache` parameter noted above) + ## Using TokenDatasets with a Custom GPT-2 Model The default TokenDataset has a `block_size` of `1024`, which corresponds to the _context window of the default GPT-2 model_. If you're using a custom model w/ a different maximum. Additionally, you must explicitly provide the vocab and merges files to rebuild the tokenizer, as the tokenizer will be different than the normal GPT-2 one. diff --git a/requirements.txt b/requirements.txt index a10d2cf..71955e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ transformers>=2.9.1 fire>=0.3.0 -msgpack pytorch-lightning>=0.7.6 \ No newline at end of file diff --git a/setup.py b/setup.py index 3ff2cbf..a743a35 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,6 @@ install_requires=[ "transformers>=2.9.1", "fire>=0.3.0", - "msgpack", "pytorch-lightning>=0.7.6", ], )