From 236c33326da0f3a0b61189db35b35b53d871117f Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 5 Dec 2020 12:35:39 -0800 Subject: [PATCH 01/40] version bump --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a9f31b9..99004b3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="aitextgen", packages=["aitextgen"], # this must be the same as the name above - version="0.3.0", + version="0.4.0", description="A robust Python tool for text-based AI training and generation using GPT-2.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", From 17a4915cf9f2b3b11f388abc45ee37e6ce18f40f Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 5 Dec 2020 15:45:34 -0800 Subject: [PATCH 02/40] Set serialize to default to True --- aitextgen/tokenizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aitextgen/tokenizers.py b/aitextgen/tokenizers.py index c3ba725..89926af 100644 --- a/aitextgen/tokenizers.py +++ b/aitextgen/tokenizers.py @@ -1,4 +1,4 @@ -from tokenizers import Tokenizer, trainers, models, ByteLevelBPETokenizer +from tokenizers import ByteLevelBPETokenizer from typing import Union, List import logging @@ -15,7 +15,7 @@ def train_tokenizer( bos_token: str = "<|endoftext|>", eos_token: str = "<|endoftext|>", unk_token: str = "<|endoftext|>", - serialize: bool = False, + serialize: bool = True, ) -> None: """ Tokenizes the text(s) as a tokenizer, wrapping the tokenizer package. From f8c2e317ec8b777f915353394679fd91c191a75a Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 5 Dec 2020 19:34:33 -0800 Subject: [PATCH 03/40] Add tokens as special tokens --- aitextgen/tokenizers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aitextgen/tokenizers.py b/aitextgen/tokenizers.py index 89926af..a90e174 100644 --- a/aitextgen/tokenizers.py +++ b/aitextgen/tokenizers.py @@ -47,13 +47,11 @@ def train_tokenizer( tokenizer.train( files=files, - vocab_size=vocab_size - len(added_tokens), + vocab_size=vocab_size, min_frequency=min_frequency, - special_tokens=[bos_token, eos_token, unk_token], + special_tokens=[bos_token, eos_token, unk_token] + added_tokens, ) - tokenizer.add_tokens(added_tokens) - PREFIX = "aitextgen" save_path_str = "the current directory" if save_path == "" else save_path if serialize: From 1f0529ab8b3d6191f6fe0acc3a09ae72cc00d08e Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 5 Dec 2020 19:49:42 -0800 Subject: [PATCH 04/40] Remove logging, add prefix param --- aitextgen/tokenizers.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/aitextgen/tokenizers.py b/aitextgen/tokenizers.py index a90e174..68aa306 100644 --- a/aitextgen/tokenizers.py +++ b/aitextgen/tokenizers.py @@ -1,8 +1,5 @@ from tokenizers import ByteLevelBPETokenizer from typing import Union, List -import logging - -logger = logging.getLogger(__name__) def train_tokenizer( @@ -10,6 +7,7 @@ def train_tokenizer( dropout: float = None, vocab_size: int = 1000, min_frequency: int = 2, + prefix: str = "aitextgen", save_path: str = "", added_tokens: List[str] = [], bos_token: str = "<|endoftext|>", @@ -27,6 +25,7 @@ def train_tokenizer( :param dropout: Training dropout :param vocab_size: Final vocabulary size :param min_frequency: Minimum number of occurences to add to vocab + :param prefix: File name prefix of the final tokenizer :param save_path: Where to save the final tokenizer :param added_tokens: List of tokens to add to the tokenizer (currently not working) :param bos_token: Beginning-of-string special token @@ -52,17 +51,7 @@ def train_tokenizer( special_tokens=[bos_token, eos_token, unk_token] + added_tokens, ) - PREFIX = "aitextgen" - save_path_str = "the current directory" if save_path == "" else save_path if serialize: - logger.info( - f"Saving {PREFIX}.tokenizer.json to {save_path_str}. " - + "You will need this file to build the GPT2Tokenizer." - ) - tokenizer.save(f"{PREFIX}.tokenizer.json") + tokenizer.save(f"{prefix}.tokenizer.json") else: - logger.info( - f"Saving {PREFIX}-vocab.json and {PREFIX}-merges.txt to {save_path_str}. " - + "You will need both files to build the GPT2Tokenizer." - ) - tokenizer.save_model(save_path, PREFIX) + tokenizer.save_model(save_path, prefix) From 473140abde44d95c7a03a137eb70c8fc7d466299 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 5 Dec 2020 21:16:59 -0800 Subject: [PATCH 05/40] Explicit fast tokenizers, allow serialized loading --- aitextgen/aitextgen.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index c1fe644..3c295aa 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -1,6 +1,6 @@ from transformers import ( GPT2LMHeadModel, - GPT2Tokenizer, + GPT2TokenizerFast, GPT2Config, AutoConfig, ) @@ -30,7 +30,7 @@ import shutil try: - import torch_xla.core.xla_model as xm + import torch_xla.core.xla_model as xm # noqa except ImportError: pass @@ -80,6 +80,7 @@ def __init__( config: Union[str, GPT2Config] = None, vocab_file: str = None, merges_file: str = None, + tokenizer_file: str = None, cache_dir: str = "aitextgen", tf_gpt2: str = None, to_gpu: bool = False, @@ -172,7 +173,7 @@ def __init__( ) if model and "gpt2" not in model: logger.info(f"Using the tokenizer for {model}.") - self.tokenizer = GPT2Tokenizer.from_pretrained( + self.tokenizer = GPT2TokenizerFast.from_pretrained( model, cache_dir=cache_dir, ) @@ -197,14 +198,26 @@ def __init__( else: logger.info("Using the default GPT-2 Tokenizer.") - self.tokenizer = GPT2Tokenizer( - vocab_file=self.vocab_file, - merges_file=self.merges_file, - bos_token=self.bos_token, - eos_token=self.eos_token, - unk_token=self.unk_token, - pad_token=self.pad_token, - ) + if tokenizer_file: + # load the custom GPT-3 tokenizer from a serialized tokenizer + self.tokenizer = GPT2TokenizerFast( + vocab_file=None, + merges_file=None, + tokenizer_file=tokenizer_file, + bos_token=self.bos_token, + eos_token=self.eos_token, + unk_token=self.unk_token, + pad_token=self.pad_token, + ) + else: + self.tokenizer = GPT2TokenizerFast( + vocab_file=self.vocab_file, + merges_file=self.merges_file, + bos_token=self.bos_token, + eos_token=self.eos_token, + unk_token=self.unk_token, + pad_token=self.pad_token, + ) self.tokenizer.padding_side = "left" From 29a84041ce8f904cac1eccc11f570814d0749909 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 5 Dec 2020 21:32:24 -0800 Subject: [PATCH 06/40] Serialized tokenizers for TokenDataset --- aitextgen/TokenDataset.py | 31 +++++++++++++++++++++++-------- aitextgen/aitextgen.py | 5 ++--- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/aitextgen/TokenDataset.py b/aitextgen/TokenDataset.py index a0cb985..2752a0c 100644 --- a/aitextgen/TokenDataset.py +++ b/aitextgen/TokenDataset.py @@ -53,6 +53,8 @@ def __init__( file_path: str = None, vocab_file: str = os.path.join(STATIC_PATH, "gpt2_vocab.json"), merges_file: str = os.path.join(STATIC_PATH, "gpt2_merges.txt"), + tokenizer: GPT2TokenizerFast = None, + tokenizer_file: str = None, texts: List[str] = None, line_by_line: bool = False, from_cache: bool = False, @@ -82,14 +84,27 @@ def __init__( assert any([texts, file_path]), "texts or file_path must be specified." - tokenizer = GPT2TokenizerFast( - vocab_file=vocab_file, - merges_file=merges_file, - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - ) + if not tokenizer: + if tokenizer_file: + # load the custom GPT-2 tokenizer from a serialized tokenizer + tokenizer = GPT2TokenizerFast( + vocab_file=None, + merges_file=None, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + ) + else: + tokenizer = GPT2TokenizerFast( + vocab_file=vocab_file, + merges_file=merges_file, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + ) # If a cache path is provided, load it. if from_cache: diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 3c295aa..642ae70 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -199,7 +199,7 @@ def __init__( logger.info("Using the default GPT-2 Tokenizer.") if tokenizer_file: - # load the custom GPT-3 tokenizer from a serialized tokenizer + # load the custom GPT-2 tokenizer from a serialized tokenizer self.tokenizer = GPT2TokenizerFast( vocab_file=None, merges_file=None, @@ -487,8 +487,7 @@ def train( if isinstance(train_data, str): train_data = TokenDataset( - vocab_file=self.vocab_file, - merges_file=self.merges_file, + tokenizer=self.tokenizer, bos_token=self.bos_token, eos_token=self.eos_token, unk_token=self.unk_token, From b38632ab2adc6ffea0e6b6e11dbcb6af7b4e533c Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 6 Dec 2020 20:54:10 -0800 Subject: [PATCH 07/40] Loading of schema tokens --- aitextgen/aitextgen.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 642ae70..d235e6a 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -28,6 +28,7 @@ from typing import Union, Optional, List from pkg_resources import resource_filename import shutil +import json try: import torch_xla.core.xla_model as xm # noqa @@ -209,6 +210,11 @@ def __init__( unk_token=self.unk_token, pad_token=self.pad_token, ) + with open(tokenizer_file, "r", encoding="utf-8") as f: + data = json.load(f) + self.schema_tokens = { + x["id"]: x["content"] for x in data["added_tokens"] + } else: self.tokenizer = GPT2TokenizerFast( vocab_file=self.vocab_file, From 60af6c6ccc417f4a54ec648f9ea3eaf8b4072348 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 6 Dec 2020 20:54:40 -0800 Subject: [PATCH 08/40] Set refresh rate to 20 for consistency --- aitextgen/TokenDataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aitextgen/TokenDataset.py b/aitextgen/TokenDataset.py index 2752a0c..b2d39ac 100644 --- a/aitextgen/TokenDataset.py +++ b/aitextgen/TokenDataset.py @@ -69,7 +69,7 @@ def __init__( eos_token: str = "<|endoftext|>", unk_token: str = "<|endoftext|>", pad_token: str = "<|endoftext|>", - progress_bar_refresh_rate: int = 10, + progress_bar_refresh_rate: int = 20, **kwargs, ) -> None: @@ -260,7 +260,7 @@ def encode_tokens_from_file( tokenizer: GPT2TokenizerFast, newline: str, header: bool = True, - progress_bar_refresh_rate: int = 10, + progress_bar_refresh_rate: int = 20, batch_size: int = 1024, ) -> List[int]: """ @@ -352,7 +352,7 @@ def encode_tokens_from_list( texts: List[str], eos_token: str, tokenizer: GPT2TokenizerFast, - progress_bar_refresh_rate: int = 10, + progress_bar_refresh_rate: int = 20, batch_size: int = 1024, ) -> List[int]: """ From 62679e77d4328a469f942e88aaecfe8d36d566aa Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Wed, 9 Dec 2020 22:05:35 -0800 Subject: [PATCH 09/40] Do not add special tokens as schema tokens --- aitextgen/aitextgen.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index d235e6a..44ea6d0 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -213,7 +213,10 @@ def __init__( with open(tokenizer_file, "r", encoding="utf-8") as f: data = json.load(f) self.schema_tokens = { - x["id"]: x["content"] for x in data["added_tokens"] + x["id"]: x["content"] + for x in data["added_tokens"] + if x["content"] + not in [self.bos_token, self.eos_token, self.unk_token] } else: self.tokenizer = GPT2TokenizerFast( From 589ef31dc5de8a8ad7af4bed0449011cad29e622 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 13 Dec 2020 19:23:43 -0800 Subject: [PATCH 10/40] Use model config instead for schema tokens --- aitextgen/aitextgen.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 44ea6d0..d76f08a 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -82,6 +82,8 @@ def __init__( vocab_file: str = None, merges_file: str = None, tokenizer_file: str = None, + schema_tokens: List[str] = None, + schema_return: List[str] = None, cache_dir: str = "aitextgen", tf_gpt2: str = None, to_gpu: bool = False, @@ -179,6 +181,12 @@ def __init__( cache_dir=cache_dir, ) + if schema_tokens: + self.model.config["schema_tokens"] = schema_tokens + + if schema_tokens: + self.model.config["schema_return"] = schema_return + if self.tokenizer is None: # Update tokenizer settings (if not set already) args = locals() @@ -210,14 +218,6 @@ def __init__( unk_token=self.unk_token, pad_token=self.pad_token, ) - with open(tokenizer_file, "r", encoding="utf-8") as f: - data = json.load(f) - self.schema_tokens = { - x["id"]: x["content"] - for x in data["added_tokens"] - if x["content"] - not in [self.bos_token, self.eos_token, self.unk_token] - } else: self.tokenizer = GPT2TokenizerFast( vocab_file=self.vocab_file, @@ -250,6 +250,9 @@ def generate( return_as_list: bool = False, seed: int = None, pad_token_id: str = None, + schema: str = None, + schema_tokens: List[str] = None, + schema_return: List[str] = None, **kwargs, ) -> Optional[str]: """ From 434f4177758040f7f6a4e406b7b5ca9999692b97 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 13 Dec 2020 20:03:39 -0800 Subject: [PATCH 11/40] find_index_of_subset() --- aitextgen/utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/aitextgen/utils.py b/aitextgen/utils.py index b06bed3..602d17a 100644 --- a/aitextgen/utils.py +++ b/aitextgen/utils.py @@ -136,3 +136,20 @@ def GPT2ConfigCPU( eos_token_id=eos_token_id, **kwargs, ) + + +def find_index_of_subset(large_list, small_list): + """ + Returns the index after the small_list within the large list, + Returns None if not present. + + Adapted from https://stackoverflow.com/a/45819222 which shows that it is + performant for input lengths < 1000, which is the common case for this function. + """ + length_small_list = len(small_list) + firstneedle = small_list[0] + for idx, item in enumerate(large_list): + if item == firstneedle: + if large_list[idx : idx + length_small_list] == small_list: + return idx + length_small_list + return None From 62656d4ddc031243cce6a94aa02437ebbbe2d403 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Mon, 14 Dec 2020 19:12:17 -0800 Subject: [PATCH 12/40] First pass of schema-aware generation --- aitextgen/aitextgen.py | 70 ++++++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index d76f08a..237e13e 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -22,6 +22,7 @@ download_gpt2, set_seed, reset_seed, + find_index_of_subset, ) from .train import ATGTransformer, ATGProgressBar from .colab import create_gdrive_folder @@ -250,9 +251,7 @@ def generate( return_as_list: bool = False, seed: int = None, pad_token_id: str = None, - schema: str = None, - schema_tokens: List[str] = None, - schema_return: List[str] = None, + schema: str = False, **kwargs, ) -> Optional[str]: """ @@ -310,25 +309,58 @@ def generate( if seed: reset_seed() - if n > 1: - gen_texts = [ - self.tokenizer.decode(output, skip_special_tokens=True) - for output in outputs - ] + # Schema token handling + if schema: + schema_tokens = self.model.config.get("schema_tokens") + schema_return = self.model.config.get("schema_return") + schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"] + + outputs = outputs.tolist() + gen_texts = [] + for output in outputs: + gen_text_dict = {} + index = 0 + for i, token_enc in enumerate(schema_tokens_enc): + end_index = find_index_of_subset(output, token_enc) + gen_text_dict[schema_tokens[i]] = output[index:end_index] + index = end_index + + # remove fields not in schema_return + if schema_return: + for key in gen_text_dict.keys(): + if key not in schema_return: + gen_text_dict.pop(key, None) + + gen_texts.append(gen_text_dict) + + if not return_as_list: + print(*gen_texts, sep="\n" + "=" * 10 + "\n") + else: + if n > 1: + return gen_texts + else: + return gen_texts[0] + + # Typical use case else: - gen_texts = [self.tokenizer.decode(outputs[0], skip_special_tokens=True)] + if n > 1: + gen_texts = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + else: + gen_texts = self.tokenizer.decode(outputs[0], skip_special_tokens=True) - if not return_as_list: - if prompt is not None: - # Bold the prompt if printing to console - gen_texts = [ - text.replace(prompt_text, f"\033[1m{prompt_text}\033[0m", 1) - for text in gen_texts - ] + if not return_as_list: + if prompt is not None: + # Bold the prompt if printing to console + gen_texts = [ + text.replace(prompt_text, f"\033[1m{prompt_text}\033[0m", 1) + for text in gen_texts + ] - print(*gen_texts, sep="\n" + "=" * 10 + "\n") - else: - return gen_texts + print(*gen_texts, sep="\n" + "=" * 10 + "\n") + else: + return gen_texts def generate_one(self, **kwargs) -> None: """ From 4d508d4e9d6aef117399a68e9f4ca48965b9ed8e Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Mon, 14 Dec 2020 19:18:41 -0800 Subject: [PATCH 13/40] Return as text if only returning one field --- aitextgen/aitextgen.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 237e13e..3b776c6 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -322,11 +322,15 @@ def generate( index = 0 for i, token_enc in enumerate(schema_tokens_enc): end_index = find_index_of_subset(output, token_enc) - gen_text_dict[schema_tokens[i]] = output[index:end_index] + gen_text_dict[schema_tokens[i]] = self.tokenizer.decode( + output[index:end_index], skip_special_tokens=True + ) index = end_index # remove fields not in schema_return if schema_return: + if len(schema_return) == 1: + gen_text_dict = gen_text_dict[schema_return[0]] for key in gen_text_dict.keys(): if key not in schema_return: gen_text_dict.pop(key, None) From cbdd1b5fc24d2d116d87528401297ef2bc7f124a Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 12:36:22 -0800 Subject: [PATCH 14/40] Allow training transformer layers only --- aitextgen/aitextgen.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 3b776c6..9478748 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -65,7 +65,7 @@ class aitextgen: :param unk_token: String to override the unknown token """ - torchscript = False + openai_gpt2_large = False # default values for GPT2Tokenizer tokenizer = None @@ -119,6 +119,9 @@ def __init__( "1558M", ], "Invalid TensorFlow GPT-2 model size." + if tf_gpt2 != "124M": + self.openai_gpt2_large = True + logger.info( f"Downloading the {tf_gpt2} GPT-2 TensorFlow weights/config " + "from Google's servers" @@ -484,6 +487,7 @@ def train( save_gdrive: bool = False, run_id: str = f"ATG_{datetime.utcnow():%Y%m%d_%H%M%S}", progress_bar_refresh_rate: int = 20, + train_transformers_only: bool = False, **kwargs, ) -> None: """ @@ -519,8 +523,6 @@ def train( the progress bar while training. """ - assert not self.torchscript, "You cannot train a traced TorchScript model." - if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -622,9 +624,21 @@ def train( if n_gpu > 1: train_params["distributed_backend"] = "ddp" + if train_transformers_only or self.openai_gpt2_large: + logger.info("Training Transformer layers only.") + for name, param in train_model.model.named_parameters(): + if ".h." in name: + param.requires_grad = False + trainer = pl.Trainer(**train_params) trainer.fit(train_model) + # Unfreeze Transformer layers after done + if train_transformers_only or self.openai_gpt2_large: + for name, param in train_model.model.named_parameters(): + if ".h." in name: + param.requires_grad = True + logger.info(f"Saving trained model pytorch_model.bin to /{output_dir}") self.model.save_pretrained(output_dir) From b8630a383a4ee625baa1efe43f241cf7004c0b33 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 12:58:03 -0800 Subject: [PATCH 15/40] Add gradient_checkpointing config option --- aitextgen/aitextgen.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 9478748..6011d8a 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -90,6 +90,7 @@ def __init__( to_gpu: bool = False, to_fp16: bool = False, verbose: bool = False, + gradient_checkpointing: bool = False, bos_token: str = None, eos_token: str = None, unk_token: str = None, @@ -121,6 +122,7 @@ def __init__( if tf_gpt2 != "124M": self.openai_gpt2_large = True + gradient_checkpointing = True logger.info( f"Downloading the {tf_gpt2} GPT-2 TensorFlow weights/config " @@ -185,6 +187,9 @@ def __init__( cache_dir=cache_dir, ) + if gradient_checkpointing: + self.model.config["gradient_checkpointing"] = True + if schema_tokens: self.model.config["schema_tokens"] = schema_tokens From fe1a0b95f7dafddcb3c302a9dad1cc3310af0374 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 13:31:49 -0800 Subject: [PATCH 16/40] usse setattr for config assignment --- aitextgen/aitextgen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 6011d8a..0d61db7 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -188,13 +188,13 @@ def __init__( ) if gradient_checkpointing: - self.model.config["gradient_checkpointing"] = True + setattr(self.model.config, "gradient_checkpointing", True) if schema_tokens: - self.model.config["schema_tokens"] = schema_tokens + setattr(self.model.config, "schema_tokens", schema_tokens) if schema_tokens: - self.model.config["schema_return"] = schema_return + setattr(self.model.config, "schema_return", schema_return) if self.tokenizer is None: # Update tokenizer settings (if not set already) From b2af94cb0db99fdf4355887fe96ec3b0dc14c106 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 14:45:47 -0800 Subject: [PATCH 17/40] Move layer freezing logic to train.py --- aitextgen/aitextgen.py | 26 ++++++++++---------------- aitextgen/train.py | 25 ++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 0d61db7..58d3a04 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -109,6 +109,10 @@ def __init__( logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) if tf_gpt2: + if tf_gpt2 != "124M": + self.openai_gpt2_large = True + gradient_checkpointing = True + # Download + convert the TF weights if a PyTorch model has not been created if not os.path.isfile( os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin") @@ -120,10 +124,6 @@ def __init__( "1558M", ], "Invalid TensorFlow GPT-2 model size." - if tf_gpt2 != "124M": - self.openai_gpt2_large = True - gradient_checkpointing = True - logger.info( f"Downloading the {tf_gpt2} GPT-2 TensorFlow weights/config " + "from Google's servers" @@ -551,6 +551,10 @@ def train( **kwargs, ) + if train_transformers_only or self.openai_gpt2_large: + logger.info("Training Transformer layers only.") + train_transformers_only = True + if num_workers is None and tpu_cores == 0: # Use all CPU cores as workers if not training on CPU # Can overload 2x w/o diminishing returns @@ -597,6 +601,7 @@ def train( checkpoint_callback=False, logger=loggers if loggers else False, weights_summary=None, + progress_bar_refresh_rate=progress_bar_refresh_rate, # ignored callbacks=[ ATGProgressBar( save_every, @@ -608,6 +613,7 @@ def train( run_id, save_gdrive, progress_bar_refresh_rate, + train_transformers_only, ) ], ) @@ -629,21 +635,9 @@ def train( if n_gpu > 1: train_params["distributed_backend"] = "ddp" - if train_transformers_only or self.openai_gpt2_large: - logger.info("Training Transformer layers only.") - for name, param in train_model.model.named_parameters(): - if ".h." in name: - param.requires_grad = False - trainer = pl.Trainer(**train_params) trainer.fit(train_model) - # Unfreeze Transformer layers after done - if train_transformers_only or self.openai_gpt2_large: - for name, param in train_model.model.named_parameters(): - if ".h." in name: - param.requires_grad = True - logger.info(f"Saving trained model pytorch_model.bin to /{output_dir}") self.model.save_pretrained(output_dir) diff --git a/aitextgen/train.py b/aitextgen/train.py index 440d583..3e01cba 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -98,6 +98,7 @@ def __init__( run_id, save_gdrive, progress_bar_refresh_rate, + train_transformers_only, ): super().__init__() self.enabled = True @@ -112,6 +113,7 @@ def __init__( self.run_id = run_id self.save_gdrive = save_gdrive self.progress_bar_refresh_rate = progress_bar_refresh_rate + self.train_transformers_only = train_transformers_only def enabled(self): self.enabled = True @@ -129,6 +131,11 @@ def on_train_start(self, trainer, pl_module): dynamic_ncols=True, file=sys.stdout, ) + self.freeze_nontransformer_layers(pl_module) + + def on_train_end(self, trainer, pl_module): + self.main_progress_bar.close() + self.unfreeze_nontransformer_layers(pl_module) def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) @@ -170,13 +177,16 @@ def on_batch_end(self, trainer, pl_module): self.main_progress_bar.set_description(desc) if self.enabled: - if self.save_every > 0 and self.steps % self.save_every == 0: + self.unfreeze_nontransformer_layers(pl_module) self.save_pytorch_model(trainer, pl_module) if self.generate_every > 0 and self.steps % self.generate_every == 0: + self.unfreeze_nontransformer_layers(pl_module) self.generate_sample_text(trainer, pl_module) + self.freeze_nontransformer_layers(pl_module) + def generate_sample_text(self, trainer, pl_module): self.main_progress_bar.write( f"\033[1m{self.steps:,} steps reached: generating sample texts.\033[0m" @@ -219,3 +229,16 @@ def average_loss(self, current_loss, prev_avg_loss, smoothing): return current_loss else: return (smoothing * current_loss) + (1 - smoothing) * prev_avg_loss + + def modify_nontransformer_layers(self, pl_module, unfreeze): + if self.train_transformers_only: + layers = ["transformer.wte.weight"] + for name, param in pl_module.model.named_parameters(): + if name in layers: + param.requires_grad = unfreeze + + def freeze_nontransformer_layers(self, pl_module): + self.modify_nontransformer_layers(pl_module, False) + + def unfreeze_nontransformer_layers(self, pl_module): + self.modify_nontransformer_layers(pl_module, True) From fc4a7df23ed4a008ab242a7fc07f24287b4b05ed Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 15:30:15 -0800 Subject: [PATCH 18/40] Tweaks for generating text during train --- aitextgen/train.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/aitextgen/train.py b/aitextgen/train.py index 3e01cba..516565a 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -166,9 +166,8 @@ def on_batch_end(self, trainer, pl_module): "--format=csv,nounits,noheader", ], encoding="utf-8", - # capture_output=True, # valid for python version >=3.7 stdout=subprocess.PIPE, - stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 + stderr=subprocess.PIPE, check=True, ) gpu_memory = result.stdout.strip().split(os.linesep)[0] @@ -195,16 +194,23 @@ def generate_sample_text(self, trainer, pl_module): gen_length = min(pl_module.model.config.n_positions, 256) outputs = pl_module.model.generate( + input_ids=None, max_length=gen_length, do_sample=True, num_return_sequences=self.n_generate, temperature=0.7, pad_token_id=pl_module.tokenizer.pad_token_id, ) - gen_texts = [ - pl_module.tokenizer.decode(output, skip_special_tokens=True) - for output in outputs - ] + + if self.n_generate > 1: + gen_texts = pl_module.tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + else: + gen_texts = [ + pl_module.tokenizer.decode(outputs[0], skip_special_tokens=True) + ] + for text in gen_texts: self.main_progress_bar.write("=" * 10) self.main_progress_bar.write(text) @@ -232,9 +238,8 @@ def average_loss(self, current_loss, prev_avg_loss, smoothing): def modify_nontransformer_layers(self, pl_module, unfreeze): if self.train_transformers_only: - layers = ["transformer.wte.weight"] for name, param in pl_module.model.named_parameters(): - if name in layers: + if name == "transformer.wte.weight": param.requires_grad = unfreeze def freeze_nontransformer_layers(self, pl_module): From ffee20c45ad6fa22d2babf5fb98bbbe9e48924a9 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 16:09:21 -0800 Subject: [PATCH 19/40] Use getattr instead --- aitextgen/aitextgen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 58d3a04..93a1f7b 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -319,8 +319,8 @@ def generate( # Schema token handling if schema: - schema_tokens = self.model.config.get("schema_tokens") - schema_return = self.model.config.get("schema_return") + schema_tokens = getattr(self.model.config, "schema_tokens") + schema_return = getattr(self.model.config, "schema_return") schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"] outputs = outputs.tolist() From d0c8414dff9537ce22bdf4dd2405d2b92dc865e4 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 16:26:32 -0800 Subject: [PATCH 20/40] Add num_layers_freeze --- aitextgen/aitextgen.py | 2 ++ aitextgen/train.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 93a1f7b..e93af00 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -493,6 +493,7 @@ def train( run_id: str = f"ATG_{datetime.utcnow():%Y%m%d_%H%M%S}", progress_bar_refresh_rate: int = 20, train_transformers_only: bool = False, + num_layers_freeze: int = None, **kwargs, ) -> None: """ @@ -614,6 +615,7 @@ def train( save_gdrive, progress_bar_refresh_rate, train_transformers_only, + num_layers_freeze, ) ], ) diff --git a/aitextgen/train.py b/aitextgen/train.py index 516565a..735757e 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -99,6 +99,7 @@ def __init__( save_gdrive, progress_bar_refresh_rate, train_transformers_only, + num_layers_freeze, ): super().__init__() self.enabled = True @@ -114,6 +115,7 @@ def __init__( self.save_gdrive = save_gdrive self.progress_bar_refresh_rate = progress_bar_refresh_rate self.train_transformers_only = train_transformers_only + self.num_layers_freeze = num_layers_freeze def enabled(self): self.enabled = True @@ -239,7 +241,12 @@ def average_loss(self, current_loss, prev_avg_loss, smoothing): def modify_nontransformer_layers(self, pl_module, unfreeze): if self.train_transformers_only: for name, param in pl_module.model.named_parameters(): - if name == "transformer.wte.weight": + if self.num_layers_freeze: + layer_num = int(name.split(".")[2]) if ".h." in name else None + to_freeze = layer_num and layer_num < self.num_layers_freeze + else: + to_freeze = False + if name == "transformer.wte.weight" or to_freeze: param.requires_grad = unfreeze def freeze_nontransformer_layers(self, pl_module): From 8bf2539950926b48f2d7ddd3115fb7efd53a7a23 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 16:43:23 -0800 Subject: [PATCH 21/40] make naming more consistent for freezing --- aitextgen/aitextgen.py | 13 ++++++++----- aitextgen/train.py | 24 ++++++++++++++---------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index e93af00..0cb6d61 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -492,7 +492,7 @@ def train( save_gdrive: bool = False, run_id: str = f"ATG_{datetime.utcnow():%Y%m%d_%H%M%S}", progress_bar_refresh_rate: int = 20, - train_transformers_only: bool = False, + freeze_layers: bool = False, num_layers_freeze: int = None, **kwargs, ) -> None: @@ -552,9 +552,12 @@ def train( **kwargs, ) - if train_transformers_only or self.openai_gpt2_large: - logger.info("Training Transformer layers only.") - train_transformers_only = True + if freeze_layers or self.openai_gpt2_large: + freeze_layers = True + if num_layers_freeze: + assert ( + num_layers_freeze < self.model.config.n_layer + ), "You are freezing more Transformer layers than in the model." if num_workers is None and tpu_cores == 0: # Use all CPU cores as workers if not training on CPU @@ -614,7 +617,7 @@ def train( run_id, save_gdrive, progress_bar_refresh_rate, - train_transformers_only, + freeze_layers, num_layers_freeze, ) ], diff --git a/aitextgen/train.py b/aitextgen/train.py index 735757e..1e64019 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -133,11 +133,11 @@ def on_train_start(self, trainer, pl_module): dynamic_ncols=True, file=sys.stdout, ) - self.freeze_nontransformer_layers(pl_module) + self.freeze_layers(pl_module) def on_train_end(self, trainer, pl_module): self.main_progress_bar.close() - self.unfreeze_nontransformer_layers(pl_module) + self.unfreeze_layers(pl_module) def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) @@ -178,15 +178,19 @@ def on_batch_end(self, trainer, pl_module): self.main_progress_bar.set_description(desc) if self.enabled: + did_unfreeze = False if self.save_every > 0 and self.steps % self.save_every == 0: - self.unfreeze_nontransformer_layers(pl_module) + self.unfreeze_layers(pl_module) self.save_pytorch_model(trainer, pl_module) + did_unfreeze = True if self.generate_every > 0 and self.steps % self.generate_every == 0: - self.unfreeze_nontransformer_layers(pl_module) + self.unfreeze_layers(pl_module) self.generate_sample_text(trainer, pl_module) + did_unfreeze = True - self.freeze_nontransformer_layers(pl_module) + if did_unfreeze: + self.freeze_layers(pl_module) def generate_sample_text(self, trainer, pl_module): self.main_progress_bar.write( @@ -238,7 +242,7 @@ def average_loss(self, current_loss, prev_avg_loss, smoothing): else: return (smoothing * current_loss) + (1 - smoothing) * prev_avg_loss - def modify_nontransformer_layers(self, pl_module, unfreeze): + def modify_layers(self, pl_module, unfreeze): if self.train_transformers_only: for name, param in pl_module.model.named_parameters(): if self.num_layers_freeze: @@ -249,8 +253,8 @@ def modify_nontransformer_layers(self, pl_module, unfreeze): if name == "transformer.wte.weight" or to_freeze: param.requires_grad = unfreeze - def freeze_nontransformer_layers(self, pl_module): - self.modify_nontransformer_layers(pl_module, False) + def freeze_layers(self, pl_module): + self.modify_layers(pl_module, False) - def unfreeze_nontransformer_layers(self, pl_module): - self.modify_nontransformer_layers(pl_module, True) + def unfreeze_layers(self, pl_module): + self.modify_layers(pl_module, True) From 6fc77f9e558658f98919cfb4a2cfdb7d594e0d3c Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 24 Dec 2020 17:58:13 -0800 Subject: [PATCH 22/40] Add trim_offsets tokenizer param --- aitextgen/tokenizers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aitextgen/tokenizers.py b/aitextgen/tokenizers.py index 68aa306..b7ed3db 100644 --- a/aitextgen/tokenizers.py +++ b/aitextgen/tokenizers.py @@ -14,6 +14,7 @@ def train_tokenizer( eos_token: str = "<|endoftext|>", unk_token: str = "<|endoftext|>", serialize: bool = True, + trim_offsets: bool = True, ) -> None: """ Tokenizes the text(s) as a tokenizer, wrapping the tokenizer package. @@ -42,7 +43,7 @@ def train_tokenizer( if isinstance(files, str): files = [files] - tokenizer = ByteLevelBPETokenizer(dropout=dropout) + tokenizer = ByteLevelBPETokenizer(dropout=dropout, trim_offsets=trim_offsets) tokenizer.train( files=files, From e60eb716a55dca2746d5f7a0c28edb2755de8612 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Fri, 25 Dec 2020 17:21:32 -0800 Subject: [PATCH 23/40] some potential TPU fixes --- aitextgen/aitextgen.py | 8 ++++++-- aitextgen/train.py | 20 +++++++++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 0cb6d61..c1fe01c 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -559,11 +559,14 @@ def train( num_layers_freeze < self.model.config.n_layer ), "You are freezing more Transformer layers than in the model." - if num_workers is None and tpu_cores == 0: + if num_workers is None: # Use all CPU cores as workers if not training on CPU # Can overload 2x w/o diminishing returns if is_gpu_used: num_workers = os.cpu_count() * 2 + # TPUs want same amount of workers as CPUs + elif tpu_cores > 0: + num_workers = os.cpu_count() # If training on the CPU, use half the CPUs else: num_workers = int(os.cpu_count() / 2) @@ -575,10 +578,11 @@ def train( warmup_steps=warmup_steps, batch_size=batch_size, num_steps=num_steps, - pin_memory=True if is_gpu_used else False, + pin_memory=is_gpu_used, num_workers=num_workers, save_every=save_every, generate_every=generate_every, + use_tpu=tpu_cores > 0, ) # Wrap the model in a pytorch-lightning module diff --git a/aitextgen/train.py b/aitextgen/train.py index 1e64019..c45c84a 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -10,6 +10,11 @@ import shutil import subprocess +try: + import torch_xla.core.xla_model as xm # noqa +except ImportError: + pass + class ATGTransformer(pl.LightningModule): """ @@ -38,11 +43,20 @@ def training_step(self, batch, batch_num): def train_dataloader(self): "Load datasets. Called after prepare data." + sampler = None + if self.hparams.use_tpu: + sampler = torch.utils.data.distributed.DistributedSampler( + self.dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=True, + ) return DataLoader( self.dataset, + sampler=sampler, batch_size=self.hparams["batch_size"], - shuffle=True, + shuffle=not sampler, pin_memory=self.hparams["pin_memory"], num_workers=self.hparams["num_workers"], ) @@ -210,11 +224,11 @@ def generate_sample_text(self, trainer, pl_module): if self.n_generate > 1: gen_texts = pl_module.tokenizer.batch_decode( - outputs, skip_special_tokens=True + outputs, skip_special_tokens=False ) else: gen_texts = [ - pl_module.tokenizer.decode(outputs[0], skip_special_tokens=True) + pl_module.tokenizer.decode(outputs[0], skip_special_tokens=False) ] for text in gen_texts: From 6f0f497b5eddf9c0e5b72f8f6d4227b979186d2d Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Fri, 25 Dec 2020 18:48:59 -0800 Subject: [PATCH 24/40] minor cleanup --- aitextgen/TokenDataset.py | 4 ++-- aitextgen/aitextgen.py | 3 +-- aitextgen/train.py | 6 ------ 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/aitextgen/TokenDataset.py b/aitextgen/TokenDataset.py index b2d39ac..89f07f9 100644 --- a/aitextgen/TokenDataset.py +++ b/aitextgen/TokenDataset.py @@ -311,7 +311,7 @@ def encode_tokens_from_file( if not batch: break - encoded_texts = tokenizer.batch_encode_plus( + encoded_texts = tokenizer( batch, add_special_tokens=False, return_token_type_ids=False, @@ -379,7 +379,7 @@ def encode_tokens_from_list( ] ] - encoded_texts = tokenizer.batch_encode_plus( + encoded_texts = tokenizer( batch, add_special_tokens=False, return_token_type_ids=False, diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index c1fe01c..1e2f241 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -29,7 +29,6 @@ from typing import Union, Optional, List from pkg_resources import resource_filename import shutil -import json try: import torch_xla.core.xla_model as xm # noqa @@ -638,7 +637,7 @@ def train( # benchmark gives a boost for GPUs if input size is constant, # which will always be the case with aitextgen training - if n_gpu != 0 and benchmark: + if is_gpu_used and benchmark: train_params["benchmark"] = True if n_gpu > 1: diff --git a/aitextgen/train.py b/aitextgen/train.py index c45c84a..55cb148 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -1,4 +1,3 @@ -from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks.progress import ProgressBarBase from tqdm.auto import tqdm @@ -10,11 +9,6 @@ import shutil import subprocess -try: - import torch_xla.core.xla_model as xm # noqa -except ImportError: - pass - class ATGTransformer(pl.LightningModule): """ From 94b6eebe77f2df3f2bf77cf13617d22d797d4935 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 2 Jan 2021 15:24:02 -0800 Subject: [PATCH 25/40] fix TPU changes --- aitextgen/train.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/aitextgen/train.py b/aitextgen/train.py index 55cb148..1d65f2b 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -4,6 +4,7 @@ import sys import torch from torch.optim import AdamW +from torch.utils.data import DataLoader from transformers import get_linear_schedule_with_warmup import os import shutil @@ -28,29 +29,16 @@ def forward(self, inputs): return self.model(**inputs, return_dict=False) def training_step(self, batch, batch_num): - "Compute loss and log." - outputs = self({"input_ids": batch, "labels": batch}) loss = outputs[0] return {"loss": loss} def train_dataloader(self): - "Load datasets. Called after prepare data." - sampler = None - if self.hparams.use_tpu: - sampler = torch.utils.data.distributed.DistributedSampler( - self.dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - shuffle=True, - ) - return DataLoader( self.dataset, - sampler=sampler, batch_size=self.hparams["batch_size"], - shuffle=not sampler, + shuffle=True, pin_memory=self.hparams["pin_memory"], num_workers=self.hparams["num_workers"], ) From 202e03170b06749b2f6d3ead0033c9dd53b6a2c1 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 2 Jan 2021 20:08:55 -0800 Subject: [PATCH 26/40] More correct implementation of schema extraction --- aitextgen/aitextgen.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 1e2f241..eb34bba 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -29,6 +29,7 @@ from typing import Union, Optional, List from pkg_resources import resource_filename import shutil +import re try: import torch_xla.core.xla_model as xm # noqa @@ -259,6 +260,7 @@ def generate( seed: int = None, pad_token_id: str = None, schema: str = False, + normalize_key: bool = True, **kwargs, ) -> Optional[str]: """ @@ -322,17 +324,35 @@ def generate( schema_return = getattr(self.model.config, "schema_return") schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"] + nonalphanum_pattern = re.compile(r"[\W_]+") + outputs = outputs.tolist() gen_texts = [] for output in outputs: gen_text_dict = {} - index = 0 - for i, token_enc in enumerate(schema_tokens_enc): - end_index = find_index_of_subset(output, token_enc) - gen_text_dict[schema_tokens[i]] = self.tokenizer.decode( - output[index:end_index], skip_special_tokens=True + + # Get indices of each schema token within the text + schema_token_indices = [ + (schema_tokens[i], find_index_of_subset(output, token_enc)) + for i, token_enc in enumerate(schema_tokens_enc) + ] + schema_token_indices.sort(key=lambda x: x[1]) + + for i, token_tuple in enumerate(schema_token_indices): + start_index = token_tuple[1] + end_index = ( + schema_token_indices[i + 1][1] - 1 + if i + 1 < len(schema_token_indices) + else None + ) + key = ( + nonalphanum_pattern.sub("", token_tuple[0]) + if normalize_key + else token_tuple[0] + ) + gen_text_dict[key] = self.tokenizer.decode( + output[start_index:end_index], skip_special_tokens=False ) - index = end_index # remove fields not in schema_return if schema_return: From fdb6778d819ad9810cb72797a6623df7321a9e98 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 2 Jan 2021 21:08:55 -0800 Subject: [PATCH 27/40] Update CHANGELOG --- CHANGELOG.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9303655..4227d9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,24 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [0.4.0] - TBD + +- Made Fast tokenizers the default (as it is the default in `transformers` 4.0.0) +- Made serialized tokenizers the default for custom tokenizers, and added support for loading them for both `aitextgen` and `TokenDataset`s +- Added gradient checkpointing for GPT-2, and set it to the default for training 355M and 774M. +- Added layer freezing to freeze the first `n` layers of GPT-2 while training. This allows 1.5B GPT-2 to be trained with a high `n`. +- Added schema-based generation for specificed schema_tokens (which can be encoded in the Transformers config). This can be used with an appropriate dataset for schema-based generation. + +## [0.3.0] - 2020-11-30 + +- Increased minimum versions of dependencies (`transformers` to 4.0.0, `pytorch-lightning` to 1.0.8, Pytorch to 1.6) +- Fixed imports to account for new Transfomers file architecture +- Fixed training to account for new transformer/pytorch-lightning minimums +- Fully removed TorchScript code (ONNX implementation will supercede it) +- Made prompt specification for generation more canonical with Transformers +- Set default `vocab` size for new tokenizers to `1000` +- Began work on serializing tokenizers in accordance to the new `tokenizers` approach + ## [0.2.1] - 2020-06-28 ### Added From a842bc329fd0c8d6d1125cfe83417ef69f16ded1 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 16 Jan 2021 17:13:52 -0800 Subject: [PATCH 28/40] Bump dependencies --- aitextgen/aitextgen.py | 2 +- requirements.txt | 4 ++-- setup.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index eb34bba..e5cce37 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -324,7 +324,7 @@ def generate( schema_return = getattr(self.model.config, "schema_return") schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"] - nonalphanum_pattern = re.compile(r"[\W_]+") + nonalphanum_pattern = re.compile(r"[\W_]+", re.UNICODE) outputs = outputs.tolist() gen_texts = [] diff --git a/requirements.txt b/requirements.txt index b63ad53..4fda52f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -transformers>=4.0.0 +transformers>=4.2.1 fire>=0.3.0 -pytorch-lightning>=1.0.8 +pytorch-lightning>=1.1.4 tokenizers>=0.9.4 torch>=1.6.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 99004b3..69bea4a 100644 --- a/setup.py +++ b/setup.py @@ -17,9 +17,9 @@ python_requires=">=3.6", include_package_data=True, install_requires=[ - "transformers>=4.0.0", + "transformers>=4.2.1", "fire>=0.3.0", - "pytorch-lightning>=1.0.8", + "pytorch-lightning>=1.1.4", "tokenizers>=0.9.4", "torch>=1.6.0", ], From fbbdc1bc65b0a0df44f93ec9645a9a396c305e98 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 21 Jan 2021 20:16:55 -0800 Subject: [PATCH 29/40] bump dependencies --- requirements.txt | 3 +-- setup.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4fda52f..5ca40e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -transformers>=4.2.1 +transformers>=4.2.2 fire>=0.3.0 pytorch-lightning>=1.1.4 -tokenizers>=0.9.4 torch>=1.6.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 69bea4a..f0d2d47 100644 --- a/setup.py +++ b/setup.py @@ -17,10 +17,9 @@ python_requires=">=3.6", include_package_data=True, install_requires=[ - "transformers>=4.2.1", + "transformers>=4.2.2", "fire>=0.3.0", "pytorch-lightning>=1.1.4", - "tokenizers>=0.9.4", "torch>=1.6.0", ], ) From 1796bd605172390243449bf2bcd7fbed7b928259 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Mon, 8 Feb 2021 17:29:48 -0800 Subject: [PATCH 30/40] Bump transformer version --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5ca40e1..5d8adca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers>=4.2.2 +transformers>=4.3.0 fire>=0.3.0 pytorch-lightning>=1.1.4 torch>=1.6.0 \ No newline at end of file diff --git a/setup.py b/setup.py index f0d2d47..242e24b 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ python_requires=">=3.6", include_package_data=True, install_requires=[ - "transformers>=4.2.2", + "transformers>=4.3.0", "fire>=0.3.0", "pytorch-lightning>=1.1.4", "torch>=1.6.0", From a91ee6c56defca1eaa87f8fe2742bcddc148f0e5 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Wed, 10 Feb 2021 18:56:32 -0800 Subject: [PATCH 31/40] fix use_cache warning --- aitextgen/aitextgen.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index e5cce37..d89e0e6 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -189,6 +189,7 @@ def __init__( if gradient_checkpointing: setattr(self.model.config, "gradient_checkpointing", True) + setattr(self.model.config, "use_cache", False) if schema_tokens: setattr(self.model.config, "schema_tokens", schema_tokens) @@ -261,6 +262,7 @@ def generate( pad_token_id: str = None, schema: str = False, normalize_key: bool = True, + use_cache: bool = True, **kwargs, ) -> Optional[str]: """ @@ -311,6 +313,7 @@ def generate( do_sample=do_sample, num_return_sequences=n, pad_token_id=pad_token_id, + use_cache=use_cache, **kwargs, ) From 5828198a9d76e25e57a06d415465836c84527a68 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Wed, 10 Feb 2021 19:18:46 -0800 Subject: [PATCH 32/40] make single-printing more explicit --- aitextgen/aitextgen.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index d89e0e6..4f603ff 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -392,7 +392,10 @@ def generate( for text in gen_texts ] - print(*gen_texts, sep="\n" + "=" * 10 + "\n") + if n > 1: + print(*gen_texts, sep="\n" + "=" * 10 + "\n") + else: + print(gen_texts) else: return gen_texts From 034d97e4780ac53e4b5eb8fb3f1adaa4628c7387 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 14 Feb 2021 18:06:32 -0800 Subject: [PATCH 33/40] change TF download url --- aitextgen/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aitextgen/utils.py b/aitextgen/utils.py index 602d17a..5cff49b 100644 --- a/aitextgen/utils.py +++ b/aitextgen/utils.py @@ -35,7 +35,7 @@ def download_gpt2(model_dir: str = "tf_model", model_name: str = "124M") -> None ]: if not os.path.isfile(os.path.join(sub_dir, file_name)): download_file_with_progress( - url_base="https://storage.googleapis.com/gpt-2", + url_base="https://openaipublic.blob.core.windows.net/gpt-2", sub_dir=sub_dir, model_name=model_name, file_name=file_name, From 1f41afe2fe1c129f76de62c3445976d232cbf1aa Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Mon, 15 Feb 2021 21:09:41 -0800 Subject: [PATCH 34/40] Fix generation issues + slow tokenizers --- aitextgen/aitextgen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 4f603ff..72b081f 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -1,6 +1,6 @@ from transformers import ( GPT2LMHeadModel, - GPT2TokenizerFast, + GPT2Tokenizer, GPT2Config, AutoConfig, ) @@ -219,7 +219,7 @@ def __init__( if tokenizer_file: # load the custom GPT-2 tokenizer from a serialized tokenizer - self.tokenizer = GPT2TokenizerFast( + self.tokenizer = GPT2Tokenizer( vocab_file=None, merges_file=None, tokenizer_file=tokenizer_file, @@ -229,7 +229,7 @@ def __init__( pad_token=self.pad_token, ) else: - self.tokenizer = GPT2TokenizerFast( + self.tokenizer = GPT2Tokenizer( vocab_file=self.vocab_file, merges_file=self.merges_file, bos_token=self.bos_token, @@ -385,7 +385,7 @@ def generate( gen_texts = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if not return_as_list: - if prompt is not None: + if prompt: # Bold the prompt if printing to console gen_texts = [ text.replace(prompt_text, f"\033[1m{prompt_text}\033[0m", 1) From cbdf0efd5e160e410019d10c0b8cd55176a07899 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 21 Feb 2021 13:11:57 -0800 Subject: [PATCH 35/40] Handle token cleanup in aitextgen (#90) --- aitextgen/aitextgen.py | 42 ++++++++++++++++++++++++++++-------------- aitextgen/train.py | 21 +++++++++++++++------ aitextgen/utils.py | 18 ++++++++++++++++++ 3 files changed, 61 insertions(+), 20 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 72b081f..be671b5 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -1,6 +1,6 @@ from transformers import ( GPT2LMHeadModel, - GPT2Tokenizer, + GPT2TokenizerFast, GPT2Config, AutoConfig, ) @@ -23,6 +23,7 @@ set_seed, reset_seed, find_index_of_subset, + skip_special_tokens, ) from .train import ATGTransformer, ATGProgressBar from .colab import create_gdrive_folder @@ -219,7 +220,7 @@ def __init__( if tokenizer_file: # load the custom GPT-2 tokenizer from a serialized tokenizer - self.tokenizer = GPT2Tokenizer( + self.tokenizer = GPT2TokenizerFast( vocab_file=None, merges_file=None, tokenizer_file=tokenizer_file, @@ -229,7 +230,7 @@ def __init__( pad_token=self.pad_token, ) else: - self.tokenizer = GPT2Tokenizer( + self.tokenizer = GPT2TokenizerFast( vocab_file=self.vocab_file, merges_file=self.merges_file, bos_token=self.bos_token, @@ -263,6 +264,7 @@ def generate( schema: str = False, normalize_key: bool = True, use_cache: bool = True, + lstrip: bool = True, **kwargs, ) -> Optional[str]: """ @@ -284,16 +286,17 @@ def generate( and model. """ - if prompt: - assert ( - len(prompt) < self.model.config.n_positions - ), "The prompt is too large for the model." - prompt_text = prompt prompt_tensors = self.tokenizer(text=prompt, return_tensors="pt") + if prompt: + prompt_num_tokens = list(prompt_tensors["input_ids"].shape)[1] + assert ( + prompt_num_tokens < self.model.config.n_positions + ), f"The prompt is too large for the model. ({prompt_num_tokens} tokens)" + input_ids = ( - prompt_tensors["input_ids"].to(self.model.device) if prompt else None + prompt_tensors["input_ids"].to(self.get_device()) if prompt else None ) if seed: @@ -377,12 +380,23 @@ def generate( # Typical use case else: - if n > 1: - gen_texts = self.tokenizer.batch_decode( - outputs, skip_special_tokens=True + # Handle special token stripping at the PyTorch level + gen_texts = [ + skip_special_tokens( + text, + self.get_device(), + [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id], ) + for text in outputs + ] + if n > 1: + gen_texts = self.tokenizer.batch_decode(gen_texts) else: - gen_texts = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + gen_texts = [self.tokenizer.decode(gen_texts[0])] + + # Handle stripping tokenization spaces w/ regex + if lstrip: + gen_texts = [re.sub(r"^\W+", "", text) for text in gen_texts] if not return_as_list: if prompt: @@ -395,7 +409,7 @@ def generate( if n > 1: print(*gen_texts, sep="\n" + "=" * 10 + "\n") else: - print(gen_texts) + print(gen_texts[0]) else: return gen_texts diff --git a/aitextgen/train.py b/aitextgen/train.py index 1d65f2b..f7f6fc4 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -204,14 +204,23 @@ def generate_sample_text(self, trainer, pl_module): pad_token_id=pl_module.tokenizer.pad_token_id, ) - if self.n_generate > 1: - gen_texts = pl_module.tokenizer.batch_decode( - outputs, skip_special_tokens=False + special_token_id_tensor = torch.unique( + torch.as_tensor( + [pl_module.tokenizer.bos_token_id, pl_module.tokenizer.eos_token_id] ) + ).to(pl_module.model.device.type) + + outputs = [ + output[ + ~output.unsqueeze(1).eq(special_token_id_tensor.unsqueeze(1)).any(1) + ].tolist() + for output in outputs + ] + + if self.n_generate > 1: + gen_texts = pl_module.tokenizer.batch_decode(outputs) else: - gen_texts = [ - pl_module.tokenizer.decode(outputs[0], skip_special_tokens=False) - ] + gen_texts = [pl_module.tokenizer.decode(outputs[0])] for text in gen_texts: self.main_progress_bar.write("=" * 10) diff --git a/aitextgen/utils.py b/aitextgen/utils.py index 5cff49b..e7bccac 100644 --- a/aitextgen/utils.py +++ b/aitextgen/utils.py @@ -153,3 +153,21 @@ def find_index_of_subset(large_list, small_list): if large_list[idx : idx + length_small_list] == small_list: return idx + length_small_list return None + + +def skip_special_tokens(tensor, device, special_token_ids): + """Filters out special tokens by ids in the given 1D tensor. + + Adapted from https://stackoverflow.com/a/62588955 + + Args: + tensor (tensor): PyTorch Tensor + device (str): Device, usually "cpu" or "cuda:0" + token_ids (set): List of Token IDs + """ + special_token_id_tensor = torch.unique(torch.as_tensor(special_token_ids)).to( + device + ) + return tensor[ + ~tensor.unsqueeze(1).eq(special_token_id_tensor.unsqueeze(1)).any(1) + ].tolist() From 2e7aaf77ebe129e57284ff58b0cef80024404966 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 21 Feb 2021 13:44:57 -0800 Subject: [PATCH 36/40] Begin draft of DeepSpeed integration --- aitextgen/aitextgen.py | 17 +++++++++++++++- aitextgen/utils.py | 45 ++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 2 +- setup.py | 2 +- 4 files changed, 63 insertions(+), 3 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index be671b5..bd6689d 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -34,7 +34,7 @@ try: import torch_xla.core.xla_model as xm # noqa -except ImportError: +except ModuleNotFoundError: pass logger = logging.getLogger("aitextgen") @@ -533,6 +533,7 @@ def train( progress_bar_refresh_rate: int = 20, freeze_layers: bool = False, num_layers_freeze: int = None, + use_deepspeed: bool = True, **kwargs, ) -> None: """ @@ -640,6 +641,19 @@ def train( if not is_gpu_used: n_gpu = 0 + # use the deepseed plugin if installed and specified + deepspeed_plugin = None + # if is_gpu_used and use_deepspeed: + # deepspeed_config = gen_deepspeed_config( + # self.get_device(), learning_rate, weight_decay + # ) + # deepspeed_plugin = DeepSpeedPlugin(deepseed_config) + # logger.info("Using DeepSpeed training.") + # logger.warning( + # "deepspeed was attempted to be used, but was not installed. " + # + "Using normal training behavior." + # ) + train_params = dict( accumulate_grad_batches=gradient_accumulation_steps, gpus=n_gpu, @@ -664,6 +678,7 @@ def train( num_layers_freeze, ) ], + plugins=deepspeed_plugin, ) if fp16: diff --git a/aitextgen/utils.py b/aitextgen/utils.py index e7bccac..aebd582 100644 --- a/aitextgen/utils.py +++ b/aitextgen/utils.py @@ -171,3 +171,48 @@ def skip_special_tokens(tensor, device, special_token_ids): return tensor[ ~tensor.unsqueeze(1).eq(special_token_id_tensor.unsqueeze(1)).any(1) ].tolist() + + +def gen_deepspeed_config(device, lr, weight_decay): + """Deepspeed OneBitAdam config. + + Adapted from https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html#deepspeed + + Args: + device ([type]): Device for training + lr ([type]): Learning rate + weight_decay ([type]): Weight decay + """ + + deepspeed_config = { + "zero_allow_untested_optimizer": True, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": lr, + "betas": [0.998, 0.999], + "eps": 1e-5, + "weight_decay": weight_decay, + "cuda_aware": "cuda" in device, + }, + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + }, + }, + "zero_optimization": { + "stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning) + "cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU + "contiguous_gradients": True, # Reduce gradient fragmentation. + "overlap_comm": True, # Overlap reduce/backward operation of gradients for speed. + "allgather_bucket_size": 2e8, # Number of elements to all gather at once. + "reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once. + }, + } + + return deepspeed_config diff --git a/requirements.txt b/requirements.txt index 5d8adca..c87cc73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ transformers>=4.3.0 fire>=0.3.0 -pytorch-lightning>=1.1.4 +pytorch-lightning>=1.2.0 torch>=1.6.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 242e24b..3165427 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ install_requires=[ "transformers>=4.3.0", "fire>=0.3.0", - "pytorch-lightning>=1.1.4", + "pytorch-lightning>=1.2.0", "torch>=1.6.0", ], ) From 6aa04ecad2e393b88d55dbaffe6a567c46447279 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 21 Feb 2021 15:40:33 -0800 Subject: [PATCH 37/40] Update training hello world --- README.md | 31 ++- notebooks/training_hello_world.ipynb | 395 +++++++++++++++++++++++++++ 2 files changed, 414 insertions(+), 12 deletions(-) create mode 100644 notebooks/training_hello_world.ipynb diff --git a/README.md b/README.md index 6ef08bc..da242da 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,8 @@ A robust Python tool for text-based AI training and generation using [OpenAI's]( aitextgen is a Python package that leverages [PyTorch](https://pytorch.org), [Hugging Face Transformers](https://github.com/huggingface/transformers) and [pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning) with specific optimizations for text generation using GPT-2, plus _many_ added features. It is the successor to [textgenrnn](https://github.com/minimaxir/textgenrnn) and [gpt-2-simple](https://github.com/minimaxir/gpt-2-simple), taking the best of both packages: - Finetunes on a pretrained 124M GPT-2 model from OpenAI...or create your own GPT-2 model + tokenizer and train from scratch! -- Generates text faster than gpt-2-simple and with better memory efficiency! (even [from the 1.5B GPT-2 model](https://docs.aitextgen.io/tutorials/generate_1_5b/)!) -- With Transformers, aitextgen preserves compatibility with the base package, allowing you to use the model for other NLP tasks, download custom GPT-2 models from the Hugging Face model repository, and upload your own models! Also, it uses the included `generate()` function to allow a massive amount of control over the generated text. +- Generates text faster than gpt-2-simple and with better memory efficiency! +- With Transformers, aitextgen preserves compatibility with the base package, allowing you to use the model for other NLP tasks, download custom GPT-2 models from the HuggingFace model repository, and upload your own models! Also, it uses the included `generate()` function to allow a massive amount of control over the generated text. - With pytorch-lightning, aitextgen trains models not just on CPUs and GPUs, but also _multiple_ GPUs and (eventually) TPUs! It also includes a pretty training progress bar, with the ability to add optional loggers. - The input dataset is its own object, allowing you to not only easily encode megabytes of data in seconds, cache, and compress it on a local computer before transporting to a remote server, but you are able to _merge_ datasets without biasing the resulting dataset, or _cross-train_ on multiple datasets to create blended output. @@ -54,7 +54,7 @@ aitextgen generate aitextgen generate --prompt "I believe in unicorns because" --to_file False ``` -Want to train your own mini GPT-2 model on your own computer? Download this [text file of Shakespeare's plays](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt), cd to that directory in a Terminal, open up a `python3` console and go: +Want to train your own mini GPT-2 model on your own computer? You can follow along [in this Jupyter Notebook](/notebooks/training_hello_world.ipynb) or, download this [text file of Shakespeare's plays](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt), cd to that directory in a Terminal, open up a `python3` console and go: ```python from aitextgen.TokenDataset import TokenDataset @@ -66,29 +66,36 @@ from aitextgen import aitextgen file_name = "input.txt" # Train a custom BPE Tokenizer on the downloaded text -# This will save two files: aitextgen-vocab.json and aitextgen-merges.txt, -# which are needed to rebuild the tokenizer. +# This will save one file: `aitextgen.tokenizer.json`, which contains the +# information needed to rebuild the tokenizer. train_tokenizer(file_name) -vocab_file = "aitextgen-vocab.json" -merges_file = "aitextgen-merges.txt" +tokenizer_file = "aitextgen.tokenizer.json" # GPT2ConfigCPU is a mini variant of GPT-2 optimized for CPU-training # e.g. the # of input tokens here is 64 vs. 1024 for base GPT-2. config = GPT2ConfigCPU() # Instantiate aitextgen using the created tokenizer and config -ai = aitextgen(vocab_file=vocab_file, merges_file=merges_file, config=config) +ai = aitextgen(tokenizer_file=tokenizer_file, config=config) # You can build datasets for training by creating TokenDatasets, # which automatically processes the dataset with the appropriate size. -data = TokenDataset(file_name, vocab_file=vocab_file, merges_file=merges_file, block_size=64) +data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64) # Train the model! It will save pytorch_model.bin periodically and after completion. -# On a 2016 MacBook Pro, this took ~25 minutes to run. -ai.train(data, batch_size=16, num_steps=5000) +# On a 2020 8-core iMac, this took ~25 minutes to run. +ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000) # Generate text from it! ai.generate(10, prompt="ROMEO:") + +# With your trained model, you can reload the model at any time by +# providing the pytorch_model.bin model weights, the config, and the tokenizer. +ai2 = aitextgen(model="trained_model/pytorch_model.bin", + tokenizer_file="aitextgen.tokenizer.json", + config="trained_model/config.json") + +ai2.generate(10, prompt="ROMEO:") ``` Want to run aitextgen and finetune GPT-2? Use the Colab notebooks in the Demos section, or [follow the documentation](https://docs.aitextgen.io/) to get more information and learn some helpful tips! @@ -102,7 +109,7 @@ Want to run aitextgen and finetune GPT-2? Use the Colab notebooks in the Demos s ## Upcoming Features -The current release (v0.2.X) of aitextgen **is considered to be a beta**, targeting the most common use cases. The Notebooks and examples written so far are tested to work, but more fleshing out of the docs/use cases will be done over the next few months in addition to fixing the known issues noted above. +The current release (v0.4.X) of aitextgen **is considered to be a beta**, targeting the most common use cases. The Notebooks and examples written so far are tested to work, but more fleshing out of the docs/use cases will be done over the next few months in addition to fixing the known issues noted above. The next versions of aitextgen (and one of the reasons I made this package in the first place) will have native support for _schema-based generation_. (See [this repo](https://github.com/minimaxir/gpt-2-keyword-generation) for a rough proof-of-concept.) diff --git a/notebooks/training_hello_world.ipynb b/notebooks/training_hello_world.ipynb new file mode 100644 index 0000000..05f2f18 --- /dev/null +++ b/notebooks/training_hello_world.ipynb @@ -0,0 +1,395 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# aitextgen Training Hello World\n", + "\n", + "_Last Updated: Feb 21, 2021 (v.0.4.0)_\n", + "\n", + "by Max Woolf\n", + "\n", + "A \"Hello World\" Tutorial to show how training works with aitextgen, even on a CPU!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from aitextgen.TokenDataset import TokenDataset\n", + "from aitextgen.tokenizers import train_tokenizer\n", + "from aitextgen.utils import GPT2ConfigCPU\n", + "from aitextgen import aitextgen" + ] + }, + { + "source": [ + "First, download this [text file of Shakespeare's plays](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt), to the folder with this notebook, then put the name of the downloaded Shakespeare text for training into the cell below." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "file_name = \"input.txt\"" + ] + }, + { + "source": [ + "You can now train a custom Byte Pair Encoding Tokenizer on the downloaded text!\n", + "\n", + "This will save one file: `aitextgen.tokenizer.json`, which contains the information needed to rebuild the tokenizer." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "train_tokenizer(file_name)\n", + "tokenizer_file = \"aitextgen.tokenizer.json\"" + ] + }, + { + "source": [ + "`GPT2ConfigCPU()` is a mini variant of GPT-2 optimized for CPU-training.\n", + "\n", + "e.g. the # of input tokens here is 64 vs. 1024 for base GPT-2. This dramatically speeds training up." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "config = GPT2ConfigCPU()" + ] + }, + { + "source": [ + "Instantiate aitextgen using the created tokenizer and config" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "ai = aitextgen(tokenizer_file=tokenizer_file, config=config)" + ] + }, + { + "source": [ + "You can build datasets for training by creating TokenDatasets, which automatically processes the dataset with the appropriate size." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 40000/40000 [00:00<00:00, 86712.61it/s]\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TokenDataset containing 462,820 subsets loaded from file at input.txt." + ] + }, + "metadata": {}, + "execution_count": 6 + } + ], + "source": [ + "data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64)\n", + "data" + ] + }, + { + "source": [ + "Train the model! It will save pytorch_model.bin periodically and after completion. On a 2020 8-core iMac, this took ~25 minutes to run.\n", + "\n", + "The configuration below processes 400,000 subsets of tokens (8 * 50000), which is about just one pass through all the data (1 epoch). Ideally you'll want multiple passes through the data and a training loss less than `2.0` for coherent output; when training a model from scratch, that's more difficult, but with long enough training you can get there!" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "pytorch_model.bin already exists in /trained_model and will be overwritten!\n", + "GPU available: False, used: False\n", + "TPU available: None, using: 0 TPU cores\n", + "\u001b[1m5,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m5,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + "'s dead;\n", + "But is no winted in his northeritiff\n", + "Tave passage, and eleve your hours.\n", + "\n", + "PETRUCHIO:\n", + "What is this I does, I will, sir;\n", + "That, you have, nor tolding we\n", + "==========\n", + "\u001b[1m10,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m10,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + ".\n", + "\n", + "QUEEN ELIZABETH:\n", + "I know, to, fair beat, to my soul is wonder'd intend.\n", + "\n", + "KING RICHARD III:\n", + "Hold, and threaten, my lord, and my shame!\n", + "\n", + "QUEEN ELIZAB\n", + "==========\n", + "\u001b[1m15,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m15,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + "s of capitcts!\n", + "\n", + "EDWARD:\n", + "Gardener, what is this hour will not say.\n", + "What, shall the joint, I pray, if they\n", + "Harry, let bid me as he would readness so.\n", + "\n", + "B\n", + "==========\n", + "\u001b[1m20,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m20,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + " for.\n", + "\n", + "ROMEO:\n", + "Fair to the iercing wide's fretch,\n", + "And happy talk of the master,\n", + "And waste their justice with the feet and punning,\n", + "And therefore be ben\n", + "==========\n", + "\u001b[1m25,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m25,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + ",\n", + "That we we will have not lose such.\n", + "\n", + "See, to the kingdom of our virtue,\n", + "You banish'd our purpose, for our own ignorse,\n", + "Dispon I remain, and seem'd in\n", + "==========\n", + "\u001b[1m30,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m30,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + ".\n", + "\n", + "BENVOLIO:\n", + "O, she's dead!\n", + "\n", + "CAMILLO:\n", + "No, my lord;\n", + "These accession will be hous.\n", + "\n", + "DERBY:\n", + "No, my lord.\n", + "\n", + "GLOUCESTER:\n", + "What is the\n", + "==========\n", + "\u001b[1m35,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m35,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + ",\n", + "And whiles it is but the castle,\n", + "That stavin'd in the gods of men.\n", + "\n", + "COMFEY:\n", + "What, then?\n", + "\n", + "ELBOW:\n", + "Peace, my lord,\n", + "And weat your greats\n", + "==========\n", + "\u001b[1m40,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m40,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + "\n", + "The white mercy of the sun upon my past,\n", + "Of my father's son be first, thy sake,\n", + "His son's chief son, and my includy;\n", + "And if thy brother's loss, thy thrief,\n", + "\n", + "==========\n", + "\u001b[1m45,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m45,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + " to the crown,\n", + "Or I'll privy I have.\n", + "\n", + "POLIXENES:\n", + "I have been a stir.\n", + "\n", + "LEONTES:\n", + "The worshiped, the benefition of the crown.\n", + "\n", + "His somet\n", + "==========\n", + "\u001b[1m50,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m50,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + ":\n", + "Catesby, girls, and make avoides;\n", + "But, welcome a far\n", + "That ever home, like a villain, and behold\n", + "Canusy not passing nonquial at the g\n", + "==========\n", + "Loss: 2.940 — Avg: 2.884: 100%|██████████| 50000/50000 [31:39<00:00, 26.32it/s]\n" + ] + } + ], + "source": [ + "ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000)" + ] + }, + { + "source": [ + "Generate text from your trained model!" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[1mROMEO:\u001b[0m\nAbook, ho! forthing me, gentle Earl's royal king,\nAnd this, I, with that I do not beseech you\nTo visit the battle, that I should believe you,\nWhich I would never\n==========\n\u001b[1mROMEO:\u001b[0m\nConfound is gone, thou art a maid into the widow;\nPut up my life and make me no harmony\nAnd make thee I know uncle,\nUnconted and curses: therefore in my\n==========\n\u001b[1mROMEO:\u001b[0m\nGod push! but what days to see\nThe giving bleedom's heart I do? Therefore,\nAnd most unless I had rather. He saddle\nTake your cold shack down; and so far I\n==========\n\u001b[1mROMEO:\u001b[0m\nPersetain'd up the earth of mercy,\nAnd never yet, the sun to make him all the\nMore than my battle.\n\nROMEO:\nI warrant him, to know, we'll not do't, but hate me\n==========\n\u001b[1mROMEO:\u001b[0m\nMethinks I am a mile, and trench one\nThy winded makes, in faults and cast\nWith one to meether, of twenty days,\nThat in my waters, that f\n==========\n\u001b[1mROMEO:\u001b[0m\nO, here is such a woman guilty.\n\nROMEO:\nI do not think it; I should be renowned\nThat I am in that which can controy\nA bawd I take it to the purpose.\n\nJU\n==========\n\u001b[1mROMEO:\u001b[0m\nI know not what I am.\n\nFLORIZEL:\nAy, as I did,\nI would be adverpite of the homely treason\nFrom the doubled in the farm of his bed.\nTa\n==========\n\u001b[1mROMEO:\u001b[0m\nI pray you, he would have taken to him but,\nAnd freely mark his into a fine of it,\nSpeak to the second to our cheek;\nAnd every day, and sanctious cover\n==========\n\u001b[1mROMEO:\u001b[0m\nI had left me--born to be drawn.\n\nJULIET:\nMy husbour, I will have thee here:\nAnd, I have found to seek thyself.\n\nJULIET:\nI will be not b\n==========\n\u001b[1mROMEO:\u001b[0m\nThat is a hour,\nThe castard is, I'll not buy, or indeeding.\n\nNurse:\nLADY CAPULET:\nThe matter, that ta'en as I may find thee.\n\n" + ] + } + ], + "source": [ + "ai.generate(10, prompt=\"ROMEO:\")" + ] + }, + { + "source": [ + "With your trained model, you can reload the model at any time by providing the `pytorch_model.bin` model weights, the `config`, and the `tokenizer`." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "ai2 = aitextgen(model=\"trained_model/pytorch_model.bin\",\n", + " tokenizer_file=\"aitextgen.tokenizer.json\",\n", + " config=\"trained_model/config.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[1mROMEO:\u001b[0m\nBoy, unreacher, unhallupony, in Padua,\nUntimely fall till I be learn'd.\n\nROMEO:\nFie, good friar, be quick, for I am,\nI'll\n==========\n\u001b[1mROMEO:\u001b[0m\nI'll be plain, I am a tail of blessed wounds;\nFor I am dead, I have not borne to make\nA couple of her fortune, but that I'll bear,\nAnd say 'Ay, chur\n==========\n\u001b[1mROMEO:\u001b[0m\nAnd yet I am a resolution of my dear dear:\nIf I have not reason to do me say\nI'll deny the sea of my body to answer,\nAnd all thy tale, or I have my m\n==========\n\u001b[1mROMEO:\u001b[0m\nIntenty to a bawd of my bait,--\n\nJULIET:\nNo, I hope to know the title,\nFor that I wish her place.\n\nJULIET:\nDo I assure her?\n==========\n\u001b[1mROMEO:\u001b[0m\nO, what's the parle that I chide thee,\nThat honourable may be, that I have still'd thee:\nI pray thee, my lord.\n\nMERCUTIO:\nI', my lord.\n\nROMEO:\nHere is a\n==========\n\u001b[1mROMEO:\u001b[0m\nAnd, for I am, and not talk of that?\n\nROMEO:\nWhere's my child, I would guess thee here.\n\nJULIET:\nNay, boy, I'll not be bowling why I;\nO thou\n==========\n\u001b[1mROMEO:\u001b[0m\nO, but thou hast seen thee of mine own.\n\nROMEO:\nI would assist thee--\n\nJULIET:\nAy, it is, and not so.\n\nROMEO:\nNo, but that I must told me with it.\n\nROMEO\n==========\n\u001b[1mROMEO:\u001b[0m\nNo, no, nor I am. I am content.\n\nBENVOLIO:\nI will not, sir: but I have required\nAs I am grown in the lawful virtue\nThat it hath bid you think, and I\n==========\n\u001b[1mROMEO:\u001b[0m\nThat I should pardon, I would be gone.\n\nESCALUS:\nI should believe you, sir, sir, ay, I would not\nnot know more, but that I can, but I would have savour me.\n\nP\n==========\n\u001b[1mROMEO:\u001b[0m\nAnd thou, I will find out thy life the wind of love.\n\nROMEO:\nIt is the morning groom of it.\n\nJULIET:\nFie, good sweet boy, I will take my leave to a happy day,\n" + ] + } + ], + "source": [ + "ai2.generate(10, prompt=\"ROMEO:\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MIT License\n", + "\n", + "Copyright (c) 2021 Max Woolf\n", + "\n", + "Permission is hereby granted, free of charge, to any person obtaining a copy\n", + "of this software and associated documentation files (the \"Software\"), to deal\n", + "in the Software without restriction, including without limitation the rights\n", + "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", + "copies of the Software, and to permit persons to whom the Software is\n", + "furnished to do so, subject to the following conditions:\n", + "\n", + "The above copyright notice and this permission notice shall be included in all\n", + "copies or substantial portions of the Software.\n", + "\n", + "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", + "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", + "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", + "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", + "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", + "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", + "SOFTWARE." + ] + } + ], + "metadata": { + "kernelspec": { + "name": "python3", + "display_name": "Python 3.9.1 64-bit", + "metadata": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } + } + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1-final" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file From 1fbe9d8f42df7402e5cbb3082bc805dbc32b0281 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 21 Feb 2021 17:15:01 -0800 Subject: [PATCH 38/40] Update CHANGELOG --- CHANGELOG.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4227d9f..faffb71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,13 +4,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). -## [0.4.0] - TBD +## [0.4.0] - 2021-02-21 +- Increased minimum versions of dependencies (`transformers` to 4.3.0, `pytorch-lightning` to 1.2.0) + - Remove dependency on `tokenizers` as `transformers` pins it. - Made Fast tokenizers the default (as it is the default in `transformers` 4.0.0) - Made serialized tokenizers the default for custom tokenizers, and added support for loading them for both `aitextgen` and `TokenDataset`s - Added gradient checkpointing for GPT-2, and set it to the default for training 355M and 774M. - Added layer freezing to freeze the first `n` layers of GPT-2 while training. This allows 1.5B GPT-2 to be trained with a high `n`. - Added schema-based generation for specificed schema_tokens (which can be encoded in the Transformers config). This can be used with an appropriate dataset for schema-based generation. +- Switched TensorFlow weight download URL from GCP (as OpenAI removed it from there) to Azure +- Fixed issue where prompt character length was used to check for a too-long assert instead of prompt token length (#90) +- Workaround breaking issue in Transformers 4.3.0 by moving special token stripping into aitextgen instead of the tokenizer (#90) +- Added an `lstrip` param to generation, which strips all whitespace at the beginning of generated text (related to point above) ## [0.3.0] - 2020-11-30 From a314c39fb4acadaa32ed1710e5e51b56f52bd887 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 21 Feb 2021 21:08:38 -0800 Subject: [PATCH 39/40] do not use skip special tokens for schema --- aitextgen/aitextgen.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index bd6689d..7f72a5a 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -356,10 +356,15 @@ def generate( if normalize_key else token_tuple[0] ) - gen_text_dict[key] = self.tokenizer.decode( - output[start_index:end_index], skip_special_tokens=False + + gen_text = skip_special_tokens( + output[start_index:end_index], + self.get_device(), + [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id], ) + gen_text_dict[key] = self.tokenizer.decode(gen_text) + # remove fields not in schema_return if schema_return: if len(schema_return) == 1: From b4dec6c5b18beb3a45b0ac241fb4e5f386f9cab0 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Mon, 22 Feb 2021 19:47:56 -0800 Subject: [PATCH 40/40] More logger entries + better tf_gpt2 handle --- aitextgen/aitextgen.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 7f72a5a..f30dffa 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -66,7 +66,7 @@ class aitextgen: :param unk_token: String to override the unknown token """ - openai_gpt2_large = False + openai_tf_gpt2 = None # default values for GPT2Tokenizer tokenizer = None @@ -110,9 +110,7 @@ def __init__( logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) if tf_gpt2: - if tf_gpt2 != "124M": - self.openai_gpt2_large = True - gradient_checkpointing = True + self.openai_tf_gpt2 = tf_gpt2 # Download + convert the TF weights if a PyTorch model has not been created if not os.path.isfile( @@ -188,7 +186,8 @@ def __init__( cache_dir=cache_dir, ) - if gradient_checkpointing: + if gradient_checkpointing or tf_gpt2 in ["355M", "774M", "1558M"]: + logger.info("Gradient checkpointing enabled for model training.") setattr(self.model.config, "gradient_checkpointing", True) setattr(self.model.config, "use_cache", False) @@ -597,7 +596,8 @@ def train( **kwargs, ) - if freeze_layers or self.openai_gpt2_large: + if freeze_layers or self.openai_tf_gpt2 == "1558M": + logger.info("Layer freezing enabled for model training.") freeze_layers = True if num_layers_freeze: assert (