From ae92983fa1341ecdcd5fb80b884424e6e3bce3ef Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Fri, 22 Sep 2023 11:59:08 +0200 Subject: [PATCH] general: Refactor code and add documentation --- .envrc | 3 + LICENSE.md | 2 +- README.md | 30 +- exp_utils.py | 189 +++---------- main.py | 6 +- main_pl.py | 96 ++----- models_pl.py | 127 +++++---- simple_einet/distributions/__init__.py | 5 +- simple_einet/distributions/abstract_leaf.py | 60 +++- simple_einet/distributions/bernoulli.py | 16 +- simple_einet/distributions/binomial.py | 95 ++++++- simple_einet/distributions/mixture.py | 33 ++- .../distributions/multidistribution.py | 3 +- .../distributions/multivariate_normal.py | 2 +- simple_einet/distributions/normal.py | 92 +++--- simple_einet/distributions/utils.py | 10 +- simple_einet/einet.py | 5 +- simple_einet/einsum_layer.py | 186 +++++++++--- simple_einet/factorized_leaf_layer.py | 82 +++--- simple_einet/layers.py | 57 +++- simple_einet/sampling_utils.py | 267 ++++++++++++++++++ simple_einet/type_checks.py | 98 ++++++- simple_einet/utils.py | 211 +------------- 23 files changed, 1023 insertions(+), 652 deletions(-) create mode 100644 .envrc create mode 100644 simple_einet/sampling_utils.py diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..2076468 --- /dev/null +++ b/.envrc @@ -0,0 +1,3 @@ +export PYTHONPATH=./ +export RESULTS_DIR=${HOME}/results +export DATA_DIR=${HOME}/data diff --git a/LICENSE.md b/LICENSE.md index 0f06c03..241b9b1 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2022 Steven Lang +Copyright (c) 2023 Steven Braun Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the diff --git a/README.md b/README.md index b629f09..8199363 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Simple EinsumNetworks Implementation -This repository contains code for my personal, simplistic, EinsumNetworks implementation. +This repository contains code for my personal, simplistic, EinsumNetworks implementation. Well, as it happens with code all the time, the implementation is not as simplistic as it used to be but at least this was the initial intention. For a speed benchmark comparison against the official EinsumNetworks implementation, check out [benchmark.md](./benchmark/benchmark.md) (short: simple-einet is faster in all dimensions except the input-channel size but scales similar to EinsumNetworks). @@ -17,7 +17,7 @@ python main_pl.py dataset=mnist batch_size=128 epochs=100 dist=normal D=5 I=32 S -Generative learning on MNIST: +Generative training on MNIST: ``` sh python main_pl.py dataset=mnist D=3 I=10 R=1 S=10 lr=0.1 dist=binomial epochs=10 batch_size=128 @@ -55,7 +55,7 @@ out_features = 3 x = torch.randn(batchsize, in_features) # Construct Einet -einet = Einet(EinetConfig(num_features=in_features, depth=2, num_sums=2, num_channels=1, num_leaves=3, num_repetitions=3, num_classes=out_features, dropout=0.0, leaf_type=RatNormal, leaf_kwargs={"min_sigma": 1e-5, "max_sigma": 1.0},)) +einet = Einet(EinetConfig(num_features=in_features, depth=2, num_sums=2, num_channels=1, num_leaves=3, num_repetitions=3, num_classes=out_features, dropout=0.0, leaf_type=Normal)) # Compute log-likelihoods lls = einet(x) @@ -86,7 +86,18 @@ print(f"samples: \n{samples}") ## Citing EinsumNetworks -If you use EinsumNetworks in your publications, please cite the official EinsumNetworks paper. +If you use this software, please cite it as below. + +```bibtex +@software{braun2021simple-einet, +author = {Braun, Steven}, +title = {{Simple-einet: An EinsumNetworks Implementation}}, +url = {https://github.com/braun-steven/simple-einet}, +version = {0.0.1}, +} +``` + +If you use EinsumNetworks as a model in your publications, please cite our official EinsumNetworks paper. ```bibtex @inproceedings{pmlr-v119-peharz20a, @@ -105,14 +116,3 @@ If you use EinsumNetworks in your publications, please cite the official EinsumN code = {https://github.com/cambridge-mlg/EinsumNetworks}, } ``` - -If you use this software, please cite it as below. - -``` bibtex -@software{braun2021simple-einet, -author = {Braun, Steven}, -title = {{Simple-einet: An EinsumNetworks Implementation}}, -url = {https://github.com/braun-steven/simple-einet}, -version = {0.0.1}, -} -``` diff --git a/exp_utils.py b/exp_utils.py index f0981da..037273c 100644 --- a/exp_utils.py +++ b/exp_utils.py @@ -32,6 +32,11 @@ from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import ToTensor + + +from matplotlib.cm import tab10 +from matplotlib import cm + from simple_einet.data import build_dataloader, get_data_shape, Shape, generate_data @@ -399,43 +404,17 @@ def catch_kb_interrupt(output_directory): shutil.move(src, dst) -@torch.no_grad() -def print_num_params(model: nn.Module): - """ - Compute the number of parameters and separate into Flow/SPN parts. - - Args: - model (nn.Module): Model with parameters. - - """ - if type(model) == DistributedDataParallel: - model = model.module - - # Count all parameteres - sum_params = count_params(model) - - # Count SPN parameters - spn_params = sum_params - - # Print - logger.info(f"Number of parameters:") - # logger.info(f"- Total: {sum_params / 1e6: >8.3f}M") - logger.info( - f"- SPN: {spn_params / 1e6: >8.3f}M ({spn_params / sum_params * 100:.1f}%)" - ) - # logger.info(f"- NN: {nn_params / 1e6: >8.3f}M ({nn_params / sum_params * 100:.1f}%)") - - def preprocess( x: torch.Tensor, n_bits: int, ) -> torch.Tensor: + """Preprocess the image.""" x = reduce_bits(x, n_bits) - # x = x.long() return x def reduce_bits(image: torch.Tensor, n_bits: int) -> torch.Tensor: + """Reduce the number of bits of the image.""" image = image * 255 if n_bits < 8: image = torch.floor(image / 2 ** (8 - n_bits)) @@ -449,7 +428,16 @@ def xor(a: bool, b: bool) -> bool: def loss_dict_to_str(running_loss_dict: Dict[str, float], logging_period: int) -> str: - """Create a joined string from a dictionary mapping str->float.""" + """ + Create a joined string from a dictionary mapping str->float. + + Args: + running_loss_dict (Dict[str, float]): Dictionary mapping str->float. + logging_period (int): Logging period. + + Returns: + str: Joined string. + """ loss_str = ", ".join( [ f"{key}: {value / logging_period:.2f}" @@ -460,6 +448,11 @@ def loss_dict_to_str(running_loss_dict: Dict[str, float], logging_period: int) - def plot_tensor(x: torch.Tensor): + """Plot a tensor as an image. + + Args: + x (torch.Tensor): Tensor to plot. + """ plt.figure() if x.dim() == 4: @@ -469,121 +462,6 @@ def plot_tensor(x: torch.Tensor): plt.close() -def build_tensorboard_writer(results_dir): - """ - Build a tensorboard writer. - Args: - results_dir: Directory where to save the tensorboard files. - - Returns: - A tensorboard writer. - """ - return SummaryWriter(os.path.join(results_dir, "tensorboard")) - - -def setup_experiment( - name: str, - args: argparse.Namespace, - remove_if_exists: bool = True, - with_tensorboard=True, -): - """ - Sets up the experiment. - Args: - name: The name of the experiment. - """ - print(f"Arguments: {args}") - - # - if args.dataset == "celeba": - args.dataset = "celeba-small" - - # Check if we want to restore from a finished experiment - if args.load_and_eval is not None: - # Load args - old_dir: pathlib.Path = args.load_and_eval.expanduser() - args_file = os.path.join(old_dir, "args.json") - old_args = argparse.Namespace(**json.load(open(args_file))) - old_args.load_and_eval = args.load_and_eval - old_args.gpu = args.gpu - - print("Loading from existing directory:", old_dir) - print("Loading with existing args:", pprint.pformat(old_args)) - - results_dir = old_dir - args = old_args - else: - # Create result directory - results_dir = make_results_dir( - results_dir=args.results_dir, - experiment_name=name, - tag=args.tag, - dataset_name=args.dataset, - remove_if_exists=remove_if_exists, - ) - # Save args to file - save_args(results_dir, args) - print(f"Results directory: {results_dir}") - # Setup tensorboard - if with_tensorboard: - writer = build_tensorboard_writer(results_dir) - else: - writer = None - - if torch.cuda.is_available(): - device = torch.device("cuda:" + str(args.gpu)) - print("Using GPU device", torch.cuda.current_device()) - else: - device = torch.device("cpu") - print("Using device:", device) - seed_all_rng(args.seed) - cudnn.benchmark = True - - # Image shape - image_shape: Shape = get_data_shape(args.dataset) - - # Create RTPT object - rtpt = RTPT( - name_initials="SL", - experiment_name=name + "_" + str(args.tag), - max_iterations=args.epochs, - ) - - # Start the RTPT tracking - rtpt.start() - - return ( - args, - results_dir, - writer, - device, - image_shape, - rtpt, - ) - -def setup_experiment(name: str, cfg: DictConfig, remove_if_exists: bool = False): - """ - Sets up the experiment. - Args: - name: The name of the experiment. - """ - # Create result directory - results_dir = make_results_dir( - results_dir=cfg.results_dir, - experiment_name=name, - tag=cfg.tag, - dataset_name=cfg.dataset, - remove_if_exists=remove_if_exists, - ) - # Save args to file - # save_args(results_dir, cfg) - - # Save args to file - print(f"Results directory: {results_dir}") - seed_all_rng(cfg.seed) - cudnn.benchmark = True - return results_dir, cfg - def anneal_tau(epoch, max_epochs): """Anneal the softmax temperature tau based on the epoch progress.""" return max(0.5, np.exp(-1 / max_epochs * epoch)) @@ -614,13 +492,6 @@ def save_samples(generate_samples, samples_dir, num_samples, nrow): torchvision.utils.save_image(grid, os.path.join(samples_dir, f"{i}.png")) -from matplotlib.cm import tab10 -from matplotlib import cm - -TEXTWIDTH = 5.78853 -LINEWIDTH = 0.75 -ARROW_HEADWIDTH = 5 -colors = tab10.colors def get_figsize(scale: float, aspect_ratio=0.8) -> Tuple[float, float]: @@ -635,12 +506,16 @@ def get_figsize(scale: float, aspect_ratio=0.8) -> Tuple[float, float]: Tuple: Tuple containing (width, height) of the figure. """ + TEXTWIDTH = 5.78853 height = aspect_ratio * TEXTWIDTH widht = TEXTWIDTH return (scale * widht, scale * height) def set_style(): + """ + Sets the style of the matplotlib plots to use LaTeX fonts and the SciencePlots package. + """ matplotlib.use("pgf") plt.style.use(["science", "grid"]) # Need SciencePlots pip package matplotlib.rcParams.update( @@ -654,6 +529,16 @@ def set_style(): def plot_distribution(model, dataset_name, logger_wandb: WandbLogger = None): + """ + Plots the learned probability density function (PDF) represented by the given model + on a 2D grid of points sampled from the specified dataset. + + Args: + model (nn.Module): The model to use for generating the PDF. + dataset_name (str): The name of the dataset to sample points from. + logger_wandb (WandbLogger, optional): The logger to use for logging the plot image to WandB. + Defaults to None. + """ with torch.no_grad(): data, targets = generate_data(dataset_name, n_samples=1000) fig = plt.figure(figsize=get_figsize(1.0)) @@ -685,7 +570,7 @@ def plot_distribution(model, dataset_name, logger_wandb: WandbLogger = None): lw=0.5, s=10, alpha=0.5, - color=colors[1], + color=tab10.colors[1], ) plt.xlabel("$X_0$") diff --git a/main.py b/main.py index 56fc6e6..137c7be 100644 --- a/main.py +++ b/main.py @@ -161,9 +161,11 @@ def test(model, device, loader, tag): num_leaves=args.I, num_repetitions=args.R, num_classes=num_classes, - leaf_type=Binomial, - leaf_kwargs={"total_count": n_bins - 1}, + leaf_type=Normal, + leaf_kwargs={}, + # leaf_kwargs={"total_count": n_bins - 1}, dropout=0.0, + cross_product=True, ) model = Einet(config).to(device) print( diff --git a/main_pl.py b/main_pl.py index 2508cf2..62932bb 100644 --- a/main_pl.py +++ b/main_pl.py @@ -1,6 +1,4 @@ #!/usr/bin/env python -from simple_einet.distributions.normal import Normal -from simple_einet.einet import Einet import omegaconf import time import wandb @@ -8,7 +6,6 @@ import logging from omegaconf import DictConfig, OmegaConf, open_dict import os -import sys from rich.traceback import install install() @@ -24,62 +21,31 @@ ) from exp_utils import ( - setup_experiment, load_from_checkpoint, plot_distribution, ) from models_pl import SpnDiscriminative, SpnGenerative from simple_einet.data import Dist from simple_einet.data import build_dataloader -from tqdm import tqdm +from simple_einet.sampling_utils import init_einet_stats # A logger for this file logger = logging.getLogger(__name__) -def init_einet_stats(einet: Einet, dataloader: torch.utils.data.DataLoader): - stats_mean = None - stats_std = None - for batch in tqdm(dataloader, desc="Leaf Parameter Initialization"): - data, label = batch - if stats_mean == None: - stats_mean = data.mean(dim=0) - stats_std = data.std(dim=0) - else: - stats_mean += data.mean(dim=0) - stats_std += data.std(dim=0) - - stats_mean /= len(dataloader) - stats_std /= len(dataloader) - - if einet.config.leaf_type == Normal: - if type(einet) == Einet: - einets = [einet] - else: - einets = einet.einets - - stats_mean_v = ( - stats_mean.view(-1, 1, 1) - .repeat(1, einets[0].config.num_leaves, einets[0].config.num_repetitions) - .view_as(einets[0].leaf.base_leaf.means) - ) - stats_std_v = ( - stats_std.view(-1, 1, 1) - .repeat(1, einets[0].config.num_leaves, einets[0].config.num_repetitions) - .view_as(einets[0].leaf.base_leaf.log_stds) - ) - for net in einets: - net.leaf.base_leaf.means.data = stats_mean_v - - net.leaf.base_leaf.log_stds.data = torch.log(stats_std_v + 1e-3) - - def main(cfg: DictConfig): + """ + Main function for training and evaluating an Einet. + + Args: + cfg: Config file. + """ preprocess_cfg(cfg) + # Get hydra config hydra_cfg = HydraConfig.get() run_dir = hydra_cfg.runtime.output_dir - print("Working directory : {}".format(os.getcwd())) + logger.info("Working directory : {}".format(os.getcwd())) # Save config with open(os.path.join(run_dir, "config.yaml"), "w") as f: @@ -97,6 +63,7 @@ def main(cfg: DictConfig): if not cfg.wandb: os.environ["WANDB_MODE"] = "offline" + # Ensure that everything is properly seeded seed_everything(cfg.seed, workers=True) # Setup devices @@ -108,16 +75,16 @@ def main(cfg: DictConfig): devices = [int(g) for g in cfg.gpu] else: accelerator = "cpu" - devices = None + devices = 1 - print("Training model...") + logger.info("Training model...") # Create dataloader normalize = cfg.dist == Dist.NORMAL train_loader, val_loader, test_loader = build_dataloader( dataset_name=cfg.dataset, data_dir=cfg.data_dir, batch_size=cfg.batch_size, - num_workers=cfg.num_workers, + num_workers=min(cfg.num_workers, os.cpu_count()), loop=False, normalize=normalize, ) @@ -130,7 +97,7 @@ def main(cfg: DictConfig): group=cfg.group_tag, offline=not cfg.wandb, config=cfg_container, - reinit=False, + reinit=True, save_dir=run_dir, settings=wandb.Settings(start_method="thread"), ) @@ -148,7 +115,14 @@ def main(cfg: DictConfig): else: model = SpnGenerative(cfg, steps_per_epoch=len(train_loader)) + if cfg.torch_compile: # Doesn't seem to work with einsum yet + # Rase an error since einsum doesn't seem to work with compilation yet + # model = torch.compile(model) + raise NotImplementedError("Torch compilation not yet supported with einsum.") + if cfg.einet_mixture: + # If we chose a mixture of einets, we need to initialize the mixture weights + logger.info("Initializing Einet mixture weights") model.spn.initialize(dataloader=train_loader, device=devices[0]) if cfg.init_leaf_data: @@ -157,17 +131,10 @@ def main(cfg: DictConfig): # Store number of model parameters summary = ModelSummary(model, max_depth=-1) - print("Model:") - print(model) - print("Summary:") - print(summary) - # logger_wandb.experiment.config["trainable_parameters"] = summary.trainable_parameters - # logger_wandb.experiment.config["trainable_parameters_leaf"] = summary.param_nums[ - # summary.layer_names.index("spn.leaf") - # ] - # logger_wandb.experiment.config["trainable_parameters_sums"] = summary.param_nums[ - # summary.layer_names.index("spn.einsum_layers") - # ] + logger.info("Model:") + logger.info(model) + logger.info("Summary:") + logger.info(summary) # Setup callbacks callbacks = [] @@ -185,7 +152,6 @@ def main(cfg: DictConfig): # Create trainer trainer = pl.Trainer( max_epochs=cfg.epochs, - # max_steps=cfg.max_steps, logger=logger_wandb, accelerator=accelerator, devices=devices, @@ -195,22 +161,21 @@ def main(cfg: DictConfig): profiler=cfg.profiler, default_root_dir=run_dir, enable_checkpointing=False, - detect_anomaly=True + detect_anomaly=True, ) if not cfg.load_and_eval: # Fit model trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader) - print("Evaluating model...") + logger.info("Evaluating model...") if "synth" in cfg.dataset and not cfg.classification: plot_distribution(model=model.spn, dataset_name=cfg.dataset, logger_wandb=logger_wandb) # Evaluate spn reconstruction error trainer.test(model=model, dataloaders=[train_loader, val_loader, test_loader], verbose=True) - - print("Finished evaluation...") + logger.info("Finished evaluation...") # Save checkpoint in general models directory to be used across experiments chpt_path = os.path.join(run_dir, "model.pt") @@ -252,10 +217,7 @@ def preprocess_cfg(cfg: DictConfig): if "seed" not in cfg: cfg.seed = int(time.time()) - if cfg.K > 0: - cfg.I = cfg.K - cfg.S = cfg.K - + # Convert dist string to enum cfg.dist = Dist[cfg.dist.upper()] diff --git a/models_pl.py b/models_pl.py index d341e33..345197e 100644 --- a/models_pl.py +++ b/models_pl.py @@ -1,47 +1,32 @@ from abc import ABC -from torch.optim.lr_scheduler import OneCycleLR -import argparse -import os -from argparse import Namespace -from typing import Dict, Any, Tuple -import numpy as np -from omegaconf import DictConfig +from typing import Tuple +import pytorch_lightning as pl import torch -from torch import nn import torch.nn.parallel import torch.utils.data -import pytorch_lightning as pl import torchvision +from omegaconf import DictConfig from rtpt import RTPT +from torch import nn -from torch.nn import functional as F -from args import parse_args from simple_einet.data import get_data_shape, Dist, get_distribution -from exp_utils import ( - load_from_checkpoint, -) -from pytorch_lightning import seed_everything -from pytorch_lightning.loggers import TensorBoardLogger - from simple_einet.einet import EinetConfig, Einet, EinetMixture -from simple_einet.einsum_layer import logsumexp -from simple_einet.distributions.binomial import Binomial - # Translate the dataloader index to the dataset name DATALOADER_ID_TO_SET_NAME = {0: "train", 1: "val", 2: "test"} -def make_einet(cfg, num_classes: int = 1) -> Einet: +def make_einet(cfg, num_classes: int = 1) -> EinetMixture | Einet: """ - Make an EinsumNetworks model based off the given arguments. + Make an Einet model based off the given arguments. Args: cfg: Arguments parsed from argparse. + num_classes: Number of classes to model. Returns: - EinsumNetworks model. + Einet model. """ image_shape = get_data_shape(cfg.dataset) @@ -69,6 +54,22 @@ def make_einet(cfg, num_classes: int = 1) -> Einet: class LitModel(pl.LightningModule, ABC): + """ + LightningModule for training a model using PyTorch Lightning. + + Args: + cfg (DictConfig): Configuration dictionary. + name (str): Name of the model. + steps_per_epoch (int): Number of steps per epoch. + + Attributes: + cfg (DictConfig): Configuration dictionary. + image_shape (ImageShape): Shape of the input data. + rtpt (RTPT): RTPT logger. + steps_per_epoch (int): Number of steps per epoch. + + """ + def __init__(self, cfg: DictConfig, name: str, steps_per_epoch: int) -> None: super().__init__() self.cfg = cfg @@ -82,28 +83,22 @@ def __init__(self, cfg: DictConfig, name: str, steps_per_epoch: int) -> None: self.steps_per_epoch = steps_per_epoch def preprocess(self, data: torch.Tensor): + """Preprocess data before passing it to the model.""" if self.cfg.dist == Dist.BINOMIAL: data *= 255.0 return data def configure_optimizers(self): + """ + Configure the optimizer and learning rate scheduler. + """ optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.lr) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[int(0.7 * self.cfg.epochs), int(0.9 * self.cfg.epochs)], gamma=0.1, ) - - # lr_scheduler = { - # "scheduler": OneCycleLR( - # optimizer, - # max_lr=self.cfg.lr, - # total_steps=self.cfg.epochs * self.steps_per_epoch - # + 1, # +1 b/c 1cycle has a bug in its last step where it upticks the lr again - # ), - # "interval": "step", - # } return [optimizer], [lr_scheduler] def on_train_start(self) -> None: @@ -114,6 +109,17 @@ def on_train_epoch_end(self) -> None: class SpnGenerative(LitModel): + """ + A class representing a generative model based on Sum-Product Networks (SPNs). + + Args: + cfg (DictConfig): A configuration dictionary. + steps_per_epoch (int): The number of steps per epoch. + + Attributes: + spn (einet.EinSumProductNetwork): The SPN model. + """ + def __init__(self, cfg: DictConfig, steps_per_epoch: int): super().__init__(cfg=cfg, name="gen", steps_per_epoch=steps_per_epoch) self.spn = make_einet(cfg) @@ -122,50 +128,74 @@ def training_step(self, train_batch, batch_idx): data, labels = train_batch data = self.preprocess(data) nll = self.negative_log_likelihood(data) - self.log("Train/loss", nll) + self.log("Train/loss", nll, prog_bar=True) return nll def validation_step(self, val_batch, batch_idx): data, labels = val_batch data = self.preprocess(data) nll = self.negative_log_likelihood(data) - self.log("Val/loss", nll) + self.log("Val/loss", nll, prog_bar=True) return nll - def negative_log_likelihood(self, data): + def negative_log_likelihood(self, data, reduction="mean"): """ Compute negative log likelihood of data. Args: data: Data to compute negative log likelihood of. + reduction: Reduction method. Returns: Negative log likelihood of data. """ - nll = -1 * self.spn(data).mean() - return nll + nll = -1 * self.spn(data) + if reduction == "mean": + return nll.mean() + elif reduction == "sum": + return nll.sum() + else: + raise ValueError(f"Unknown reduction {reduction}") + + def generate_samples(self, num_samples: int, differentiable: bool): + """ + Generates a batch of samples from the model. + + Args: + num_samples (int): The number of samples to generate. + differentiable (bool): Whether to use a differentiable sampling method. - def generate_samples(self, num_samples: int): - samples = self.spn.sample(num_samples=num_samples, mpe_at_leaves=True).view(-1, *self.image_shape) + Returns: + torch.Tensor: A tensor of shape (num_samples, *self.image_shape) containing the generated samples. + """ + if not differentiable: + samples = self.spn.sample(num_samples=num_samples, mpe_at_leaves=True).view(-1, *self.image_shape) + else: + samples = self.spn.sample_differentiable(num_samples=num_samples, mpe_at_leaves=True, hard=True).view( + -1, *self.image_shape + ) samples = samples / 255.0 return samples def on_train_epoch_end(self): - with torch.no_grad(): - samples = self.generate_samples(num_samples=64) + samples = self.generate_samples(num_samples=64, differentiable=False) grid = torchvision.utils.make_grid(samples.data[:64], nrow=8, pad_value=0.0, normalize=True) self.logger.log_image(key="samples", images=[grid]) + samples_diff = self.generate_samples(num_samples=64, differentiable=True) + grid_diff = torchvision.utils.make_grid(samples_diff.data[:64], nrow=8, pad_value=0.0, normalize=True) + self.logger.log_image(key="samples_diff", images=[grid_diff]) + super().on_train_epoch_end() - def test_step(self, batch, batch_idx, dataloader_id=0): + def test_step(self, batch, batch_idx, dataloader_idx=0): data, labels = batch data = self.preprocess(data) nll = self.negative_log_likelihood(data) - set_name = DATALOADER_ID_TO_SET_NAME[dataloader_id] + set_name = DATALOADER_ID_TO_SET_NAME[dataloader_idx] self.log(f"Test/{set_name}_nll", nll, add_dataloader_idx=False) @@ -185,8 +215,8 @@ def __init__(self, cfg: DictConfig, steps_per_epoch: int): def training_step(self, train_batch, batch_idx): loss, accuracy = self._get_cross_entropy_and_accuracy(train_batch) - self.log("Train/accuracy", accuracy, prog_bar=True) - self.log("Train/loss", loss) + self.log("Train/accuracy", accuracy, on_step=True, prog_bar=True) + self.log("Train/loss", loss, on_step=True) return loss def validation_step(self, val_batch, batch_idx): @@ -195,9 +225,9 @@ def validation_step(self, val_batch, batch_idx): self.log("Val/loss", loss) return loss - def test_step(self, batch, batch_idx, dataloader_id=0): + def test_step(self, batch, batch_idx, dataloader_idx=0): loss, accuracy = self._get_cross_entropy_and_accuracy(batch) - set_name = DATALOADER_ID_TO_SET_NAME[dataloader_id] + set_name = DATALOADER_ID_TO_SET_NAME[dataloader_idx] self.log(f"Test/{set_name}_accuracy", accuracy, add_dataloader_idx=False) def _get_cross_entropy_and_accuracy(self, batch) -> Tuple[torch.Tensor, torch.Tensor]: @@ -218,6 +248,5 @@ def _get_cross_entropy_and_accuracy(self, batch) -> Tuple[torch.Tensor, torch.Te # NOTE: Don't use nn.CrossEntropyLoss because it expects unnormalized logits # and applies LogSoftmax first loss = self.criterion(ll_y_g_x, labels) - # loss = self.criterion(ll_x_g_y, labels) accuracy = (labels == ll_y_g_x.argmax(-1)).sum() / ll_y_g_x.shape[0] return loss, accuracy diff --git a/simple_einet/distributions/__init__.py b/simple_einet/distributions/__init__.py index 54cec59..67c6b57 100644 --- a/simple_einet/distributions/__init__.py +++ b/simple_einet/distributions/__init__.py @@ -3,6 +3,7 @@ """ -from simple_einet.distributions.utils import * from simple_einet.distributions.abstract_leaf import AbstractLeaf -from simple_einet.distributions.normal import RatNormal, CustomNormal +from simple_einet.distributions.utils import * +from simple_einet.distributions.binomial import Binomial +from simple_einet.distributions.normal import Normal, RatNormal, CustomNormal \ No newline at end of file diff --git a/simple_einet/distributions/abstract_leaf.py b/simple_einet/distributions/abstract_leaf.py index 5fc93e3..7b0a0ed 100644 --- a/simple_einet/distributions/abstract_leaf.py +++ b/simple_einet/distributions/abstract_leaf.py @@ -4,7 +4,7 @@ import logging from abc import ABC, abstractmethod from simple_einet.layers import AbstractLayer -from simple_einet.utils import SamplingContext, index_one_hot +from simple_einet.sampling_utils import SamplingContext, index_one_hot from typing import List from torch import distributions as dist, nn import torch @@ -84,9 +84,11 @@ def dist_sample(distribution: dist.Distribution, context: SamplingContext = None Sample n samples from a given distribution. Args: - indices_repetition: Indices into the repetition axis. - distribution (dists.Distribution): Base distribution to sample from. - indices_out (torch.Tensor): Tensor of indexes that point to specific representations of single features/scopes. + distribution: Leaf distribution from which to sample from. + context: Sampling context. + + Returns: + torch.Tensor: Samples from the given distribution. """ # Sample from the specified distribution @@ -136,6 +138,13 @@ class AbstractLeaf(AbstractLayer, ABC): representation, e.g. Gaussians. Implementing layers shall be valid distributions. + + Attributes: + num_features: Number of input features. + num_channels: Number of input features. + num_leaves: Number of parallel representations for each input feature. + num_repetitions: Number of parallel repetitions of this layer. + cardinality: Number of random variables covered by a single leaf. """ def __init__( @@ -170,7 +179,17 @@ def __init__( self.marginalization_constant = nn.Parameter(torch.zeros(1), requires_grad=False) def _apply_dropout(self, x: torch.Tensor) -> torch.Tensor: - # Apply dropout sampled from a bernoulli during training (model.train() has been called) + """ + Applies dropout to the input tensor `x` according to the dropout probability + `self.dropout`. Dropout is only applied during training (when `model.train()` + has been called). + + Args: + x (torch.Tensor): The input tensor to apply dropout to. + + Returns: + torch.Tensor: The input tensor with dropout applied. + """ if self.dropout > 0.0 and self.training: dropout_indices = self._bernoulli_dist.sample( x.shape, @@ -179,9 +198,18 @@ def _apply_dropout(self, x: torch.Tensor) -> torch.Tensor: return x def _marginalize_input(self, x: torch.Tensor, marginalized_scopes: List[int]) -> torch.Tensor: + """ + Marginalizes the input tensor `x` along the dimensions specified in `marginalized_scopes`. + + Args: + x (torch.Tensor): The input tensor to be marginalized. + marginalized_scopes (List[int]): A list of dimensions to be marginalized. + + Returns: + torch.Tensor: The marginalized tensor. + """ # Marginalize nans set by user if marginalized_scopes is not None: - # Transform to tensor if type(marginalized_scopes) != torch.Tensor: s = torch.tensor(marginalized_scopes) @@ -196,6 +224,16 @@ def _marginalize_input(self, x: torch.Tensor, marginalized_scopes: List[int]) -> return x def forward(self, x, marginalized_scopes: List[int]): + """ + Forward pass through the distribution. + + Args: + x (torch.Tensor): Input tensor. + marginalized_scopes (List[int]): List of scopes to marginalize. + + Returns: + torch.Tensor: Output tensor after marginalization. + """ # Forward through base distribution d = self._get_base_distribution() x = dist_forward(d, x) @@ -211,8 +249,14 @@ def _get_base_distribution(self, context: SamplingContext = None) -> dist.Distri def sample(self, num_samples: int = None, context: SamplingContext = None) -> torch.Tensor: """ - Perform sampling, given indices from the parent layer that indicate which of the multiple representations - for each input shall be used. + Sample from the distribution represented by this leaf node. + + Args: + num_samples (int, optional): The number of samples to draw from the distribution. If None, a single sample is drawn. + context (SamplingContext, optional): The sampling context to use when drawing samples. + + Returns: + torch.Tensor: A tensor of shape (num_samples,) or (1,) containing the drawn samples. """ d = self._get_base_distribution(context) samples = dist_sample(distribution=d, context=context) diff --git a/simple_einet/distributions/bernoulli.py b/simple_einet/distributions/bernoulli.py index 2e82ae4..5dea603 100644 --- a/simple_einet/distributions/bernoulli.py +++ b/simple_einet/distributions/bernoulli.py @@ -5,16 +5,20 @@ class Bernoulli(AbstractLeaf): - """Bernoulli layer. Maps each input feature to its gaussian log likelihood.""" + """Bernoulli layer. Maps each input feature to its bernoulli log likelihood. + + Probabilities are modeled as unconstrained parameters and are transformed via a sigmoid function into [0, 1] when needed. + """ def __init__(self, num_features: int, num_channels: int, num_leaves: int, num_repetitions: int): - """Creat a gaussian layer. + """ + Initializes a Bernoulli distribution with the given parameters. Args: - out_channels: Number of parallel representations for each input feature. - in_features: Number of input features. - num_repetitions: Number of parallel repetitions of this layer. - + num_features (int): The number of features in the input data. + num_channels (int): The number of channels in the input data. + num_leaves (int): The number of leaves in the tree structure. + num_repetitions (int): The number of repetitions for each leaf. """ super().__init__(num_features, num_channels, num_leaves, num_repetitions) diff --git a/simple_einet/distributions/binomial.py b/simple_einet/distributions/binomial.py index 16e51fd..4b44109 100644 --- a/simple_einet/distributions/binomial.py +++ b/simple_einet/distributions/binomial.py @@ -1,7 +1,7 @@ from torch import distributions as dist, nn import numpy as np import torchvision.models as models -from simple_einet.utils import SamplingContext +from simple_einet.sampling_utils import SamplingContext from typing import List, Tuple, Union import torch from torch import nn @@ -17,6 +17,8 @@ class Binomial(AbstractLeaf): + """Binomial layer. Maps each input feature to its binomial log likelihood.""" + def __init__( self, num_features: int, @@ -25,6 +27,16 @@ def __init__( num_repetitions: int, total_count: int, ): + """ + Initializes a Binomial distribution with the given parameters. + + Args: + num_features (int): The number of features in the input data. + num_channels (int): The number of channels in the input data. + num_leaves (int): The number of leaves in the tree. + num_repetitions (int): The number of repetitions for each leaf. + total_count (int): The total number of trials for the Binomial distribution. + """ super().__init__( num_features=num_features, num_channels=num_channels, @@ -46,11 +58,32 @@ def _get_base_distribution(self, context: SamplingContext = None): class CustomBinomial: + """ + A custom implementation of the Binomial distribution, with differentiable sampling. + + Args: + probs (torch.Tensor): The probability of success for each trial. Should have shape (batch_size,). + total_count (int): The total number of trials. + + Attributes: + probs (torch.Tensor): The probability of success for each trial. Should have shape (batch_size,). + total_count (int): The total number of trials. + """ + def __init__(self, probs, total_count): self.probs = probs self.total_count = total_count def sample(self, sample_shape: Tuple[int]): + """ + Draws samples from the distribution using a normal distribution as approximation. + + Args: + sample_shape (Tuple[int]): The shape of the desired sample. + + Returns: + torch.Tensor: A tensor of shape (sample_shape[0], batch_size), containing the drawn samples. + """ # Normal approximation to be differentiable mu = self.total_count * self.probs sigma = mu * (1 - self.probs) @@ -62,10 +95,41 @@ def sample(self, sample_shape: Tuple[int]): return samples def log_prob(self, x): + """ + Computes the log-probability of a given value under the distribution. + + Args: + x (torch.Tensor): The value(s) for which to compute the log-probability. Should have shape (batch_size,). + + Returns: + torch.Tensor: A tensor of shape (batch_size,), containing the log-probabilities. + """ return dist.Binomial(probs=self.probs, total_count=self.total_count).log_prob(x) class ConditionalBinomial(AbstractLeaf): + """ + A class representing a conditional binomial distribution. + + Allows a conditional function to be used to condition the binomial distribution. + + Args: + num_features (int): The number of features in the input tensor. + num_channels (int): The number of channels in the input tensor. + num_leaves (int): The number of leaves in the tree. + num_repetitions (int): The number of repetitions. + total_count (int): The total count of the binomial distribution. + cond_fn (nn.Module): The module used to condition the binomial distribution. + cond_idxs (Union[List[int], torch.Tensor]): The indices of the conditioned input. + + Attributes: + total_count (int): The total count of the binomial distribution. + cond_fn (nn.Module): The module used to condition the binomial distribution. + cond_idxs (Union[List[int], torch.Tensor]): The indices of the conditioned input. + probs_conditioned_base (nn.Parameter): The base parameters for the conditioned binomial distribution. + probs_unconditioned (nn.Parameter): The parameters for the unconditioned binomial distribution. + """ + def __init__( self, num_features: int, @@ -76,6 +140,9 @@ def __init__( cond_fn: nn.Module, cond_idxs: Union[List[int], torch.Tensor], ): + """ + Initializes the ConditionalBinomial class. + """ super().__init__( num_features=num_features, num_channels=num_channels, @@ -129,6 +196,16 @@ def get_conditioned_distribution(self, x_cond: torch.Tensor): return d def forward(self, x, marginalized_scopes: List[int]): + """ + Computes the forward pass of the ConditionalBinomial class. + + Args: + x (torch.Tensor): The input tensor. + marginalized_scopes (List[int]): The marginalized scopes. + + Returns: + The output tensor. + """ # Get conditional input (TODO: make this flexible with an index array defined during construction) x_cond = x[:, :, self.cond_idxs, None, None] d = self.get_conditioned_distribution(x_cond) @@ -142,6 +219,16 @@ def forward(self, x, marginalized_scopes: List[int]): return x def sample(self, num_samples: int = None, context: SamplingContext = None) -> torch.Tensor: + """ + Samples from the ConditionalBinomial distribution. + + Args: + num_samples (int): The number of samples to generate. + context (SamplingContext): The sampling context. + + Returns: + The generated samples. + """ ev = context.evidence x_cond = ev[:, :, self.cond_idxs, None, None] d = self.get_conditioned_distribution(x_cond) @@ -174,4 +261,10 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to return samples def _get_base_distribution(self) -> dist.Distribution: + """ + Gets the base distribution. + + Returns: + The base distribution. + """ raise NotImplementedError("This should not happen.") diff --git a/simple_einet/distributions/mixture.py b/simple_einet/distributions/mixture.py index 5a0f707..32db479 100644 --- a/simple_einet/distributions/mixture.py +++ b/simple_einet/distributions/mixture.py @@ -8,7 +8,8 @@ from torch import nn from torch.nn import functional as F -from simple_einet.utils import SamplingContext, invert_permutation +from simple_einet.utils import invert_permutation +from simple_einet.sampling_utils import SamplingContext from simple_einet.layers import AbstractLayer, Sum from simple_einet.type_checks import check_valid from simple_einet.distributions.abstract_leaf import AbstractLeaf, dist_mode @@ -18,7 +19,7 @@ class Mixture(AbstractLeaf): def __init__( self, distributions, - in_features: int, + num_features: int, out_channels, num_repetitions, dropout=0.0, @@ -30,15 +31,17 @@ def __init__( distributions: List of possible distributions to represent the feature with. out_channels: out_channels of how many nodes each distribution is assigned to. in_features: Number of input features. + num_repetitions: Number of times to repeat the layer. + dropout: Dropout probability. """ - super().__init__(in_features, out_channels, num_repetitions, dropout) + super().__init__(num_features, out_channels, num_repetitions, dropout) # Build different layers for each distribution specified - reprs = [distr(in_features, out_channels, num_repetitions, dropout) for distr in distributions] + reprs = [distr(num_features, out_channels, num_repetitions, dropout) for distr in distributions] self.representations = nn.ModuleList(reprs) # Build sum layer as mixture of distributions self.sumlayer = Sum( - num_features=in_features, + num_features=num_features, num_sums_in=len(distributions) * out_channels, num_sums_out=out_channels, num_repetitions=num_repetitions, @@ -48,6 +51,16 @@ def _get_base_distribution(self): raise Exception("Not implemented") def forward(self, x, marginalized_scopes: List[int]): + """ + Forward pass of the Mixture layer. + + Args: + x: Input tensor. + marginalized_scopes: List of marginalized scopes. + + Returns: + Output tensor. + """ results = [d(x) for d in self.representations] # Stack along output channel dimension @@ -58,6 +71,16 @@ def forward(self, x, marginalized_scopes: List[int]): return x def sample(self, num_samples: int = None, context: SamplingContext = None) -> torch.Tensor: + """ + Sample from the Mixture layer. + + Args: + num_samples: Number of samples to generate. + context: Sampling context. + + Returns: + Sampled tensor. + """ # Sample from sum mixture layer context = self.sumlayer.sample(context=context) diff --git a/simple_einet/distributions/multidistribution.py b/simple_einet/distributions/multidistribution.py index 38794ae..d167288 100644 --- a/simple_einet/distributions/multidistribution.py +++ b/simple_einet/distributions/multidistribution.py @@ -2,7 +2,8 @@ import torch from simple_einet.distributions.abstract_leaf import AbstractLeaf -from simple_einet.utils import SamplingContext, invert_permutation +from simple_einet.utils import invert_permutation +from simple_einet.sampling_utils import SamplingContext from torch import distributions as dist from torch import nn diff --git a/simple_einet/distributions/multivariate_normal.py b/simple_einet/distributions/multivariate_normal.py index 5b6e253..85ac93a 100644 --- a/simple_einet/distributions/multivariate_normal.py +++ b/simple_einet/distributions/multivariate_normal.py @@ -4,7 +4,7 @@ import torch from simple_einet.distributions.abstract_leaf import AbstractLeaf, dist_mode from simple_einet.type_checks import check_valid -from simple_einet.utils import SamplingContext +from simple_einet.sampling_utils import SamplingContext from torch import distributions as dist from torch import nn from torch.nn import functional as F diff --git a/simple_einet/distributions/normal.py b/simple_einet/distributions/normal.py index 84471ac..acffc66 100644 --- a/simple_einet/distributions/normal.py +++ b/simple_einet/distributions/normal.py @@ -6,7 +6,7 @@ from torch import distributions as dist from torch import nn -from simple_einet.utils import SamplingContext +from simple_einet.sampling_utils import SamplingContext class Normal(AbstractLeaf): @@ -19,13 +19,14 @@ def __init__( num_leaves: int, num_repetitions: int, ): - """Creat a gaussian layer. + """ + Initializes a Normal distribution with the given parameters. Args: - out_channels: Number of parallel representations for each input feature. - in_features: Number of input features. - num_repetitions: Number of parallel repetitions of this layer. - + num_features (int): The number of features in the input tensor. + num_channels (int): The number of channels in the input tensor. + num_leaves (int): The number of leaves in the tree structure. + num_repetitions (int): The number of repetitions of the tree structure. """ super().__init__(num_features, num_channels, num_leaves, num_repetitions) @@ -40,7 +41,11 @@ def _get_base_distribution(self, context: SamplingContext = None): class RatNormal(AbstractLeaf): """Implementation as in RAT-SPN - Gaussian layer. Maps each input feature to its gaussian log likelihood.""" + Gaussian layer. Maps each input feature to its gaussian log likelihood. + + Sigmas are constrained to be in min_sigma and max_sigma. + Means are constrained to be in min_mean and max_mean. + """ def __init__( self, @@ -53,35 +58,19 @@ def __init__( min_mean: float = None, max_mean: float = None, ): - """Creat a gaussian layer. + """ + Initializes a Normal distribution with learnable parameters for the means and standard deviations. Args: - out_channels: Number of parallel representations for each input feature. - in_features: Number of input features. - + num_features (int): The number of features in the input tensor. + num_leaves (int): The number of leaves in the tree structure. + num_channels (int): The number of channels in the input tensor. + num_repetitions (int, optional): The number of repetitions for each feature. Defaults to 1. + min_sigma (float, optional): The minimum value for the standard deviation. Defaults to 0.1. + max_sigma (float, optional): The maximum value for the standard deviation. Defaults to 1.0. + min_mean (float, optional): The minimum value for the mean. Defaults to None. + max_mean (float, optional): The maximum value for the mean. Defaults to None. """ - super().__init__( - num_features=num_features, - num_leaves=num_leaves, - num_repetitions=num_repetitions, - num_channels=num_channels, - ) - - # Create gaussian means and stds - self.means = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)) - - if min_sigma is not None and max_sigma is not None: - # Init from normal - self.stds = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)) - else: - # Init uniform between 0 and 1 - self.stds = nn.Parameter(torch.rand(1, num_channels, num_features, num_leaves, num_repetitions)) - - self.min_sigma = check_valid(min_sigma, float, 0.0, max_sigma) - self.max_sigma = check_valid(max_sigma, float, min_sigma) - self.min_mean = check_valid(min_mean, float, upper_bound=max_mean, allow_none=True) - self.max_mean = check_valid(max_mean, float, min_mean, allow_none=True) - def _get_base_distribution(self, context: SamplingContext = None) -> "CustomNormal": if self.min_sigma < self.max_sigma: sigma_ratio = torch.sigmoid(self.stds) @@ -101,17 +90,44 @@ def _get_base_distribution(self, context: SamplingContext = None) -> "CustomNorm class CustomNormal: - """Basic Normal class that can sample given mu and sigma.""" + """ + A custom implementation of the Normal distribution. + + This class allows to sample from a Normal distribution with mean `mu` and standard deviation `sigma`. + The `sample` method returns a tensor of samples from the distribution, with shape `sample_shape + mu.shape`. + The `log_prob` method returns the log probability density/mass function evaluated at `x`. + Args: + mu (torch.Tensor): The mean of the Normal distribution. + sigma (torch.Tensor): The standard deviation of the Normal distribution. + """ def __init__(self, mu: torch.Tensor, sigma: torch.Tensor): self.mu = mu self.sigma = sigma def sample(self, sample_shape: Tuple[int]): - num_samples = sample_shape[0] - eps = torch.randn((num_samples,) + self.mu.shape, dtype=self.mu.dtype, device=self.mu.device) - samples = self.mu.unsqueeze(0) + self.sigma.unsqueeze(0) * eps - return samples + """ + Generates random samples from the normal distribution with mean `mu` and standard deviation `sigma`. + + Args: + sample_shape (Tuple[int]): The shape of the desired output tensor. + + Returns: + samples (torch.Tensor): A tensor of shape `sample_shape` containing random samples from the normal distribution. + """ + num_samples = sample_shape[0] + eps = torch.randn((num_samples,) + self.mu.shape, dtype=self.mu.dtype, device=self.mu.device) + samples = self.mu.unsqueeze(0) + self.sigma.unsqueeze(0) * eps + return samples def log_prob(self, x): + """ + Computes the log probability density of the normal distribution at the given value. + + Args: + x (torch.Tensor): The value(s) at which to evaluate the log probability density. + + Returns: + torch.Tensor: The log probability density of the normal distribution at the given value(s). + """ return dist.Normal(self.mu, self.sigma).log_prob(x) diff --git a/simple_einet/distributions/utils.py b/simple_einet/distributions/utils.py index 5c5dfc8..f321623 100644 --- a/simple_einet/distributions/utils.py +++ b/simple_einet/distributions/utils.py @@ -1,13 +1,15 @@ def truncated_normal_(tensor, mean=0, std=0.1): """ Truncated normal from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 + + Args: + tensor (torch.Tensor): The tensor to initialize. + mean (float): The mean of the normal distribution. + std (float): The standard deviation of the normal distribution. """ size = tensor.shape tmp = tensor.new_empty(size + (4,)).normal_() valid = (tmp < 2) & (tmp > -2) ind = valid.max(-1, keepdim=True)[1] tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) - tensor.data.mul_(std).add_(mean) - -if __name__ == '__main__': - pass + tensor.data.mul_(std).add_(mean) \ No newline at end of file diff --git a/simple_einet/einet.py b/simple_einet/einet.py index d3e5d53..6f7f5e4 100644 --- a/simple_einet/einet.py +++ b/simple_einet/einet.py @@ -22,7 +22,7 @@ from simple_einet.factorized_leaf_layer import FactorizedLeaf from simple_einet.layers import Sum from simple_einet.type_checks import check_valid -from simple_einet.utils import SamplingContext, provide_evidence +from simple_einet.sampling_utils import provide_evidence, SamplingContext logger = logging.getLogger(__name__) @@ -587,7 +587,7 @@ def sample_differentiable( with provide_evidence(self, evidence, marginalized_scopes, requires_grad=True): # If class is given, use it as base index - # Start sampling one of the C root nodes TODO: check what happens if C=1 + # Start sampling one of the C root nodes ctx = SamplingContext( num_samples=num_samples, is_mpe=is_mpe, @@ -600,7 +600,6 @@ def sample_differentiable( tau=tau, mpe_at_leaves=mpe_at_leaves, ) - # ctx = self._sampling_root.sample(context=ctx) ctx.indices_out = torch.ones( num_samples, 1, diff --git a/simple_einet/einsum_layer.py b/simple_einet/einsum_layer.py index 65a1cab..f9d3fbf 100644 --- a/simple_einet/einsum_layer.py +++ b/simple_einet/einsum_layer.py @@ -8,7 +8,7 @@ from simple_einet.layers import AbstractLayer from simple_einet.type_checks import check_valid -from simple_einet.utils import SamplingContext, index_one_hot, diff_sample_one_hot +from simple_einet.sampling_utils import SamplingContext, diff_sample_one_hot, index_one_hot def logsumexp(tensors, mask=None, dim=-1): @@ -18,11 +18,12 @@ def logsumexp(tensors, mask=None, dim=-1): Logsumexp with custom scalar mask to allow for negative values in the sum. Args: - tensors: - mask: (Default value = None) + tensors (torch.Tensor, List[torch.Tensor]): The tensors to sum. + mask (torch.Tensor, optional): The mask to apply to the sum. Defaults to None. + dim (int, optional): The dimension to sum over. Defaults to -1. Returns: - + torch.Tensor: The summed tensor. """ # Ensure that everything is a tensor if type(tensors) == list or type(tensors) == tuple: @@ -39,6 +40,23 @@ def logsumexp(tensors, mask=None, dim=-1): class LinsumLayer(AbstractLayer): + """ + Similar to Einsum but with a linear combination of the input channels for each output channel compared to + the cross-product combination that is applied in an EinsumLayer. + + Attributes: + num_sums_in (int): The number of input sums. + num_sums_out (int): The number of output sums. + num_repetitions (int): The number of repetitions of the layer. + weights (torch.Tensor): The weights of the layer. + dropout (float): The dropout probability. + _bernoulli_dist (torch.distributions.Bernoulli): The Bernoulli distribution used for dropout. + _is_input_cache_enabled (bool): Whether the input cache is enabled. + _input_cache_left (torch.Tensor): The left input cache. + _input_cache_right (torch.Tensor): The right input cache. + cardinality (int): The cardinality of the layer. + """ + def __init__( self, num_features: int, @@ -47,6 +65,16 @@ def __init__( num_repetitions: int = 1, dropout: float = 0.0, ): + """ + Initializes a LinsumLayer instance. + + Args: + num_features (int): The number of input features. + num_sums_in (int): The number of input sums. + num_sums_out (int): The number of output sums. + num_repetitions (int, optional): The number of times to repeat the layer. Defaults to 1. + dropout (float, optional): The dropout probability. Defaults to 0.0. + """ super().__init__(num_features, num_repetitions) self.num_sums_in = check_valid(num_sums_in, int, 1) @@ -54,7 +82,6 @@ def __init__( cardinality = 2 # Fixed to binary graphs for now self.cardinality = check_valid(cardinality, int, 2, num_features + 1) self.num_features_out = np.ceil(self.num_features / self.cardinality).astype(int) - self._pad = 0 ws = self._init_weights() @@ -73,6 +100,9 @@ def __init__( self.out_shape = f"(N, {self.num_features_out}, {self.num_sums_out}, {self.num_repetitions})" def _init_weights(self): + """ + Initializes the weights of the layer. + """ # Weights, such that each sumnode has its own weights ws = torch.randn( self.num_features // self.cardinality, @@ -138,7 +168,6 @@ def forward_tdi(self, log_exp_ch: torch.Tensor, log_var_ch: torch.Tensor, dropou # SUM LAYER # ############# - # Prepare constants # If dropout at inference time is set, use this instead if dropout_inference is not None: @@ -202,7 +231,6 @@ def forward(self, x: torch.Tensor): dropout_indices = torch.log(1 - dropout_indices) prod_output = prod_output + dropout_indices - # Get log weights log_weights = self._get_normalized_log_weights().unsqueeze(0) prob = torch.logsumexp(prod_output + log_weights, dim=2) # N x D/2 x Sout x R @@ -214,10 +242,19 @@ def forward(self, x: torch.Tensor): return prob - def sample( - self, num_samples: int, context: SamplingContext, differentiable=False - ) -> Union[SamplingContext, torch.Tensor]: + def sample(self, num_samples: int, context: SamplingContext) -> SamplingContext: + """ + Samples from the weights of the EinsumLayer and returns a SamplingContext object + containing the sampled indices. + Args: + num_samples (int): The number of samples to generate. + context (SamplingContext): The SamplingContext object containing the indices + used for sampling. + + Returns: + SamplingContext: The SamplingContext object containing the sampled indices. + """ # Sum weights are of shape: [D, IC//2, IC//2, OC, R] # We now want to use `indices` to access one in_channel for each in_feature x # out_channels block index is of size in_feature @@ -382,6 +419,16 @@ def __init__( num_repetitions: int = 1, dropout: float = 0.0, ): + """ + EinsumLayer is a PyTorch module that implements the Einsum layer for the Einet model. + + Args: + num_features (int): The number of input features. + num_sums_in (int): The number of input sum nodes. + num_sums_out (int): The number of output sum nodes. + num_repetitions (int, optional): The number of repetitions. Defaults to 1. + dropout (float, optional): The dropout probability. Defaults to 0.0. + """ super().__init__(num_features, num_repetitions) self.num_sums_in = check_valid(num_sums_in, int, 1) @@ -389,7 +436,6 @@ def __init__( cardinality = 2 # Fixed to binary graphs for now self.cardinality = check_valid(cardinality, int, 2, num_features + 1) self.num_features_out = np.ceil(self.num_features / self.cardinality).astype(int) - self._pad = 0 # Weights, such that each sumnode has its own weights ws = torch.randn( @@ -484,10 +530,17 @@ def forward(self, x: torch.Tensor): return prob - def sample( - self, num_samples: int, context: SamplingContext, differentiable=False - ) -> Union[SamplingContext, torch.Tensor]: + def sample(self, num_samples: int, context: SamplingContext) -> SamplingContext: + """ + Samples from the weights of the EinsumLayer using the provided SamplingContext. + + Args: + num_samples (int): The number of samples to generate. + context (SamplingContext): The SamplingContext to use for sampling. + Returns: + SamplingContext: The updated SamplingContext. + """ # Sum weights are of shape: [D, IC//2, IC//2, OC, R] # We now want to use `indices` to access one in_channel for each in_feature x # out_channels block index is of size in_feature @@ -542,9 +595,6 @@ def sample( weights = weights.view(num_samples, out_features, in_channels**2) log_weights = F.log_softmax(weights * context.temperature_sums, dim=2) - # Clone for discrete validity check - # log_weights_disc = log_weights.clone() - # If evidence is given, adjust the weights with the likelihoods of the observed paths if self._is_input_cache_enabled and self._input_cache_left is not None: # Index repetition @@ -628,12 +678,30 @@ def extra_repr(self): class EinsumMixingLayer(AbstractLayer): + """ + A PyTorch module that implements a mixing layer using the Einstein summation convention. + + Attributes: + weights (nn.Parameter): The learnable weights of the layer. + num_sums_in (int): The number of input summation nodes. + num_sums_out (int): The number of output summation nodes. + out_features (int): The number of output features. + """ + def __init__( self, num_features: int, num_sums_in: int, num_sums_out: int, ): + """ + Creates a new EinsumMixingLayer. + + Args: + num_features (int): The number of input features. + num_sums_in (int): The number of input summation nodes. + num_sums_out (int): The number of output summation nodes. + """ super().__init__(num_features, num_repetitions=1) self.num_sums_in = check_valid(num_sums_in, int, 1) @@ -654,6 +722,7 @@ def __init__( self._input_cache_right = None def forward(self, x): + """Forward pass of the layer.""" # Save input if input cache is enabled if self._is_input_cache_enabled: self._input_cache = x.clone() @@ -676,8 +745,23 @@ def sample( self, num_samples: int = None, context: SamplingContext = None, - differentiable=False, - ) -> Union[SamplingContext, torch.Tensor]: + ) -> SamplingContext: + """ + Samples from the EinsumLayer. + + Args: + num_samples (int, optional): The number of samples to generate. Defaults to None. + context (SamplingContext, optional): The sampling context. Defaults to None. + + Returns: + SamplingContext: The updated sampling context. + """ + + def sample( + self, + num_samples: int = None, + context: SamplingContext = None, + ) -> SamplingContext: # Sum weights are of shape: [W, H, IC, OC, R] # We now want to use `indices` to access one in_channel for each in_feature x num_sums_out block # index is of size in_feature @@ -755,6 +839,20 @@ def extra_repr(self): class MixingLayer(AbstractLayer): + """ + A PyTorch module that implements a linear mixing layer. + + Attributes: + weights (nn.Parameter): The learnable weights of the layer. + num_sums_in (int): The number of input summation nodes. + num_sums_out (int): The number of output summation nodes. + out_features (int): The number of output features. + _is_input_cache_enabled (bool): Whether the input cache is enabled. + _input_cache_left (torch.Tensor): The left input cache. + _input_cache_right (torch.Tensor): The right input cache. + _bernoulli_dist (torch.distributions.Bernoulli): The Bernoulli distribution. + """ + def __init__( self, num_features: int, @@ -762,6 +860,15 @@ def __init__( num_sums_out: int, dropout: float = 0.0, ): + """ + Initializes an EinsumLayer instance. + + Args: + num_features (int): Number of input and output features. + num_sums_in (int): Number of input sum nodes. + num_sums_out (int): Number of output sum nodes. + dropout (float, optional): Dropout probability. Defaults to 0.0. + """ super().__init__(num_features, num_repetitions=1) self.num_sums_in = check_valid(num_sums_in, int, 1) @@ -790,6 +897,7 @@ def _get_normalized_log_weights(self): return F.log_softmax(self.weights, dim=2) def forward(self, x): + """Forward pass of the layer.""" # Save input if input cache is enabled if self._is_input_cache_enabled: self._input_cache = x.clone() @@ -797,7 +905,6 @@ def forward(self, x): # Dimensions N, D, IC, R = x.size() - # Apply dropout: Set random sum node children to 0 (-inf in log domain) if self.dropout > 0.0 and self.training: dropout_indices = self._bernoulli_dist.sample(x.shape) @@ -809,7 +916,6 @@ def forward(self, x): dropout_indices = torch.log(1 - dropout_indices) x = x + dropout_indices - # Get log weights log_weights = self._get_normalized_log_weights().unsqueeze(0) lls = torch.logsumexp(x + log_weights, dim=3) @@ -853,9 +959,18 @@ def sample( self, num_samples: int = None, context: SamplingContext = None, - differentiable=False, - ) -> Union[SamplingContext, torch.Tensor]: - raise NotImplementedError("Not yet implemented for MixingLayer") + ) -> SamplingContext: + """ + Samples from the mixing layer. + + Args: + num_samples (int): The number of samples to generate. + context (SamplingContext): The sampling context. + + Returns: + SamplingContext: The updated sampling context. + """ + # raise NotImplementedError("Not yet implemented for MixingLayer") # Sum weights are of shape: [W, H, IC, OC, R] # We now want to use `indices` to access one in_channel for each in_feature x num_sums_out block # index is of size in_feature @@ -929,25 +1044,4 @@ def _check_indices_repetition(self, context: SamplingContext): def extra_repr(self): return "num_features={}, num_sums_in={}, num_sums_out={}".format( self.num_features, self.num_sums_in, self.num_sums_out, self.num_repetitions - ) - - -if __name__ == "__main__": - x = torch.rand(1, 3, 4) - y = torch.rand(1, 3, 4) - z = torch.rand(1, 3, 4) - - print(logsumexp((logsumexp((x, y), mask=[1, 1]), z), mask=[1, -1])) - print(logsumexp(torch.stack([x, y, z], dim=-1), mask=[1, 1, -1])) - - # layer = LinsumLayer(num_features=2, num_sums_in=3, num_sums_out=10, num_repetitions=7, dropout=0.1) - # mixing = MixingLayer(num_features=1, num_sums_in=7, num_sums_out=10) - - # log_exp = torch.randn(1, 2, 3, 7) - # log_var = torch.rand_like(log_exp) - - # # x = layer(log_exp) - # # x = mixing(x) - - # log_exp, log_var = layer.forward_tdi(log_exp, log_var) - # log_exp, log_var = mixing.forward_tdi(log_exp, log_var) + ) \ No newline at end of file diff --git a/simple_einet/factorized_leaf_layer.py b/simple_einet/factorized_leaf_layer.py index a697c92..ae74e75 100644 --- a/simple_einet/factorized_leaf_layer.py +++ b/simple_einet/factorized_leaf_layer.py @@ -2,7 +2,7 @@ import torch import numpy as np -from simple_einet.utils import SamplingContext, index_one_hot +from simple_einet.sampling_utils import SamplingContext, index_one_hot from simple_einet.layers import AbstractLayer from simple_einet.distributions import AbstractLeaf @@ -12,11 +12,11 @@ class FactorizedLeaf(AbstractLayer): A 'meta'-leaf layer that combines multiple scopes of a base-leaf layer via naive factorization. Attributes: - base_leaf: Base leaf layer that contains the actual leaf distribution. - in_features: Number of input features/RVs. - out_features: Number of output features/RVs. This determines the factorization group size (round(in_features / out_features)) - scopes: One-hot mapping from which in_features correspond to which out_features. - + num_features (int): Number of input features. + num_features_out (int): Number of output features. + num_repetitions (int): Number of repetitions. + base_leaf (AbstractLeaf): The base leaf layer. + scopes (torch.Tensor): The scopes of the factorized groups of RVs. """ def __init__( @@ -28,10 +28,10 @@ def __init__( ): """ Args: - in_features (int): Number of input features/RVs. - out_features (int): Number of output features/RVs. + num_features (int): Number of input features. + num_features_out (int): Number of output features. num_repetitions (int): Number of repetitions. - base_leaf (Leaf): Base leaf distribution object. + base_leaf (AbstractLeaf): The base leaf layer. """ super().__init__(num_features, num_repetitions=num_repetitions) @@ -56,24 +56,48 @@ def __init__( self.register_buffer("scopes", scopes) def forward(self, x: torch.Tensor, marginalized_scopes: List[int]): - # Forward through base leaf - x = self.base_leaf(x, marginalized_scopes) + """ + Forward pass through the factorized leaf layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, num_input_channels, num_leaves, num_repetitions). + marginalized_scopes (List[int]): List of integers representing the marginalized scopes. + + Returns: + torch.Tensor: Output tensor of shape (batch_size, num_output_channels, num_leaves, num_repetitions). + """ + # Forward through base leaf + x = self.base_leaf(x, marginalized_scopes) - # Factorize input channels - x = x.sum(dim=1) + # Factorize input channels + x = x.sum(dim=1) - # Merge scopes by naive factorization - x = torch.einsum("bicr,ior->bocr", x, self.scopes) + # Merge scopes by naive factorization + x = torch.einsum("bicr,ior->bocr", x, self.scopes) - assert x.shape == ( - x.shape[0], - self.num_features_out, - self.base_leaf.num_leaves, - self.num_repetitions, - ) - return x + assert x.shape == ( + x.shape[0], + self.num_features_out, + self.base_leaf.num_leaves, + self.num_repetitions, + ) + return x def sample(self, num_samples: int = None, context: SamplingContext = None) -> torch.Tensor: + """ + Samples the factorized leaf layer by generating `num_samples` samples from the base leaf layer, + and then mapping them to the factorized leaf layer using the indices specified in the `context` + argument. If `context.is_differentiable` is True, the mapping is done using one-hot indexing. + + Args: + num_samples (int, optional): The number of samples to generate. If None, defaults to the + `num_samples` attribute of the `context` argument. Defaults to None. + context (SamplingContext, optional): The sampling context to use. Defaults to None. + + Returns: + torch.Tensor: A tensor of shape `(num_samples, self.num_features_out, self.num_leaves)`, + representing the samples generated from the factorized leaf layer. + """ # Save original indices_out and set context indices_out to none, such that the out_channel # are not filtered in the base_leaf sampling procedure indices_out = context.indices_out @@ -110,18 +134,6 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to samples = samples.gather(dim=-1, index=indices_in_gather) samples.squeeze_(-1) # Remove num_leaves dimension else: - # i_rep = context.indices_repetition.argmax(-1).long() - # i_out = indices_out.argmax(-1).long() - # scopes_orig = self.scopes[..., i_rep].permute(2, 0, 1) - # rnge_in = torch.arange(self.num_features_out, device=samples.device) - # scopes_orig = (scopes_orig * rnge_in).sum(-1).long() - # indices_in_gather_orig = i_out.gather(dim=1, index=scopes_orig) - # indices_in_gather_orig = indices_in_gather_orig.view(num_samples, 1, -1, 1) - # - # indices_in_gather_orig = indices_in_gather_orig.expand(-1, samples.shape[1], -1, -1) - # samples_orig = samples.gather(dim=-1, index=indices_in_gather_orig) - # samples_orig.squeeze_(-1) # Remove num_leaves dimension - scopes = self.scopes.unsqueeze(0) # make space for batch dim r_idx = context.indices_repetition.view(context.num_samples, 1, 1, -1) scopes = index_one_hot(scopes, index=r_idx, dim=-1) @@ -131,8 +143,6 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to indices_in = indices_in.unsqueeze(1) # make space for channel dim samples = index_one_hot(samples, index=indices_in, dim=-1) - # assert (samples - samples_orig).sum() < 1e-4 - return samples def extra_repr(self): diff --git a/simple_einet/layers.py b/simple_einet/layers.py index 7f7eef0..aa1df78 100644 --- a/simple_einet/layers.py +++ b/simple_einet/layers.py @@ -8,13 +8,23 @@ from torch.nn import functional as F from simple_einet.type_checks import check_valid -from simple_einet.utils import SamplingContext +from simple_einet.sampling_utils import SamplingContext logger = logging.getLogger(__name__) class AbstractLayer(nn.Module, ABC): + """ + This is the abstract base class for all layers in the SPN. + """ def __init__(self, num_features: int, num_repetitions: int = 1): + """ + Create an abstract layer. + + Args: + num_features (int): Number of input features. + num_repetitions (int, optional): Number of layer repetitions in parallel. Defaults to 1. + """ super().__init__() self.num_features = check_valid(num_features, int, 1) self.num_repetitions = check_valid(num_repetitions, int, 1) @@ -23,17 +33,31 @@ def __init__(self, num_features: int, num_repetitions: int = 1): def sample(self, num_samples: int = None, context: SamplingContext = None) -> Union[SamplingContext, torch.Tensor]: """ Sample from this layer. + Args: - num_samples: Number of samples. - context: Sampling context. + num_samples (int, optional): Number of samples. Defaults to None. + context (SamplingContext, optional): Sampling context. Defaults to None. Returns: - torch.Tensor: Generated samples. + Union[SamplingContext, torch.Tensor]: Generated samples. """ pass class Sum(AbstractLayer): + """ + Sum Node Layer that sums over all children in a scope set. + + Attributes: + num_sums_in (int): Number of input sum nodes. + num_sums_out (int): Multiplicity of a sum node for a given scope set. + dropout (torch.nn.Parameter): Dropout percentage. + weights (torch.nn.Parameter): Weights, such that each sumnode has its own weights. + out_shape (str): Output shape of the layer. + _bernoulli_dist (torch.distributions.Bernoulli): Bernoulli distribution for dropout. + _is_input_cache_enabled (bool): Whether input cache is enabled. + _input_cache (torch.Tensor): Cached input tensor. + """ def __init__( self, num_sums_in: int, @@ -49,11 +73,11 @@ def __init__( Output will be of shape [n, d, oc, r]. Args: - in_channels (int): Number of output channels from the previous layer. - in_features (int): Number of input features. - out_channels (int): Multiplicity of a sum node for a given scope set. - num_repetitions(int): Number of layer repetitions in parallel. - dropout (float, optional): Dropout percentage. + num_sums_in (int): Number of input sum nodes. + num_features (int): Number of input features. + num_sums_out (int): Multiplicity of a sum node for a given scope set. + num_repetitions (int, optional): Number of layer repetitions in parallel. Defaults to 1. + dropout (float, optional): Dropout percentage. Defaults to 0.0. """ super().__init__(num_features, num_repetitions) @@ -220,6 +244,13 @@ def extra_repr(self): class Product(AbstractLayer): """ Product Node Layer that chooses k scopes as children for a product node. + + Attributes: + cardinality (int): Number of random children for each product node. + _conv_weights (torch.nn.Parameter): Convolution weights. + _pad (int): Padding to the next power of 2. + _out_features (int): Number of output features. + out_shape (str): Output shape of the layer. """ def __init__(self, in_features: int, cardinality: int, num_repetitions: int = 1): @@ -229,6 +260,7 @@ def __init__(self, in_features: int, cardinality: int, num_repetitions: int = 1) Args: in_features (int): Number of input features. cardinality (int): Number of random children for each product node. + num_repetitions (int, optional): Number of layer repetitions in parallel. Defaults to 1. """ super().__init__(in_features, num_repetitions) @@ -333,6 +365,13 @@ class CrossProduct(AbstractLayer): res += [n1 * n2] TODO: Generalize to k regions (cardinality = k). + + Attributes: + cardinality (int): Number of random children for each product node. + _pad (int): Padding to the next power of 2. + _out_features (int): Number of output features. + out_shape (str): Output shape of the layer. + _scopes (List[List[int]]): List of scopes for each product child. """ def __init__(self, in_features: int, in_channels: int, num_repetitions: int = 1): diff --git a/simple_einet/sampling_utils.py b/simple_einet/sampling_utils.py new file mode 100644 index 0000000..8f69920 --- /dev/null +++ b/simple_einet/sampling_utils.py @@ -0,0 +1,267 @@ +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass + +import torch +from torch import nn +from torch.nn import functional as F +from tqdm import tqdm + +import simple_einet + +from simple_einet.utils import __HAS_EINSUM_BROADCASTING + + +@contextmanager +def provide_evidence( + spn: nn.Module, + evidence: torch.Tensor = None, + marginalized_scopes: torch.Tensor = None, + requires_grad=False, +): + """ + Context manager for sampling with evidence. In this context, the SPN graph is reweighted with the likelihoods + computed using the given evidence. + + Args: + spn: SPN that is being used to perform the sampling. + evidence: Provided evidence. The SPN will perform a forward pass prior to entering this contex. + requires_grad: If False, runs in torch.no_grad() context. (default: False) + """ + # If no gradients are required, run in no_grad context + if not requires_grad: + context = torch.no_grad + else: + # Else provide null context + context = nullcontext + + # Run forward pass in given context + with context(): + + if evidence is not None: + # Enter + for module in spn.modules(): + if hasattr(module, "_enable_input_cache"): + module._enable_input_cache() + + _ = spn(evidence, marginalized_scopes) + + # Run in context (nothing needs to be yielded) + yield + + # Exit + if evidence is not None: + for module in spn.modules(): + if hasattr(module, "_enable_input_cache"): + module._disable_input_cache() + + +@dataclass +class SamplingContext: + """Dataclass for representing the context in which sampling operations occur.""" + + # Number of samples + num_samples: int = None + + # Indices into the out_channels dimension + indices_out: torch.Tensor = None + + # Indices into the repetition dimension + indices_repetition: torch.Tensor = None + + # MPE flag, if true, will perform most probable explanation sampling + is_mpe: bool = False + + # Temperature for sampling at the leaves + temperature_leaves: float = 1.0 + + # Temperature for sampling at the einsumlayers + temperature_sums: float = 1.0 + + # Number of repetitions + num_repetitions: int = None + + # Evidence + evidence: torch.Tensor = None + + # Differentiable + is_differentiable: bool = False + + # Flag for hard or soft differentiable sampling + hard: bool = False + + # Temperature for differentiable sampling + tau: float = 1.0 + + # Do MPE at leaves + mpe_at_leaves: bool = False + + def __setattr__(self, key, value): + if hasattr(self, key): + super().__setattr__(key, value) + else: + raise AttributeError(f"SamplingContext object has no attribute {key}") + + @property + def is_root(self): + return self.indices_out == None and self.indices_repetition == None + + +def get_context(differentiable): + """ + Get a noop context if differentiable, else torch.no_grad. + + Args: + differentiable: If the context should allow gradients or not. + + Returns: + nullcontext if differentialbe=False, else torch.no_grad + + """ + if differentiable: + return nullcontext() + else: + return torch.no_grad() + + +def diff_sample_one_hot(logits: torch.Tensor, dim: int, mode: str, hard: bool, tau: float) -> torch.Tensor: + """ + Perform differentiable sampling/mpe on the given input along a specific dimension. + + Modes: + - "sample": Perform sampling + - "argmax": Perform mpe + + Args: + logits(torch.Tensor): Logits from which the sampling should be done. + dim(int): Dimension along which to sample from. + mode(str): Mode as described above. + hard(bool): Whether to perform hard or soft sampling. + tau(float): Temperature for soft sampling. + + Returns: + torch.Tensor: Indices encoded as one-hot tensor along the given dimension `dim`. + + """ + if mode == "sample": + return F.gumbel_softmax(logits=logits, hard=hard, tau=tau, dim=dim) + elif mode == "argmax": + # Differentiable argmax (see gumbel softmax trick code) + y_soft = logits.softmax(dim) + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + return ret + else: + raise Exception(f"Invalid mode option (got {mode}). Must be either 'sample' or 'argmax'.") + + +def index_one_hot(tensor: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor: + """ + Index into a given tensor unsing a one-hot encoded index tensor at a specific dimension. + + Example: + + Given array "x = [3 7 5]" and index "2", "x[2]" should return "5". + + Here, index "2" should be one-hot encoded: "2 = [0 0 1]" which we use to + elementwise multiply the original tensor and then sum up the masked result. + + sum([3 7 5] * [0 1 0]) == sum([0 0 5]) == 5 + + The implementation is equivalent to + + torch.sum(tensor * index, dim) + + but uses the einsum operation to reduce the number of operations from two to one. + + Args: + tensor(torch.Tensor): Tensor which shall be indexed. + index(torch.Tensor): Indexing tensor. + dim(int): Dimension at which the tensor should be used index. + + Returns: + torch.Tensor: Indexed tensor. + + """ + assert ( + tensor.shape[dim] == index.shape[dim] + ), f"Tensor and index at indexing dimension must be the same size but was tensor.shape[{dim}]={tensor.shape[dim]} and index.shape[{dim}]={index.shape[dim]}" + + assert ( + tensor.dim() == index.dim() + ), f"Tensor and index number of dimensions must be the same but was tensor.dim()={tensor.dim()} and index.dim()={index.dim()}" + + if __HAS_EINSUM_BROADCASTING and False: + num_dims = tensor.dim() + dims = "abcdefghijklmnopqrstuvwxyz"[:num_dims] + dims_without = dims[:dim] + dims[dim + 1 :] + einsum_str = f"{dims},{dims}->{dims_without}" + # print(f"tensor.shape: {tensor.shape}") + # print(f"index.shape: {index.shape}") + # print(f"einsum_str: {einsum_str}") + # print(f"dim={dim}") + return torch.einsum(einsum_str, tensor, index) + else: + return torch.sum(tensor * index, dim=dim) + + +def init_einet_stats(einet: "Einet", dataloader: torch.utils.data.DataLoader): + """ + Initializes the statistics of the Einet model using the given dataloader. + + Args: + - einet (Einet): The Einet model to initialize. + - dataloader (torch.utils.data.DataLoader): The dataloader to use for initialization. + + Returns: None + """ + stats_mean = None + stats_std = None + + # Compute mean and std + for batch in tqdm(dataloader, desc="Leaf Parameter Initialization"): + data, label = batch + if stats_mean == None: + stats_mean = data.mean(dim=0) + stats_std = data.std(dim=0) + else: + stats_mean += data.mean(dim=0) + stats_std += data.std(dim=0) + + # Normalize + stats_mean /= len(dataloader) + stats_std /= len(dataloader) + + + + from simple_einet.distributions.normal import Normal + from simple_einet.einet import Einet, EinetMixture + + + # Set leaf parameters for normal distribution + if einet.config.leaf_type == Normal: + + if type(einet) == Einet: + einets = [einet] + elif type(einet) == EinetMixture: + einets = einet.einets + else: + raise ValueError(f"Invalid einet type: {type(einet)} -- must be Einet or EinetMixture.") + + # Reshape to match leaf parameters + stats_mean_v = ( + stats_mean.view(-1, 1, 1) + .repeat(1, einets[0].config.num_leaves, einets[0].config.num_repetitions) + .view_as(einets[0].leaf.base_leaf.means) + ) + stats_std_v = ( + stats_std.view(-1, 1, 1) + .repeat(1, einets[0].config.num_leaves, einets[0].config.num_repetitions) + .view_as(einets[0].leaf.base_leaf.log_stds) + ) + + # Set leaf parameters + for net in einets: + # Add noise to ensure that values are not completely equal along repetitions and einets + net.leaf.base_leaf.means.data = stats_mean_v + 0.1 * torch.normal(torch.zeros_like(stats_mean_v), torch.std(stats_mean_v)) + net.leaf.base_leaf.log_stds.data = torch.log(stats_std_v + 1e-3 + torch.clamp(0.1 * torch.normal(torch.zeros_like(stats_std_v), torch.std(stats_std_v)), min=0.0)) diff --git a/simple_einet/type_checks.py b/simple_einet/type_checks.py index 5df8354..6249d96 100644 --- a/simple_einet/type_checks.py +++ b/simple_einet/type_checks.py @@ -4,25 +4,48 @@ class OutOfBoundsException(Exception): + """ + Exception raised when a value is not within the specified bounds. + + Attributes: + value -- the value that was out of bounds + lower_bound -- the lower bound of the valid range + upper_bound -- the upper bound of the valid range + """ + def __init__(self, value, lower_bound, upper_bound): super().__init__(f"Value {value} was not in bounds: [{lower_bound}, {upper_bound}).") class InvalidTypeException(Exception): + """ + Exception raised when a value is of an invalid type. + + Attributes: + value -- the value that was of an invalid type + expected_type -- the expected type (or a subclass of this type) + """ + def __init__(self, value, expected_type): super().__init__( f"Value {value} was of type {type(value)} but expected to be of type {expected_type} (or a subclass of this type) ." ) -class InvalidStackedSpnConfigurationException(Exception): - def __init__(self, expected, observed, parameter_name): - super().__init__( - f"The StackedSpn has received an invalid configuration: Expected {parameter_name}={expected} but got {parameter_name}={observed}." - ) - def _check_bounds(value: Any, expected_type, lower_bound=None, upper_bound=None): + """ + Check if a given value is within the specified bounds. + + Args: + value (Any): The value to check. + expected_type (type): The expected type of the value. + lower_bound (Any, optional): The lower bound of the value. Defaults to None. + upper_bound (Any, optional): The upper bound of the value. Defaults to None. + + Raises: + OutOfBoundsException: If the value is outside the specified bounds. + """ # Check lower bound if lower_bound: if not value >= expected_type(lower_bound): @@ -34,7 +57,18 @@ def _check_bounds(value: Any, expected_type, lower_bound=None, upper_bound=None) raise OutOfBoundsException(value, lower_bound, upper_bound) + def _check_type(value: Any, expected_type): + """ + Checks if the type of the given value matches the expected type. + + Args: + value (Any): The value to check the type of. + expected_type (type): The expected type of the value. + + Raises: + Exception: If the type of the value is not supported for typecheck. + """ # Check if type is from torch if isinstance(value, torch.Tensor): _check_type_torch(value, expected_type) @@ -48,14 +82,37 @@ def _check_type(value: Any, expected_type): raise Exception(f"Unsupported type ({type(value)}) for typecheck.") -def _check_type_core(value: Any, expected_type): + +def _check_type_core(value: Any, expected_type: type) -> None: + """ + Check if the given value is of the expected type. + + Args: + value (Any): The value to check the type of. + expected_type (type): The expected type of the value. + + Raises: + InvalidTypeException: If the value is not of the expected type. + """ if expected_type == float and not isinstance(value, float): raise InvalidTypeException(value, expected_type) elif expected_type == int and not isinstance(value, int): raise InvalidTypeException(value, expected_type) + def _check_type_numpy(value: Any, expected_type): + """ + Check if the given value is of the expected type using numpy's data types. + + Args: + value (Any): The value to check the type of. + expected_type (type): The expected type of the value. + + Raises: + InvalidTypeException: If the value is not of the expected type. + + """ # Check float if expected_type == float: if not isinstance(value, np.floating): @@ -69,6 +126,17 @@ def _check_type_numpy(value: Any, expected_type): def _check_type_torch(value: torch.Tensor, expected_type): + """ + Check if the data type of a torch tensor matches the expected type. + + Args: + value (torch.Tensor): The tensor to check the data type of. + expected_type (type): The expected data type. Must be either `int` or `float`. + + Raises: + InvalidTypeException: If the data type of the tensor does not match the expected type. + + """ # Get torch data type dtype = value.dtype @@ -83,9 +151,23 @@ def _check_type_torch(value: torch.Tensor, expected_type): raise Exception(f"Unexpected data type, must be either int or float, but was {expected_type}") + def check_valid(value: Any, expected_type, lower_bound=None, upper_bound=None, allow_none: bool = False): """ - Check if a value is of a certain type and in given bounds. + Check if a given value is valid based on its expected type and optional bounds. + + Args: + value (Any): The value to check. + expected_type (type): The expected type of the value. + lower_bound (Any, optional): The lower bound for the value. Defaults to None. + upper_bound (Any, optional): The upper bound for the value. Defaults to None. + allow_none (bool, optional): Whether to allow None as a valid value. Defaults to False. + + Returns: + Any: The value, if it is valid. + + Raises: + Exception: If the value is not valid based on the expected type and/or bounds. """ if allow_none and value is None: return value diff --git a/simple_einet/utils.py b/simple_einet/utils.py index 2693d37..1125f34 100644 --- a/simple_einet/utils.py +++ b/simple_einet/utils.py @@ -1,16 +1,9 @@ -from contextlib import contextmanager, nullcontext -import logging -import os -import sys -from dataclasses import dataclass from typing import Tuple -from torch.nn import functional as F import numpy as np import torch from scipy.stats import rankdata -from torch import Tensor, nn -from tqdm.std import tqdm +from torch import Tensor # Assert that torch.einsum broadcasting is available check for torch version >= 1.8.0 try: @@ -27,99 +20,6 @@ __HAS_EINSUM_BROADCASTING = False -@contextmanager -def provide_evidence( - spn: nn.Module, - evidence: torch.Tensor = None, - marginalized_scopes: torch.Tensor = None, - requires_grad=False, -): - """ - Context manager for sampling with evidence. In this context, the SPN graph is reweighted with the likelihoods - computed using the given evidence. - - Args: - spn: SPN that is being used to perform the sampling. - evidence: Provided evidence. The SPN will perform a forward pass prior to entering this contex. - requires_grad: If False, runs in torch.no_grad() context. (default: False) - """ - # If no gradients are required, run in no_grad context - if not requires_grad: - context = torch.no_grad - else: - # Else provide null context - context = nullcontext - - # Run forward pass in given context - with context(): - - if evidence is not None: - # Enter - for module in spn.modules(): - if hasattr(module, "_enable_input_cache"): - module._enable_input_cache() - - _ = spn(evidence, marginalized_scopes) - - # Run in context (nothing needs to be yielded) - yield - - # Exit - if evidence is not None: - for module in spn.modules(): - if hasattr(module, "_enable_input_cache"): - module._disable_input_cache() - - -@dataclass -class SamplingContext: - # Number of samples - num_samples: int = None - - # Indices into the out_channels dimension - indices_out: torch.Tensor = None - - # Indices into the repetition dimension - indices_repetition: torch.Tensor = None - - # MPE flag, if true, will perform most probable explanation sampling - is_mpe: bool = False - - # Temperature for sampling at the leaves - temperature_leaves: float = 1.0 - - # Temperature for sampling at the einsumlayers - temperature_sums: float = 1.0 - - # Number of repetitions - num_repetitions: int = None - - # Evidence - evidence: torch.Tensor = None - - # Differentiable - is_differentiable: bool = False - - # Flag for hard or soft differentiable sampling - hard: bool = False - - # Temperature for differentiable sampling - tau: float = 1.0 - - # Do MPE at leaves - mpe_at_leaves: bool = False - - def __setattr__(self, key, value): - if hasattr(self, key): - super().__setattr__(key, value) - else: - raise AttributeError(f"SamplingContext object has no attribute {key}") - - @property - def is_root(self): - return self.indices_out == None and self.indices_repetition == None - - def invert_permutation(p: torch.Tensor): """ The argument p is assumed to be some permutation of 0, 1, ..., len(p)-1. @@ -132,6 +32,18 @@ def invert_permutation(p: torch.Tensor): def calc_bpd(log_p: Tensor, image_shape: Tuple[int, int, int], has_gauss_dist: bool, n_bins: int) -> float: + """ + Calculates the bits per dimension (BPD) for a given log probability tensor. + + Args: + log_p (Tensor): The log probability tensor. + image_shape (Tuple[int, int, int]): The shape of the image. + has_gauss_dist (bool): Whether the distribution is Gaussian or not. + n_bins (int): The number of bins. + + Returns: + float: The bits per dimension (BPD) value. + """ n_pixels = np.prod(image_shape) if has_gauss_dist: @@ -262,101 +174,4 @@ def rdc(x, y, f=np.sin, k=20, s=1 / 6.0, n=1): return np.sqrt(np.max(eigs)) -def get_context(differentiable): - """ - Get a noop context if differentiable, else torch.no_grad. - - Args: - differentiable: If the context should allow gradients or not. - - Returns: - nullcontext if differentialbe=False, else torch.no_grad - - """ - if differentiable: - return nullcontext() - else: - return torch.no_grad() - - -def diff_sample_one_hot(logits: torch.Tensor, dim: int, mode: str, hard: bool, tau: float) -> torch.Tensor: - """ - Perform differentiable sampling/mpe on the given input along a specific dimension. - - Modes: - - "sample": Perform sampling - - "argmax": Perform mpe - - Args: - logits(torch.Tensor): Logits from which the sampling should be done. - dim(int): Dimension along which to sample from. - mode(str): Mode as described above. - hard(bool): Whether to perform hard or soft sampling. - tau(float): Temperature for soft sampling. - - Returns: - torch.Tensor: Indices encoded as one-hot tensor along the given dimension `dim`. - - """ - if mode == "sample": - return F.gumbel_softmax(logits=logits, hard=hard, tau=tau, dim=dim) - elif mode == "argmax": - # Differentiable argmax (see gumbel softmax trick code) - y_soft = logits.softmax(dim) - index = y_soft.max(dim, keepdim=True)[1] - y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) - ret = y_hard - y_soft.detach() + y_soft - return ret - else: - raise Exception(f"Invalid mode option (got {mode}). Must be either 'sample' or 'argmax'.") - - -def index_one_hot(tensor: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor: - """ - Index into a given tensor unsing a one-hot encoded index tensor at a specific dimension. - - Example: - - Given array "x = [3 7 5]" and index "2", "x[2]" should return "5". - - Here, index "2" should be one-hot encoded: "2 = [0 0 1]" which we use to - elementwise multiply the original tensor and then sum up the masked result. - - sum([3 7 5] * [0 1 0]) == sum([0 0 5]) == 5 - - The implementation is equivalent to - - torch.sum(tensor * index, dim) - - but uses the einsum operation to reduce the number of operations from two to one. - - Args: - tensor(torch.Tensor): Tensor which shall be indexed. - index(torch.Tensor): Indexing tensor. - dim(int): Dimension at which the tensor should be used index. - - Returns: - torch.Tensor: Indexed tensor. - - """ - assert ( - tensor.shape[dim] == index.shape[dim] - ), f"Tensor and index at indexing dimension must be the same size but was tensor.shape[{dim}]={tensor.shape[dim]} and index.shape[{dim}]={index.shape[dim]}" - - assert ( - tensor.dim() == index.dim() - ), f"Tensor and index number of dimensions must be the same but was tensor.dim()={tensor.dim()} and index.dim()={index.dim()}" - - if __HAS_EINSUM_BROADCASTING and False: - num_dims = tensor.dim() - dims = "abcdefghijklmnopqrstuvwxyz"[:num_dims] - dims_without = dims[:dim] + dims[dim + 1 :] - einsum_str = f"{dims},{dims}->{dims_without}" - # print(f"tensor.shape: {tensor.shape}") - # print(f"index.shape: {index.shape}") - # print(f"einsum_str: {einsum_str}") - # print(f"dim={dim}") - return torch.einsum(einsum_str, tensor, index) - else: - return torch.sum(tensor * index, dim=dim)