diff --git a/grok/training.py b/grok/training.py index 43b43df..3ffe804 100755 --- a/grok/training.py +++ b/grok/training.py @@ -47,7 +47,7 @@ def __init__(self, hparams: Namespace) -> None: self.add_model_specific_args(). """ super().__init__() - self.hparams = hparams # type: ignore + self.hparams.update(vars(hparams)) # type: ignore self.prepare_data() self.transformer = Transformer( diff --git a/setup.py b/setup.py index c31f060..7545493 100644 --- a/setup.py +++ b/setup.py @@ -5,9 +5,9 @@ packages=find_packages(), version="0.0.1", install_requires=[ - "pytorch_lightning", + "pytorch_lightning==1.5.10", "blobfile", - "numpy", + "numpy==1.23.0", "torch", "tqdm", "scipy",