Skip to content

Commit dae451c

Browse files
committed
Move instantiate_from_config into module
1 parent e30e8fb commit dae451c

File tree

7 files changed

+25
-24
lines changed

7 files changed

+25
-24
lines changed

main.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import argparse, os, sys, datetime, glob, importlib
1+
import argparse, os, sys, datetime, glob
22
from omegaconf import OmegaConf
33
import numpy as np
44
from PIL import Image
@@ -12,15 +12,7 @@
1212
from pytorch_lightning.utilities.distributed import rank_zero_only
1313

1414
from taming.data.utils import custom_collate
15-
16-
17-
def get_obj_from_str(string, reload=False):
18-
module, cls = string.rsplit(".", 1)
19-
if reload:
20-
module_imp = importlib.import_module(module)
21-
importlib.reload(module_imp)
22-
return getattr(importlib.import_module(module, package=None), cls)
23-
15+
from taming.util import instantiate_from_config
2416

2517
def get_parser(**parser_kwargs):
2618
def str2bool(v):
@@ -113,12 +105,6 @@ def nondefault_trainer_args(opt):
113105
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
114106

115107

116-
def instantiate_from_config(config):
117-
if not "target" in config:
118-
raise KeyError("Expected key `target` to instantiate.")
119-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
120-
121-
122108
class WrappedDataset(Dataset):
123109
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
124110
def __init__(self, dataset):

scripts/make_samples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from omegaconf import OmegaConf
55
from PIL import Image
6-
from main import instantiate_from_config, DataModuleFromConfig
6+
from taming.util import instantiate_from_config
77
from torch.utils.data import DataLoader
88
from torch.utils.data.dataloader import default_collate
99
from tqdm import trange

scripts/sample_conditional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import streamlit as st
66
from streamlit import caching
77
from PIL import Image
8-
from main import instantiate_from_config, DataModuleFromConfig
8+
from taming.util import instantiate_from_config
99
from torch.utils.data import DataLoader
1010
from torch.utils.data.dataloader import default_collate
1111

scripts/sample_fast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from tqdm import tqdm, trange
88
from einops import repeat
99

10-
from main import instantiate_from_config
1110
from taming.modules.transformer.mingpt import sample_with_past
11+
from taming.util import instantiate_from_config
1212

1313

1414
rescale = lambda x: (x + 1.) / 2.

taming/models/cond_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch.nn.functional as F
44
import pytorch_lightning as pl
55

6-
from main import instantiate_from_config
76
from taming.modules.util import SOSProvider
7+
from taming.util import instantiate_from_config
88

99

1010
def disabled_train(self, mode=True):

taming/models/vqgan.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import torch.nn.functional as F
33
import pytorch_lightning as pl
44

5-
from main import instantiate_from_config
6-
75
from taming.modules.diffusionmodules.model import Encoder, Decoder
86
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
97
from taming.modules.vqvae.quantize import GumbelQuantize
10-
from taming.modules.vqvae.quantize import EMAVectorQuantizer
8+
from taming.util import instantiate_from_config
9+
1110

1211
class VQModel(pl.LightningModule):
1312
def __init__(self,
@@ -401,4 +400,4 @@ def configure_optimizers(self):
401400
lr=lr, betas=(0.5, 0.9))
402401
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
403402
lr=lr, betas=(0.5, 0.9))
404-
return [opt_ae, opt_disc], []
403+
return [opt_ae, opt_disc], []

taming/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
import os, hashlib
23
import requests
34
from tqdm import tqdm
@@ -142,6 +143,20 @@ def retrieve(
142143
return list_or_dict, success
143144

144145

146+
def get_obj_from_str(string, reload=False):
147+
module, cls = string.rsplit(".", 1)
148+
if reload:
149+
module_imp = importlib.import_module(module)
150+
importlib.reload(module_imp)
151+
return getattr(importlib.import_module(module, package=None), cls)
152+
153+
154+
def instantiate_from_config(config):
155+
if not "target" in config:
156+
raise KeyError("Expected key `target` to instantiate.")
157+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
158+
159+
145160
if __name__ == "__main__":
146161
config = {"keya": "a",
147162
"keyb": "b",
@@ -155,3 +170,4 @@ def retrieve(
155170
print(config)
156171
retrieve(config, "keya")
157172

173+

0 commit comments

Comments
 (0)