From d8891630d994381ad1a20edbf8611f3e642e8c4e Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 4 Jul 2020 19:07:47 -0700 Subject: [PATCH] Do not override optimizer_step (#44) --- aitextgen/train.py | 12 ------------ setup.py | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/aitextgen/train.py b/aitextgen/train.py index 9e418a6..386d6ee 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -86,20 +86,8 @@ def configure_optimizers(self): num_training_steps=self.hparams["num_steps"], ) - self.opt = optimizer - self.lr_scheduler = scheduler return [optimizer], [scheduler] - def optimizer_step( - self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None - ): - if self.hparams["tpu"]: - xm.optimizer_step(optimizer, barrier=True) - else: - optimizer.step() - optimizer.zero_grad() - self.lr_scheduler.step() - class ATGProgressBar(ProgressBarBase): """A variant progress bar that works off of steps and prints periodically.""" diff --git a/setup.py b/setup.py index e44eabd..fde0384 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="aitextgen", packages=["aitextgen"], # this must be the same as the name above - version="0.2.2", + version="0.2.3", description="A robust Python tool for text-based AI training and generation using GPT-2.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown",