diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 7f72a5a..f30dffa 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -66,7 +66,7 @@ class aitextgen: :param unk_token: String to override the unknown token """ - openai_gpt2_large = False + openai_tf_gpt2 = None # default values for GPT2Tokenizer tokenizer = None @@ -110,9 +110,7 @@ def __init__( logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) if tf_gpt2: - if tf_gpt2 != "124M": - self.openai_gpt2_large = True - gradient_checkpointing = True + self.openai_tf_gpt2 = tf_gpt2 # Download + convert the TF weights if a PyTorch model has not been created if not os.path.isfile( @@ -188,7 +186,8 @@ def __init__( cache_dir=cache_dir, ) - if gradient_checkpointing: + if gradient_checkpointing or tf_gpt2 in ["355M", "774M", "1558M"]: + logger.info("Gradient checkpointing enabled for model training.") setattr(self.model.config, "gradient_checkpointing", True) setattr(self.model.config, "use_cache", False) @@ -597,7 +596,8 @@ def train( **kwargs, ) - if freeze_layers or self.openai_gpt2_large: + if freeze_layers or self.openai_tf_gpt2 == "1558M": + logger.info("Layer freezing enabled for model training.") freeze_layers = True if num_layers_freeze: assert (