Skip to content

Commit d9b9ce0

Browse files
committed
Move instantiate_from_config into module
1 parent 3e94fa3 commit d9b9ce0

File tree

7 files changed

+25
-23
lines changed

7 files changed

+25
-23
lines changed

main.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import argparse, os, sys, datetime, glob, importlib
1+
import argparse, os, sys, datetime, glob
22
from omegaconf import OmegaConf
33
import numpy as np
44
from PIL import Image
@@ -10,14 +10,7 @@
1010
from pytorch_lightning.trainer import Trainer
1111
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
1212
from pytorch_lightning.utilities.distributed import rank_zero_only
13-
14-
def get_obj_from_str(string, reload=False):
15-
module, cls = string.rsplit(".", 1)
16-
if reload:
17-
module_imp = importlib.import_module(module)
18-
importlib.reload(module_imp)
19-
return getattr(importlib.import_module(module, package=None), cls)
20-
13+
from taming.util import instantiate_from_config
2114

2215
def get_parser(**parser_kwargs):
2316
def str2bool(v):
@@ -110,12 +103,6 @@ def nondefault_trainer_args(opt):
110103
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
111104

112105

113-
def instantiate_from_config(config):
114-
if not "target" in config:
115-
raise KeyError("Expected key `target` to instantiate.")
116-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
117-
118-
119106
class WrappedDataset(Dataset):
120107
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
121108
def __init__(self, dataset):

scripts/make_samples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from omegaconf import OmegaConf
55
from PIL import Image
6-
from main import instantiate_from_config, DataModuleFromConfig
6+
from taming.util import instantiate_from_config
77
from torch.utils.data import DataLoader
88
from torch.utils.data.dataloader import default_collate
99
from tqdm import trange

scripts/sample_conditional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import streamlit as st
66
from streamlit import caching
77
from PIL import Image
8-
from main import instantiate_from_config, DataModuleFromConfig
8+
from taming.util import instantiate_from_config
99
from torch.utils.data import DataLoader
1010
from torch.utils.data.dataloader import default_collate
1111

scripts/sample_fast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from tqdm import tqdm, trange
88
from einops import repeat
99

10-
from main import instantiate_from_config
1110
from taming.modules.transformer.mingpt import sample_with_past
11+
from taming.util import instantiate_from_config
1212

1313

1414
rescale = lambda x: (x + 1.) / 2.

taming/models/cond_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch.nn.functional as F
44
import pytorch_lightning as pl
55

6-
from main import instantiate_from_config
76
from taming.modules.util import SOSProvider
7+
from taming.util import instantiate_from_config
88

99

1010
def disabled_train(self, mode=True):

taming/models/vqgan.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import torch.nn.functional as F
33
import pytorch_lightning as pl
44

5-
from main import instantiate_from_config
6-
75
from taming.modules.diffusionmodules.model import Encoder, Decoder
86
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
97
from taming.modules.vqvae.quantize import GumbelQuantize
10-
from taming.modules.vqvae.quantize import EMAVectorQuantizer
8+
from taming.util import instantiate_from_config
9+
1110

1211
class VQModel(pl.LightningModule):
1312
def __init__(self,
@@ -401,4 +400,4 @@ def configure_optimizers(self):
401400
lr=lr, betas=(0.5, 0.9))
402401
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
403402
lr=lr, betas=(0.5, 0.9))
404-
return [opt_ae, opt_disc], []
403+
return [opt_ae, opt_disc], []

taming/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
import os, hashlib
23
import requests
34
from tqdm import tqdm
@@ -142,6 +143,20 @@ def retrieve(
142143
return list_or_dict, success
143144

144145

146+
def get_obj_from_str(string, reload=False):
147+
module, cls = string.rsplit(".", 1)
148+
if reload:
149+
module_imp = importlib.import_module(module)
150+
importlib.reload(module_imp)
151+
return getattr(importlib.import_module(module, package=None), cls)
152+
153+
154+
def instantiate_from_config(config):
155+
if not "target" in config:
156+
raise KeyError("Expected key `target` to instantiate.")
157+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
158+
159+
145160
if __name__ == "__main__":
146161
config = {"keya": "a",
147162
"keyb": "b",
@@ -155,3 +170,4 @@ def retrieve(
155170
print(config)
156171
retrieve(config, "keya")
157172

173+

0 commit comments

Comments
 (0)