Skip to content

Commit

Permalink
general: Update
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Jan 20, 2023
1 parent 32808c3 commit c0fb9fa
Show file tree
Hide file tree
Showing 10 changed files with 600 additions and 84 deletions.
2 changes: 1 addition & 1 deletion PROJECT_NAME
Original file line number Diff line number Diff line change
@@ -1 +1 @@
simple-einet
simple-einet
29 changes: 23 additions & 6 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,27 @@ defaults:
- override hydra/job_logging: colorlog
- override hydra/hydra_logging: colorlog

# Hydra config
hydra:
run:
dir: "${results_dir}/${dataset}/${now:%Y-%m-%d_%H-%M-%S}_${oc.select:tag,}"
sweep:
dir: "${hydra.run.dir}/${hydra.job.name}"
subdir: "${hydra.run.dir}/${hydra.job.num}"
job_logging:
handlers:
file:
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log # Fixed in hydra-colorlog version 1.2.1


# Default set of configurations.
data_dir: ???
results_dir: ???
lr: 0.1
data_dir: "${oc.env:DATA_DIR}/"
results_dir: "${oc.env:RESULTS_DIR}/${project_name}"
project_name: "simple-einet-dc"
lr: 0.0005
batch_size: 64
n_bits: 8
num_workers: 8
num_workers: 16
temperature_leaves: 1.0
temperature_sums: 1.0
dropout: 0.0
Expand All @@ -19,18 +33,19 @@ max_sigma: 2.0
dry_run: False
seed: 1
log_interval: 10
classification: False
classification: True
device: "cuda"
debug: False
S: 10
I: 10
D: 3
R: 1
K: -1
gpu: 0
epochs: 10
load_and_eval: False
cp: False
dist: "binomial"
dist: "normal"
precision: 16
group_tag: ???
tag: ???
Expand All @@ -40,3 +55,5 @@ profiler: ???
log_weights: False
dataset: ???
num_classes: 10
init_leaf_data: True
einet_mixture: False
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,4 @@ def test(model, device, loader, tag):
)

print(f"Result directory: {result_dir}")
print("Done.")
print("Done.")
189 changes: 141 additions & 48 deletions main_pl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
#!/usr/bin/env python
from simple_einet.distributions.normal import Normal
from simple_einet.einet import Einet
import omegaconf
import time
import wandb
from hydra.core.hydra_config import HydraConfig
import logging
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig, OmegaConf, open_dict
import os
import sys
from rich.traceback import install
install()

import hydra
import pytorch_lightning as pl
Expand All @@ -21,73 +30,137 @@
from models_pl import SpnDiscriminative, SpnGenerative
from simple_einet.data import Dist
from simple_einet.data import build_dataloader
from tqdm import tqdm

# A logger for this file
logger = logging.getLogger(__name__)

@hydra.main(version_base=None, config_path="./conf", config_name="config")

def init_einet_stats(einet: Einet, dataloader: torch.utils.data.DataLoader):
stats_mean = None
stats_std = None
for batch in tqdm(dataloader, desc="Leaf Parameter Initialization"):
data, label = batch
if stats_mean == None:
stats_mean = data.mean(dim=0)
stats_std = data.std(dim=0)
else:
stats_mean += data.mean(dim=0)
stats_std += data.std(dim=0)

stats_mean /= len(dataloader)
stats_std /= len(dataloader)

if einet.config.leaf_type == Normal:
if type(einet) == Einet:
einets = [einet]
else:
einets = einet.einets

stats_mean_v = (
stats_mean.view(-1, 1, 1)
.repeat(1, einets[0].config.num_leaves, einets[0].config.num_repetitions)
.view_as(einets[0].leaf.base_leaf.means)
)
stats_std_v = (
stats_std.view(-1, 1, 1)
.repeat(1, einets[0].config.num_leaves, einets[0].config.num_repetitions)
.view_as(einets[0].leaf.base_leaf.log_stds)
)
for net in einets:
net.leaf.base_leaf.means.data = stats_mean_v

net.leaf.base_leaf.log_stds.data = torch.log(stats_std_v + 1e-3)


def main(cfg: DictConfig):
preprocess_cfg(cfg)

logger.info(OmegaConf.to_yaml(cfg))
hydra_cfg = HydraConfig.get()
run_dir = hydra_cfg.runtime.output_dir
print("Working directory : {}".format(os.getcwd()))

# Save config
with open(os.path.join(run_dir, "config.yaml"), "w") as f:
OmegaConf.save(config=cfg, f=f)

# Safe run_dir in config (use open_dict to make config writable)
with open_dict(cfg):
cfg.run_dir = run_dir

results_dir, cfg = setup_experiment(name="simple-einet", cfg=cfg, remove_if_exists=True)
logger.info("\n" + OmegaConf.to_yaml(cfg, resolve=True))
logger.info("Run dir: " + run_dir)

seed_everything(cfg.seed, workers=True)

if not cfg.wandb:
os.environ["WANDB_MODE"] = "offline"

seed_everything(cfg.seed, workers=True)

# Setup devices
if torch.cuda.is_available():
accelerator = "gpu"
if type(cfg.gpu) == int:
devices = [int(cfg.gpu)]
else:
devices = [int(g) for g in cfg.gpu]
else:
accelerator = "cpu"
devices = None

print("Training model...")
# Create dataloader
normalize = cfg.dist == Dist.NORMAL
train_loader, val_loader, test_loader = build_dataloader(cfg=cfg, loop=False, normalize=normalize)

