Skip to content

Commit

Permalink
More logger entries + better tf_gpt2 handle
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Feb 23, 2021
1 parent a314c39 commit b4dec6c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class aitextgen:
:param unk_token: String to override the unknown token
"""

openai_gpt2_large = False
openai_tf_gpt2 = None

# default values for GPT2Tokenizer
tokenizer = None
Expand Down Expand Up @@ -110,9 +110,7 @@ def __init__(
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)

if tf_gpt2:
if tf_gpt2 != "124M":
self.openai_gpt2_large = True
gradient_checkpointing = True
self.openai_tf_gpt2 = tf_gpt2

# Download + convert the TF weights if a PyTorch model has not been created
if not os.path.isfile(
Expand Down Expand Up @@ -188,7 +186,8 @@ def __init__(
cache_dir=cache_dir,
)

if gradient_checkpointing:
if gradient_checkpointing or tf_gpt2 in ["355M", "774M", "1558M"]:
logger.info("Gradient checkpointing enabled for model training.")
setattr(self.model.config, "gradient_checkpointing", True)
setattr(self.model.config, "use_cache", False)

Expand Down Expand Up @@ -597,7 +596,8 @@ def train(
**kwargs,
)

if freeze_layers or self.openai_gpt2_large:
if freeze_layers or self.openai_tf_gpt2 == "1558M":
logger.info("Layer freezing enabled for model training.")
freeze_layers = True
if num_layers_freeze:
assert (
Expand Down

0 comments on commit b4dec6c

Please sign in to comment.