diff --git a/main.py b/main.py index 3d83cb21..f07510ac 100644 --- a/main.py +++ b/main.py @@ -11,17 +11,10 @@ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from pytorch_lightning.utilities.distributed import rank_zero_only +from taming import get_obj_from_str, instantiate_from_config from taming.data.utils import custom_collate -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): diff --git a/taming/__init__.py b/taming/__init__.py new file mode 100644 index 00000000..ac572368 --- /dev/null +++ b/taming/__init__.py @@ -0,0 +1,13 @@ +import importlib + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) diff --git a/taming/data/__init__.py b/taming/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/models/__init__.py b/taming/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/models/cond_transformer.py b/taming/models/cond_transformer.py index e4c63730..98cbab51 100644 --- a/taming/models/cond_transformer.py +++ b/taming/models/cond_transformer.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config +from taming import instantiate_from_config from taming.modules.util import SOSProvider diff --git a/taming/models/vqgan.py b/taming/models/vqgan.py index a6950baa..d60c5021 100644 --- a/taming/models/vqgan.py +++ b/taming/models/vqgan.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config +from taming import instantiate_from_config from taming.modules.diffusionmodules.model import Encoder, Decoder from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer diff --git a/taming/modules/__init__.py b/taming/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/diffusionmodules/__init__.py b/taming/modules/diffusionmodules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/discriminator/__init__.py b/taming/modules/discriminator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/misc/__init__.py b/taming/modules/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/transformer/__init__.py b/taming/modules/transformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/vqvae/__init__.py b/taming/modules/vqvae/__init__.py new file mode 100644 index 00000000..e69de29b