# Create callbacks
cfg_container = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
logger_wandb = WandbLogger(
name=cfg.tag,
project=cfg.project_name,
group=cfg.group_tag,
offline=not cfg.wandb,
config=cfg_container,
reinit=False,
save_dir=run_dir,
settings=wandb.Settings(start_method="thread"),
)

# Load or create model
if cfg.load_and_eval:
model = load_from_checkpoint(
results_dir, load_fn=SpnGenerative.load_from_checkpoint, cfg=cfg
run_dir,
load_fn=SpnGenerative.load_from_checkpoint,
cfg=cfg,
)
else:
if cfg.classification:
model = SpnDiscriminative(cfg)
model = SpnDiscriminative(cfg, steps_per_epoch=len(train_loader))
else:
model = SpnGenerative(cfg)
model = SpnGenerative(cfg, steps_per_epoch=len(train_loader))

seed_everything(cfg.seed, workers=True)

print("Training model...")
# Create dataloader
normalize = cfg.dist == Dist.NORMAL
train_loader, val_loader, test_loader = build_dataloader(
cfg=cfg, loop=False, normalize=normalize
)
if cfg.einet_mixture:
model.spn.initialize(dataloader=train_loader, device=devices[0])

# Create callbacks
logger_wandb = WandbLogger(name=cfg.tag, project="einet", group=cfg.group_tag,
offline=not cfg.wandb)
if cfg.init_leaf_data:
logger.info("Initializing leaf distributions from data statistics")
init_einet_stats(model.spn, train_loader)

# Store number of model parameters
summary = ModelSummary(model, max_depth=-1)
print("Model:")
print(model)
print("Summary:")
print(summary)
logger_wandb.experiment.config[
"trainable_parameters"
] = summary.trainable_parameters
logger_wandb.experiment.config["trainable_parameters_leaf"] = summary.param_nums[
summary.layer_names.index("spn.leaf")
]
logger_wandb.experiment.config["trainable_parameters_sums"] = summary.param_nums[
summary.layer_names.index("spn.einsum_layers")
]
# logger_wandb.experiment.config["trainable_parameters"] = summary.trainable_parameters
# logger_wandb.experiment.config["trainable_parameters_leaf"] = summary.param_nums[
# summary.layer_names.index("spn.leaf")
# ]
# logger_wandb.experiment.config["trainable_parameters_sums"] = summary.param_nums[
# summary.layer_names.index("spn.einsum_layers")
# ]

# Setup devices
if torch.cuda.is_available():
accelerator = "gpu"
devices = [int(cfg.gpu)]
# elif torch.backends.mps.is_available(): # Currently leads to errors
# accelerator = "mps"
# devices = 1
else:
accelerator = "cpu"
devices = None

# Setup callbacks
callbacks = []
Expand All @@ -111,28 +184,30 @@ def main(cfg: DictConfig):
precision=cfg.precision,
fast_dev_run=cfg.debug,
profiler=cfg.profiler,
default_root_dir=run_dir,
enable_checkpointing=False,
)

if not cfg.load_and_eval:
# Fit model
trainer.fit(
model=model, train_dataloaders=train_loader, val_dataloaders=val_loader
)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

print("Evaluating model...")

if "synth" in cfg.dataset and not cfg.classification:
plot_distribution(
model=model.spn, dataset_name=cfg.dataset, logger_wandb=logger_wandb
)
plot_distribution(model=model.spn, dataset_name=cfg.dataset, logger_wandb=logger_wandb)

# Evaluate spn reconstruction error
trainer.test(
model=model, dataloaders=[train_loader, val_loader, test_loader], verbose=True
)
trainer.test(model=model, dataloaders=[train_loader, val_loader, test_loader], verbose=True)

print("Finished evaluation...")

# Save checkpoint in general models directory to be used across experiments
chpt_path = os.path.join(run_dir, "model.pt")
logger.info("Saving checkpoint: " + chpt_path)
trainer.save_checkpoint(chpt_path)


def preprocess_cfg(cfg: DictConfig):
"""
Preprocesses the config file.
Expand All @@ -151,7 +226,6 @@ def preprocess_cfg(cfg: DictConfig):
if "results_dir" not in cfg:
cfg.results_dir = os.getenv("RESULTS_DIR", os.path.join(home, "results"))


# If FP16/FP32 is given, convert to int (else it's "bf16", keep string)
if cfg.precision == "16" or cfg.precision == "32":
cfg.precision = int(cfg.precision)
Expand All @@ -165,7 +239,26 @@ def preprocess_cfg(cfg: DictConfig):
if "group_tag" not in cfg:
cfg.group_tag = None

if "seed" not in cfg:
cfg.env.seed = int(time.time())

if cfg.K > 0:
cfg.I = cfg.K
cfg.S = cfg.K

cfg.dist = Dist[cfg.dist.upper()]


@hydra.main(version_base=None, config_path="./conf", config_name="config")
def main_hydra(cfg: DictConfig):
try:
main(cfg)
except Exception as e:
logging.critical(e, exc_info=True) # log exception info at CRITICAL log level
finally:
# Close wandb instance. Necessary for hydra multi-runs where main() is called multipel times
wandb.finish()


if __name__ == "__main__":
main()
main_hydra()
Loading

0 comments on commit c0fb9fa

Please sign in to comment.