diff --git a/PROJECT_NAME b/PROJECT_NAME index 09a1127..01feef8 100644 --- a/PROJECT_NAME +++ b/PROJECT_NAME @@ -1 +1 @@ -simple-einet \ No newline at end of file +simple-einet diff --git a/conf/config.yaml b/conf/config.yaml index 7e5684e..4445055 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -4,13 +4,27 @@ defaults: - override hydra/job_logging: colorlog - override hydra/hydra_logging: colorlog +# Hydra config +hydra: + run: + dir: "${results_dir}/${dataset}/${now:%Y-%m-%d_%H-%M-%S}_${oc.select:tag,}" + sweep: + dir: "${hydra.run.dir}/${hydra.job.name}" + subdir: "${hydra.run.dir}/${hydra.job.num}" + job_logging: + handlers: + file: + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log # Fixed in hydra-colorlog version 1.2.1 + + # Default set of configurations. -data_dir: ??? -results_dir: ??? -lr: 0.1 +data_dir: "${oc.env:DATA_DIR}/" +results_dir: "${oc.env:RESULTS_DIR}/${project_name}" +project_name: "simple-einet-dc" +lr: 0.0005 batch_size: 64 n_bits: 8 -num_workers: 8 +num_workers: 16 temperature_leaves: 1.0 temperature_sums: 1.0 dropout: 0.0 @@ -19,18 +33,19 @@ max_sigma: 2.0 dry_run: False seed: 1 log_interval: 10 -classification: False +classification: True device: "cuda" debug: False S: 10 I: 10 D: 3 R: 1 +K: -1 gpu: 0 epochs: 10 load_and_eval: False cp: False -dist: "binomial" +dist: "normal" precision: 16 group_tag: ??? tag: ??? @@ -40,3 +55,5 @@ profiler: ??? log_weights: False dataset: ??? num_classes: 10 +init_leaf_data: True +einet_mixture: False diff --git a/main.py b/main.py index 03e2bd7..56fc6e6 100644 --- a/main.py +++ b/main.py @@ -324,4 +324,4 @@ def test(model, device, loader, tag): ) print(f"Result directory: {result_dir}") - print("Done.") \ No newline at end of file + print("Done.") diff --git a/main_pl.py b/main_pl.py index 9e0208b..ca8f882 100644 --- a/main_pl.py +++ b/main_pl.py @@ -1,7 +1,16 @@ #!/usr/bin/env python +from simple_einet.distributions.normal import Normal +from simple_einet.einet import Einet +import omegaconf +import time +import wandb +from hydra.core.hydra_config import HydraConfig import logging -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, open_dict import os +import sys +from rich.traceback import install +install() import hydra import pytorch_lightning as pl @@ -21,46 +30,122 @@ from models_pl import SpnDiscriminative, SpnGenerative from simple_einet.data import Dist from simple_einet.data import build_dataloader +from tqdm import tqdm # A logger for this file logger = logging.getLogger(__name__) -@hydra.main(version_base=None, config_path="./conf", config_name="config") + +def init_einet_stats(einet: Einet, dataloader: torch.utils.data.DataLoader): + stats_mean = None + stats_std = None + for batch in tqdm(dataloader, desc="Leaf Parameter Initialization"): + data, label = batch + if stats_mean == None: + stats_mean = data.mean(dim=0) + stats_std = data.std(dim=0) + else: + stats_mean += data.mean(dim=0) + stats_std += data.std(dim=0) + + stats_mean /= len(dataloader) + stats_std /= len(dataloader) + + if einet.config.leaf_type == Normal: + if type(einet) == Einet: + einets = [einet] + else: + einets = einet.einets + + stats_mean_v = ( + stats_mean.view(-1, 1, 1) + .repeat(1, einets[0].config.num_leaves, einets[0].config.num_repetitions) + .view_as(einets[0].leaf.base_leaf.means) + ) + stats_std_v = ( + stats_std.view(-1, 1, 1) + .repeat(1, einets[0].config.num_leaves, einets[0].config.num_repetitions) + .view_as(einets[0].leaf.base_leaf.log_stds) + ) + for net in einets: + net.leaf.base_leaf.means.data = stats_mean_v + + net.leaf.base_leaf.log_stds.data = torch.log(stats_std_v + 1e-3) + + def main(cfg: DictConfig): preprocess_cfg(cfg) - logger.info(OmegaConf.to_yaml(cfg)) + hydra_cfg = HydraConfig.get() + run_dir = hydra_cfg.runtime.output_dir + print("Working directory : {}".format(os.getcwd())) + + # Save config + with open(os.path.join(run_dir, "config.yaml"), "w") as f: + OmegaConf.save(config=cfg, f=f) + + # Safe run_dir in config (use open_dict to make config writable) + with open_dict(cfg): + cfg.run_dir = run_dir - results_dir, cfg = setup_experiment(name="simple-einet", cfg=cfg, remove_if_exists=True) + logger.info("\n" + OmegaConf.to_yaml(cfg, resolve=True)) + logger.info("Run dir: " + run_dir) seed_everything(cfg.seed, workers=True) if not cfg.wandb: os.environ["WANDB_MODE"] = "offline" + seed_everything(cfg.seed, workers=True) + + # Setup devices + if torch.cuda.is_available(): + accelerator = "gpu" + if type(cfg.gpu) == int: + devices = [int(cfg.gpu)] + else: + devices = [int(g) for g in cfg.gpu] + else: + accelerator = "cpu" + devices = None + + print("Training model...") + # Create dataloader + normalize = cfg.dist == Dist.NORMAL + train_loader, val_loader, test_loader = build_dataloader(cfg=cfg, loop=False, normalize=normalize) + + # Create callbacks + cfg_container = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + logger_wandb = WandbLogger( + name=cfg.tag, + project=cfg.project_name, + group=cfg.group_tag, + offline=not cfg.wandb, + config=cfg_container, + reinit=False, + save_dir=run_dir, + settings=wandb.Settings(start_method="thread"), + ) + # Load or create model if cfg.load_and_eval: model = load_from_checkpoint( - results_dir, load_fn=SpnGenerative.load_from_checkpoint, cfg=cfg + run_dir, + load_fn=SpnGenerative.load_from_checkpoint, + cfg=cfg, ) else: if cfg.classification: - model = SpnDiscriminative(cfg) + model = SpnDiscriminative(cfg, steps_per_epoch=len(train_loader)) else: - model = SpnGenerative(cfg) + model = SpnGenerative(cfg, steps_per_epoch=len(train_loader)) - seed_everything(cfg.seed, workers=True) - - print("Training model...") - # Create dataloader - normalize = cfg.dist == Dist.NORMAL - train_loader, val_loader, test_loader = build_dataloader( - cfg=cfg, loop=False, normalize=normalize - ) + if cfg.einet_mixture: + model.spn.initialize(dataloader=train_loader, device=devices[0]) - # Create callbacks - logger_wandb = WandbLogger(name=cfg.tag, project="einet", group=cfg.group_tag, - offline=not cfg.wandb) + if cfg.init_leaf_data: + logger.info("Initializing leaf distributions from data statistics") + init_einet_stats(model.spn, train_loader) # Store number of model parameters summary = ModelSummary(model, max_depth=-1) @@ -68,26 +153,14 @@ def main(cfg: DictConfig): print(model) print("Summary:") print(summary) - logger_wandb.experiment.config[ - "trainable_parameters" - ] = summary.trainable_parameters - logger_wandb.experiment.config["trainable_parameters_leaf"] = summary.param_nums[ - summary.layer_names.index("spn.leaf") - ] - logger_wandb.experiment.config["trainable_parameters_sums"] = summary.param_nums[ - summary.layer_names.index("spn.einsum_layers") - ] + # logger_wandb.experiment.config["trainable_parameters"] = summary.trainable_parameters + # logger_wandb.experiment.config["trainable_parameters_leaf"] = summary.param_nums[ + # summary.layer_names.index("spn.leaf") + # ] + # logger_wandb.experiment.config["trainable_parameters_sums"] = summary.param_nums[ + # summary.layer_names.index("spn.einsum_layers") + # ] - # Setup devices - if torch.cuda.is_available(): - accelerator = "gpu" - devices = [int(cfg.gpu)] - # elif torch.backends.mps.is_available(): # Currently leads to errors - # accelerator = "mps" - # devices = 1 - else: - accelerator = "cpu" - devices = None # Setup callbacks callbacks = [] @@ -111,28 +184,30 @@ def main(cfg: DictConfig): precision=cfg.precision, fast_dev_run=cfg.debug, profiler=cfg.profiler, + default_root_dir=run_dir, + enable_checkpointing=False, ) if not cfg.load_and_eval: # Fit model - trainer.fit( - model=model, train_dataloaders=train_loader, val_dataloaders=val_loader - ) + trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader) print("Evaluating model...") if "synth" in cfg.dataset and not cfg.classification: - plot_distribution( - model=model.spn, dataset_name=cfg.dataset, logger_wandb=logger_wandb - ) + plot_distribution(model=model.spn, dataset_name=cfg.dataset, logger_wandb=logger_wandb) # Evaluate spn reconstruction error - trainer.test( - model=model, dataloaders=[train_loader, val_loader, test_loader], verbose=True - ) + trainer.test(model=model, dataloaders=[train_loader, val_loader, test_loader], verbose=True) print("Finished evaluation...") + # Save checkpoint in general models directory to be used across experiments + chpt_path = os.path.join(run_dir, "model.pt") + logger.info("Saving checkpoint: " + chpt_path) + trainer.save_checkpoint(chpt_path) + + def preprocess_cfg(cfg: DictConfig): """ Preprocesses the config file. @@ -151,7 +226,6 @@ def preprocess_cfg(cfg: DictConfig): if "results_dir" not in cfg: cfg.results_dir = os.getenv("RESULTS_DIR", os.path.join(home, "results")) - # If FP16/FP32 is given, convert to int (else it's "bf16", keep string) if cfg.precision == "16" or cfg.precision == "32": cfg.precision = int(cfg.precision) @@ -165,7 +239,26 @@ def preprocess_cfg(cfg: DictConfig): if "group_tag" not in cfg: cfg.group_tag = None + if "seed" not in cfg: + cfg.env.seed = int(time.time()) + + if cfg.K > 0: + cfg.I = cfg.K + cfg.S = cfg.K + cfg.dist = Dist[cfg.dist.upper()] + +@hydra.main(version_base=None, config_path="./conf", config_name="config") +def main_hydra(cfg: DictConfig): + try: + main(cfg) + except Exception as e: + logging.critical(e, exc_info=True) # log exception info at CRITICAL log level + finally: + # Close wandb instance. Necessary for hydra multi-runs where main() is called multipel times + wandb.finish() + + if __name__ == "__main__": - main() + main_hydra() diff --git a/models_pl.py b/models_pl.py index 156e9ba..f43a129 100644 --- a/models_pl.py +++ b/models_pl.py @@ -1,4 +1,5 @@ from abc import ABC +from torch.optim.lr_scheduler import OneCycleLR import argparse import os from argparse import Namespace @@ -24,7 +25,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from simple_einet.data import build_dataloader -from simple_einet.einet import EinetConfig, Einet +from simple_einet.einet import EinetConfig, Einet, EinetMixture from simple_einet.distributions.binomial import Binomial @@ -42,6 +43,7 @@ def make_einet(cfg, num_classes: int = 1): Returns: EinsumNetworks model. """ + image_shape = get_data_shape(cfg.dataset) # leaf_kwargs, leaf_type = {"total_count": 255}, Binomial leaf_kwargs, leaf_type = get_distribution( @@ -62,20 +64,24 @@ def make_einet(cfg, num_classes: int = 1): cross_product=cfg.cp, log_weights=cfg.log_weights, ) - return Einet(config) + if cfg.einet_mixture: + return EinetMixture(n_components=num_classes, einet_config=config) + else: + return Einet(config) class LitModel(pl.LightningModule, ABC): - def __init__(self, cfg: DictConfig, name: str) -> None: + def __init__(self, cfg: DictConfig, name: str, steps_per_epoch: int) -> None: super().__init__() self.cfg = cfg self.image_shape = get_data_shape(cfg.dataset) self.rtpt = RTPT( - name_initials="SL", + name_initials="SB", experiment_name="einet_" + name + ("_" + str(cfg.tag) if cfg.tag else ""), max_iterations=cfg.epochs + 1, ) self.save_hyperparameters() + self.steps_per_epoch = steps_per_epoch def preprocess(self, data: torch.Tensor): if self.cfg.dist == Dist.BINOMIAL: @@ -90,6 +96,16 @@ def configure_optimizers(self): milestones=[int(0.7 * self.cfg.epochs), int(0.9 * self.cfg.epochs)], gamma=0.1, ) + + # lr_scheduler = { + # "scheduler": OneCycleLR( + # optimizer, + # max_lr=self.cfg.lr, + # total_steps=self.cfg.epochs * self.steps_per_epoch + # + 1, # +1 b/c 1cycle has a bug in its last step where it upticks the lr again + # ), + # "interval": "step", + # } return [optimizer], [lr_scheduler] def on_train_start(self) -> None: @@ -100,8 +116,8 @@ def on_train_epoch_end(self) -> None: class SpnGenerative(LitModel): - def __init__(self, cfg: DictConfig): - super().__init__(cfg=cfg, name="gen") + def __init__(self, cfg: DictConfig, steps_per_epoch: int): + super().__init__(cfg=cfg, name="gen", steps_per_epoch=steps_per_epoch) self.spn = make_einet(cfg) def training_step(self, train_batch, batch_idx): @@ -164,8 +180,8 @@ class SpnDiscriminative(LitModel): Discriminative SPN model. Models the class conditional data distribution at its C root nodes. """ - def __init__(self, cfg: DictConfig): - super().__init__(cfg, name="disc") + def __init__(self, cfg: DictConfig, steps_per_epoch: int): + super().__init__(cfg, name="disc", steps_per_epoch=steps_per_epoch) # Construct SPN self.spn = make_einet(cfg, num_classes=10) @@ -181,7 +197,7 @@ def training_step(self, train_batch, batch_idx): def validation_step(self, val_batch, batch_idx): loss, accuracy = self._get_cross_entropy_and_accuracy(val_batch) - self.log("Val/accuracy", accuracy) + self.log("Val/accuracy", accuracy, prog_bar=True) self.log("Val/loss", loss) return loss @@ -218,5 +234,6 @@ def _get_cross_entropy_and_accuracy( # NOTE: Don't use nn.CrossEntropyLoss because it expects unnormalized logits # and applies LogSoftmax first loss = self.criterion(ll_y_g_x, labels) + # loss = self.criterion(ll_x_g_y, labels) accuracy = (labels == ll_y_g_x.argmax(-1)).sum() / ll_y_g_x.shape[0] return loss, accuracy diff --git a/requirements.txt b/requirements.txt index c9d4e85..1f5d818 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,16 @@ fast_pytorch_kmeans==0.1.6 -icecream==2.1.2 -matplotlib==3.5.1 -numpy==1.22.3 -Pillow==9.1.1 -pytorch_lightning==1.6.1 -rich==12.4.4 +hydra-core==1.3.1 +icecream==2.1.3 +matplotlib==3.6.2 +numpy==1.23.5 +omegaconf==2.3.0 +Pillow==9.4.0 +pytorch_lightning==1.8.6 +rich==13.0.1 rtpt==0.0.4 -scipy==1.8.0 -torch==1.11.0 -torchvision==0.12.0 -tqdm==4.62.3 +scikit_learn==1.2.0 +scipy==1.10.0 +setuptools==65.6.3 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.64.1 diff --git a/simple_einet/data.py b/simple_einet/data.py index 88e1222..b372404 100644 --- a/simple_einet/data.py +++ b/simple_einet/data.py @@ -19,6 +19,7 @@ from simple_einet.distributions import RatNormal from simple_einet.distributions.binomial import Binomial +from simple_einet.distributions.normal import Normal @dataclass @@ -183,6 +184,9 @@ def get_datasets(cfg, normalize: bool) -> Tuple[Dataset, Dataset, Dataset]: kwargs = dict(root=cfg.data_dir, download=True, transform=transform) + # Custom split generator with fixed seed + split_generator = torch.Generator().manual_seed(1) + # Select the datasets if "synth" in dataset_name: # Train @@ -218,7 +222,8 @@ def get_datasets(cfg, normalize: bool) -> Tuple[Dataset, Dataset, Dataset]: N_train = round(N * 0.9) N_val = N - N_train lenghts = [N_train, N_val] - dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts) + + dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts, generator=split_generator) elif dataset_name == "fmnist" or dataset_name == "fmnist-28": if normalize: @@ -232,7 +237,8 @@ def get_datasets(cfg, normalize: bool) -> Tuple[Dataset, Dataset, Dataset]: N_train = round(N * 0.9) N_val = N - N_train lenghts = [N_train, N_val] - dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts) + + dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts, generator=split_generator) elif "celeba" in dataset_name: if normalize: @@ -251,7 +257,8 @@ def get_datasets(cfg, normalize: bool) -> Tuple[Dataset, Dataset, Dataset]: N_train = round(N * 0.9) N_val = N - N_train lenghts = [N_train, N_val] - dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts) + + dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts, generator=split_generator) dataset_test = CIFAR10(**kwargs, train=False) elif "svhn" in dataset_name: @@ -263,7 +270,8 @@ def get_datasets(cfg, normalize: bool) -> Tuple[Dataset, Dataset, Dataset]: N = len(dataset_train.data) lenghts = [round(N * 0.9), round(N * 0.1)] - dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts) + + dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts, generator=split_generator) dataset_test = SVHN(**kwargs, split="test") if dataset_name == "svhn-extra": @@ -357,6 +365,7 @@ class Dist(str, Enum): """Enum for the distribution of the data.""" NORMAL = "normal" + NORMAL_RAT = "normal_rat" BINOMIAL = "binomial" @@ -375,6 +384,9 @@ def get_distribution(dist, min_sigma, max_sigma): """ if dist == Dist.NORMAL: + leaf_type = Normal + leaf_kwargs = {} + elif dist == Dist.NORMAL_RAT: leaf_type = RatNormal leaf_kwargs = {"min_sigma": min_sigma, "max_sigma": max_sigma} elif dist == Dist.BINOMIAL: diff --git a/simple_einet/einet.py b/simple_einet/einet.py index 36c990c..7d0092e 100644 --- a/simple_einet/einet.py +++ b/simple_einet/einet.py @@ -1,12 +1,14 @@ from collections import defaultdict import logging from dataclasses import dataclass -from typing import Any, Dict, List, Sequence, Type +from operator import xor +from typing import Any, Dict, List, Sequence, Tuple, Type import numpy as np import torch from fast_pytorch_kmeans import KMeans from torch import nn +from torch.utils.data import DataLoader from simple_einet.distributions import AbstractLeaf, RatNormal, truncated_normal_ from simple_einet.einsum_layer import EinsumLayer, EinsumMixingLayer, LinsumLayer, LinsumLayerLogWeights @@ -157,6 +159,76 @@ def forward(self, x: torch.Tensor, marginalization_mask: torch.Tensor = None) -> return x + def forward_dropout_inference( + self, x: torch.Tensor, marginalization_mask: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Inference pass for the Einet model. + + Args: + x (torch.Tensor): Input data of shape [N, C, D], where C is the number of input channels (useful for images) and D is the number of features/random variables (H*W for images). + marginalized_scope: torch.Tensor: (Default value = None) + + Returns: + Log-likelihood tensor of the input: p(X) or p(X | C) if number of classes > 1. + """ + + # Add channel dimension if not present + if x.dim() == 2: # [N, D] + x = x.unsqueeze(1) + + if x.dim() == 4: # [N, C, H, W] + x = x.view(x.shape[0], self.config.num_channels, -1) + + assert x.dim() == 3 + assert x.shape[1] == self.config.num_channels + + # Apply leaf distributions (replace marginalization indicators with 0.0 first) + x = self.leaf(x, marginalization_mask) + + # Pass through intermediate layers + log_exp, log_var = self._forward_layers_dropout_inference(x) + + # Merge results from the different repetitions into the channel dimension + batch_size, features, channels, repetitions = log_exp.size() + assert features == 1 # number of features should be 1 at this point + assert channels == self.config.num_classes + + # If model has multiple reptitions, perform repetition mixing + if self.config.num_repetitions > 1: + # Mix repetitions + log_exp, log_var = self.mixing.forward_dropout_inference(log_exp, log_var) + else: + # Remove repetition index + log_exp = log_exp.squeeze(-1) + log_var = log_var.squeeze(-1) + + # Remove feature dimension + log_exp = log_exp.squeeze(1) + log_var = log_var.squeeze(1) + + # Final shape check + assert log_exp.shape == (batch_size, self.config.num_classes) + assert log_var.shape == (batch_size, self.config.num_classes) + + return log_exp, log_var + + def _forward_layers_dropout_inference(self, log_exp): + """ + Forward pass through the inner sum and product layers. + + Args: + log_exp: Input expectations. + + Returns: + torch.Tensor: Output of the last layer before the root layer. + """ + # Forward to inner product and sum layers + log_var = torch.ones_like(log_exp) * np.NINF + for layer in self.einsum_layers: + log_exp, log_var = layer.forward_dropout_inference(log_exp, log_var) + return log_exp, log_var + def _forward_layers(self, x): """ Forward pass through the inner sum and product layers. @@ -559,7 +631,21 @@ def __init__(self, n_components: int, einet_config: EinetConfig): self.centroids = nn.Parameter(torch.empty(n_components, einet_config.num_features), requires_grad=False) @torch.no_grad() - def initialize(self, data: torch.Tensor): + def initialize(self, data: torch.Tensor=None, dataloader: DataLoader=None, device=None): + assert xor(data is not None, dataloader is not None) + + if dataloader is not None: + # Collect data from dataloader + l = [] + for batch in dataloader: + x, y = batch + l.append(x) + if sum([d.shape[0] for d in l]) > 2000: + break + + data = torch.cat(l, dim=0).to(device) + + data = data.float() # input has to be [n, d] self._kmeans.fit(data.view(data.shape[0], -1)) diff --git a/simple_einet/einsum_layer.py b/simple_einet/einsum_layer.py index 029218e..44f12a7 100644 --- a/simple_einet/einsum_layer.py +++ b/simple_einet/einsum_layer.py @@ -9,6 +9,27 @@ from simple_einet.type_checks import check_valid from simple_einet.utils import SamplingContext, index_one_hot, diff_sample_one_hot +def logsumexp(left, right, mask=None): + """ + Source: https://github.com/pytorch/pytorch/issues/32097 + + Logsumexp with custom scalar mask to allow for negative values in the sum. + + Args: + tensor: + other: + mask: (Default value = None) + + Returns: + + """ + if mask is None: + mask = torch.tensor([1, 1]) + else: + assert mask.shape == (2,), "Invalid mask shape" + + maxes = torch.max(left, right) + return maxes + ((left - maxes).exp() * mask[0] + (right - maxes).exp() * mask[1]).log() class LinsumLayer(AbstractLayer): def __init__( @@ -56,6 +77,89 @@ def _init_weights(self): def _get_normalized_log_weights(self): return F.log_softmax(self.weights, dim=1) + def forward_dropout_inference(self, log_exp_ch: torch.Tensor, log_var_ch: torch.Tensor): + """ + Einsum layer dropout inference pass. + + Args: + log_exp_ch: Input expectations. + log_var_ch: Input variances. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Output expectations and variances + """ + + # Dimensions + N, D, C, R = log_exp_ch.size() + D_out = D // 2 + + ################# + # PRODUCT LAYER # + ################# + + #--------------------------------------------------- + #| 1. Product expectation (default log-likelihood) | + #--------------------------------------------------- + + # Get left and right partition probs + log_exp_left = log_exp_ch[:, 0::2] + log_exp_right = log_exp_ch[:, 1::2] + log_exp_prod = (log_exp_left + log_exp_right).unsqueeze(3) # N x D/2 x Sin x 1 x R + + + #----------------------- + #| 2. Product variance | + #----------------------- + + # Get left and right partition vars + log_var_left = log_var_ch[:, 0::2] + log_var_right = log_var_ch[:, 1::2] + + log_exp_sq_left = log_exp_left * 2 + log_exp_sq_right = log_exp_right * 2 + + log_var_right_term = log_exp_sq_left + log_exp_sq_right + + + log_var_left_term_left = logsumexp(log_var_left, log_exp_sq_left) + log_var_left_term_right = logsumexp(log_var_right, log_exp_sq_right) + + log_var_left_term = log_var_left_term_left + log_var_left_term_right + + mask = torch.tensor([1, -1]) + log_var_prod = logsumexp(log_var_left_term, log_var_right_term, mask=mask) + + + ############# + # SUM LAYER # + ############# + + # Prepare constants + log_q = np.log(1 - self.dropout) + log_p = np.log(self.dropout) + + # Get log weights + log_weights = self._get_normalized_log_weights().unsqueeze(0) + + #---------------------- + #| 3. Sum expectation | + #---------------------- + + log_exp_sum = log_p + torch.logsumexp(log_exp_prod + log_weights, dim=2) # N x D/2 x Sout x R + + #------------------- + #| 4. Sum variance | + #------------------- + + log_weights_sq = log_weights * 2 + log_exp_prod_sq = log_exp_prod * 2 + log_var_prod = log_var_prod.unsqueeze(3) + + log_var_plus_exp = torch.logsumexp(torch.stack((log_var_prod, log_exp_prod_sq + log_p), dim=-1), dim=-1) + log_var_sum = log_q + torch.logsumexp(log_weights_sq + log_var_plus_exp, dim=2) # dim=1? + + return log_exp_sum, log_var_sum + def forward(self, x: torch.Tensor): """ Einsum layer forward pass. @@ -85,9 +189,6 @@ def forward(self, x: torch.Tensor): # Get log weights log_weights = self._get_normalized_log_weights().unsqueeze(0) - # log_weights = F.log_softmax(self.weights, dim=1).unsqueeze( - # 0 - # ) # 1 x D/2 x Sin x Sout x R prob = torch.logsumexp(prod_output + log_weights, dim=2) # N x D/2 x Sout x R # Save input if input cache is enabled @@ -555,6 +656,7 @@ def forward(self, x): return lls + def sample( self, num_samples: int = None, @@ -635,3 +737,184 @@ def extra_repr(self): return "num_features={}, num_sums_in={}, num_sums_out={}".format( self.num_features, self.num_sums_in, self.num_sums_out, self.num_repetitions ) + +class MixingLayer(AbstractLayer): + def __init__( + self, + num_features: int, + num_sums_in: int, + num_sums_out: int, + dropout: float = 0.0, + ): + super().__init__(num_features, num_repetitions=1) + + self.num_sums_in = check_valid(num_sums_in, int, 1) + self.num_sums_out = check_valid(num_sums_out, int, 1) + self.out_features = num_features + + # Dropout + self.dropout = check_valid(dropout, expected_type=float, lower_bound=0.0, upper_bound=1.0) + self._bernoulli_dist = torch.distributions.Bernoulli(probs=self.dropout) + + # Weights, such that each sumnode has its own weights + ws = torch.randn( + self.num_features, + self.num_sums_out, + self.num_sums_in, + ) + self.weights = nn.Parameter(ws) + + # Necessary for sampling with evidence: Save input during forward pass. + self._is_input_cache_enabled = False + self._input_cache_left = None + self._input_cache_right = None + + def _get_normalized_log_weights(self): + return F.log_softmax(self.weights, dim=2) + + def forward(self, x): + # Save input if input cache is enabled + if self._is_input_cache_enabled: + self._input_cache = x.clone() + + # Dimensions + N, D, IC, R = x.size() + + + # Apply dropout: Set random sum node children to 0 (-inf in log domain) + if self.dropout > 0.0 and self.training: + dropout_indices = self._bernoulli_dist.sample(x.shape).bool() + x[dropout_indices] = np.NINF + + # Get log weights + log_weights = self._get_normalized_log_weights().unsqueeze(0) + lls = torch.logsumexp(x + log_weights, dim=3) + + return lls + + def forward_dropout_inference(self, log_exp_ch, log_var_ch): + # Save input if input cache is enabled + if self._is_input_cache_enabled: + self._input_cache = log_exp_ch.clone() + + # Dimensions + N, D, IC, R = log_exp_ch.size() + + # Get log weights + log_weights = self._get_normalized_log_weights().unsqueeze(0) + + # Prepare constants + log_q = np.log(1 - self.dropout) + log_p = np.log(self.dropout) + + # Expectation + log_exp = log_p + torch.logsumexp(log_exp_ch + log_weights, dim=3) + + # Variance + log_weights_sq = log_weights * 2 + log_exp_ch_sq = log_exp_ch * 2 + log_var_ch = log_var_ch + + log_var_plus_exp = torch.logsumexp(torch.stack((log_var_ch, log_exp_ch_sq + log_p), dim=-1), dim=-1) + log_var = log_q + torch.logsumexp(log_weights_sq + log_var_plus_exp, dim=3) # dim=1? + + return log_exp, log_var + + + def sample( + self, + num_samples: int = None, + context: SamplingContext = None, + differentiable=False, + ) -> Union[SamplingContext, torch.Tensor]: + raise NotImplementedError("Not yet implemented for MixingLayer") + # Sum weights are of shape: [W, H, IC, OC, R] + # We now want to use `indices` to access one in_channel for each in_feature x num_sums_out block + # index is of size in_feature + weights = self.weights + + in_features, num_sums_out, num_sums_in = weights.shape + num_samples = context.num_samples + + self._check_indices_repetition(context) + + if not context.is_differentiable: + # Index with parent indices + weights = weights.unsqueeze(0) # make space for batch dim + weights = weights.expand(num_samples, -1, -1, -1) + p_idxs = context.indices_out.unsqueeze(-1).unsqueeze(-1) # make space for repetition dim + p_idxs = p_idxs.expand(-1, -1, -1, num_sums_in) + weights = weights.gather(dim=2, index=p_idxs) + # Drop dim which was selected via parent indices + weights = weights.squeeze(2) + else: + # Index with parent indices + weights = weights.unsqueeze(0) # make space for batch dim + p_idxs = context.indices_out.unsqueeze(-1) # make space for repetition dim + weights = index_one_hot(weights, index=p_idxs, dim=2) # TODO: is 2 correct? + + # Check dimensions + assert weights.shape == (num_samples, in_features, num_sums_in) + + log_weights = F.log_softmax(weights, dim=2) + + # If evidence is given, adjust the weights with the likelihoods of the observed paths + if self._is_input_cache_enabled and self._input_cache is not None: + # TODO: parallelize this with torch.gather + for i in range(num_samples): + # Reweight the i-th samples weights by its likelihood values at the correct + # repetition + log_weights[i, :, :] += self._input_cache[i, :, :, context.indices_repetition[i]] + + if not context.is_differentiable: + if context.is_mpe: + indices = log_weights.argmax(dim=2) + else: + # Create categorical distribution to sample from + dist = torch.distributions.Categorical(logits=log_weights) + + indices = dist.sample() + else: + if context.is_mpe: + raise NotImplementedError + else: + indices = diff_sample_one_hot( + log_weights, + mode="sample", + dim=2, + hard=context.hard, + tau=context.tau, + ) + + context.indices_out = indices + return context + + def _check_indices_repetition(self, context: SamplingContext): + assert context.indices_repetition.shape[0] == context.indices_out.shape[0] + if self.num_repetitions > 1 and context.indices_repetition is None: + raise Exception( + f"Sum layer has multiple repetitions (num_repetitions=={self.num_repetitions}) but indices_repetition argument was None, expected a Long tensor size #samples." + ) + if self.num_repetitions == 1 and context.indices_repetition is None: + context.indices_repetition = torch.zeros(context.num_samples, dtype=torch.int, device=self.__device) + + def extra_repr(self): + return "num_features={}, num_sums_in={}, num_sums_out={}".format( + self.num_features, self.num_sums_in, self.num_sums_out, self.num_repetitions + ) + + +if __name__ == '__main__': + layer = LinsumLayer(num_features=2, num_sums_in=3, num_sums_out=10, num_repetitions=7, dropout=0.1) + mixing = MixingLayer(num_features=1, num_sums_in=7, num_sums_out=10) + + log_exp = torch.randn(1, 2, 3, 7) + log_var = torch.rand_like(log_exp) + + + # x = layer(log_exp) + # x = mixing(x) + + + log_exp, log_var = layer.forward_dropout_inference(log_exp, log_var) + log_exp, log_var = mixing.forward_dropout_inference(log_exp, log_var) diff --git a/simple_einet/utils.py b/simple_einet/utils.py index 42fe72b..2693d37 100644 --- a/simple_einet/utils.py +++ b/simple_einet/utils.py @@ -1,4 +1,7 @@ from contextlib import contextmanager, nullcontext +import logging +import os +import sys from dataclasses import dataclass from typing import Tuple @@ -356,3 +359,4 @@ def index_one_hot(tensor: torch.Tensor, index: torch.Tensor, dim: int) -> torch. return torch.einsum(einsum_str, tensor, index) else: return torch.sum(tensor * index, dim=dim) +