Skip to content

Commit

Permalink
general: Refactor code and add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Sep 22, 2023
1 parent c6ef47b commit ae92983
Show file tree
Hide file tree
Showing 23 changed files with 1,023 additions and 652 deletions.
3 changes: 3 additions & 0 deletions .envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export PYTHONPATH=./
export RESULTS_DIR=${HOME}/results
export DATA_DIR=${HOME}/data
2 changes: 1 addition & 1 deletion LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
The MIT License (MIT)

Copyright (c) 2022 Steven Lang
Copyright (c) 2023 Steven Braun

Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
Expand Down
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Simple EinsumNetworks Implementation

This repository contains code for my personal, simplistic, EinsumNetworks implementation.
This repository contains code for my personal, simplistic, EinsumNetworks implementation. Well, as it happens with code all the time, the implementation is not as simplistic as it used to be but at least this was the initial intention.

For a speed benchmark comparison against the official EinsumNetworks implementation, check out [benchmark.md](./benchmark/benchmark.md) (short: simple-einet is faster in all dimensions except the input-channel size but scales similar to EinsumNetworks).

Expand All @@ -17,7 +17,7 @@ python main_pl.py dataset=mnist batch_size=128 epochs=100 dist=normal D=5 I=32 S
<img src="./res/mnist_classification.png" width=400px><img src="./res/mnist_train_val_test_acc.png" width=400px>


Generative learning on MNIST:
Generative training on MNIST:

``` sh
python main_pl.py dataset=mnist D=3 I=10 R=1 S=10 lr=0.1 dist=binomial epochs=10 batch_size=128
Expand Down Expand Up @@ -55,7 +55,7 @@ out_features = 3
x = torch.randn(batchsize, in_features)

# Construct Einet
einet = Einet(EinetConfig(num_features=in_features, depth=2, num_sums=2, num_channels=1, num_leaves=3, num_repetitions=3, num_classes=out_features, dropout=0.0, leaf_type=RatNormal, leaf_kwargs={"min_sigma": 1e-5, "max_sigma": 1.0},))
einet = Einet(EinetConfig(num_features=in_features, depth=2, num_sums=2, num_channels=1, num_leaves=3, num_repetitions=3, num_classes=out_features, dropout=0.0, leaf_type=Normal))

# Compute log-likelihoods
lls = einet(x)
Expand Down Expand Up @@ -86,7 +86,18 @@ print(f"samples: \n{samples}")

## Citing EinsumNetworks

If you use EinsumNetworks in your publications, please cite the official EinsumNetworks paper.
If you use this software, please cite it as below.

```bibtex
@software{braun2021simple-einet,
author = {Braun, Steven},
title = {{Simple-einet: An EinsumNetworks Implementation}},
url = {https://github.com/braun-steven/simple-einet},
version = {0.0.1},
}
```

If you use EinsumNetworks as a model in your publications, please cite our official EinsumNetworks paper.

```bibtex
@inproceedings{pmlr-v119-peharz20a,
Expand All @@ -105,14 +116,3 @@ If you use EinsumNetworks in your publications, please cite the official EinsumN
code = {https://github.com/cambridge-mlg/EinsumNetworks},
}
```

If you use this software, please cite it as below.

``` bibtex
@software{braun2021simple-einet,
author = {Braun, Steven},
title = {{Simple-einet: An EinsumNetworks Implementation}},
url = {https://github.com/braun-steven/simple-einet},
version = {0.0.1},
}
```
189 changes: 37 additions & 152 deletions exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor



from matplotlib.cm import tab10
from matplotlib import cm

from simple_einet.data import build_dataloader, get_data_shape, Shape, generate_data


Expand Down Expand Up @@ -399,43 +404,17 @@ def catch_kb_interrupt(output_directory):
shutil.move(src, dst)


@torch.no_grad()
def print_num_params(model: nn.Module):
"""
Compute the number of parameters and separate into Flow/SPN parts.
Args:
model (nn.Module): Model with parameters.
"""
if type(model) == DistributedDataParallel:
model = model.module

# Count all parameteres
sum_params = count_params(model)

# Count SPN parameters
spn_params = sum_params

# Print
logger.info(f"Number of parameters:")
# logger.info(f"- Total: {sum_params / 1e6: >8.3f}M")
logger.info(
f"- SPN: {spn_params / 1e6: >8.3f}M ({spn_params / sum_params * 100:.1f}%)"
)
# logger.info(f"- NN: {nn_params / 1e6: >8.3f}M ({nn_params / sum_params * 100:.1f}%)")


def preprocess(
x: torch.Tensor,
n_bits: int,
) -> torch.Tensor:
"""Preprocess the image."""
x = reduce_bits(x, n_bits)
# x = x.long()
return x


def reduce_bits(image: torch.Tensor, n_bits: int) -> torch.Tensor:
"""Reduce the number of bits of the image."""
image = image * 255
if n_bits < 8:
image = torch.floor(image / 2 ** (8 - n_bits))
Expand All @@ -449,7 +428,16 @@ def xor(a: bool, b: bool) -> bool:


def loss_dict_to_str(running_loss_dict: Dict[str, float], logging_period: int) -> str:
"""Create a joined string from a dictionary mapping str->float."""
"""
Create a joined string from a dictionary mapping str->float.
Args:
running_loss_dict (Dict[str, float]): Dictionary mapping str->float.
logging_period (int): Logging period.
Returns:
str: Joined string.
"""
loss_str = ", ".join(
[
f"{key}: {value / logging_period:.2f}"
Expand All @@ -460,6 +448,11 @@ def loss_dict_to_str(running_loss_dict: Dict[str, float], logging_period: int) -


def plot_tensor(x: torch.Tensor):
"""Plot a tensor as an image.
Args:
x (torch.Tensor): Tensor to plot.
"""

plt.figure()
if x.dim() == 4:
Expand All @@ -469,121 +462,6 @@ def plot_tensor(x: torch.Tensor):
plt.close()


def build_tensorboard_writer(results_dir):
"""
Build a tensorboard writer.
Args:
results_dir: Directory where to save the tensorboard files.
Returns:
A tensorboard writer.
"""
return SummaryWriter(os.path.join(results_dir, "tensorboard"))


def setup_experiment(
name: str,
args: argparse.Namespace,
remove_if_exists: bool = True,
with_tensorboard=True,
):
"""
Sets up the experiment.
Args:
name: The name of the experiment.
"""
print(f"Arguments: {args}")

#
if args.dataset == "celeba":
args.dataset = "celeba-small"

# Check if we want to restore from a finished experiment
if args.load_and_eval is not None:
# Load args
old_dir: pathlib.Path = args.load_and_eval.expanduser()
args_file = os.path.join(old_dir, "args.json")
old_args = argparse.Namespace(**json.load(open(args_file)))
old_args.load_and_eval = args.load_and_eval
old_args.gpu = args.gpu

print("Loading from existing directory:", old_dir)
print("Loading with existing args:", pprint.pformat(old_args))

results_dir = old_dir
args = old_args
else:
# Create result directory
results_dir = make_results_dir(
results_dir=args.results_dir,
experiment_name=name,
tag=args.tag,
dataset_name=args.dataset,
remove_if_exists=remove_if_exists,
)
# Save args to file
save_args(results_dir, args)
print(f"Results directory: {results_dir}")
# Setup tensorboard
if with_tensorboard:
writer = build_tensorboard_writer(results_dir)
else:
writer = None

if torch.cuda.is_available():
device = torch.device("cuda:" + str(args.gpu))
print("Using GPU device", torch.cuda.current_device())
else:
device = torch.device("cpu")
print("Using device:", device)
seed_all_rng(args.seed)
cudnn.benchmark = True

# Image shape
image_shape: Shape = get_data_shape(args.dataset)

# Create RTPT object
rtpt = RTPT(
name_initials="SL",
experiment_name=name + "_" + str(args.tag),
max_iterations=args.epochs,
)

# Start the RTPT tracking
rtpt.start()

return (
args,
results_dir,
writer,
device,
image_shape,
rtpt,
)

def setup_experiment(name: str, cfg: DictConfig, remove_if_exists: bool = False):
"""
Sets up the experiment.
Args:
name: The name of the experiment.
"""
# Create result directory
results_dir = make_results_dir(
results_dir=cfg.results_dir,
experiment_name=name,
tag=cfg.tag,
dataset_name=cfg.dataset,
remove_if_exists=remove_if_exists,
)
# Save args to file
# save_args(results_dir, cfg)

# Save args to file
print(f"Results directory: {results_dir}")
seed_all_rng(cfg.seed)
cudnn.benchmark = True
return results_dir, cfg

def anneal_tau(epoch, max_epochs):
"""Anneal the softmax temperature tau based on the epoch progress."""
return max(0.5, np.exp(-1 / max_epochs * epoch))
Expand Down Expand Up @@ -614,13 +492,6 @@ def save_samples(generate_samples, samples_dir, num_samples, nrow):
torchvision.utils.save_image(grid, os.path.join(samples_dir, f"{i}.png"))


from matplotlib.cm import tab10
from matplotlib import cm

TEXTWIDTH = 5.78853
LINEWIDTH = 0.75
ARROW_HEADWIDTH = 5
colors = tab10.colors


def get_figsize(scale: float, aspect_ratio=0.8) -> Tuple[float, float]:
Expand All @@ -635,12 +506,16 @@ def get_figsize(scale: float, aspect_ratio=0.8) -> Tuple[float, float]:
Tuple: Tuple containing (width, height) of the figure.
"""
TEXTWIDTH = 5.78853
height = aspect_ratio * TEXTWIDTH
widht = TEXTWIDTH
return (scale * widht, scale * height)


def set_style():
"""
Sets the style of the matplotlib plots to use LaTeX fonts and the SciencePlots package.
"""
matplotlib.use("pgf")
plt.style.use(["science", "grid"]) # Need SciencePlots pip package
matplotlib.rcParams.update(
Expand All @@ -654,6 +529,16 @@ def set_style():


def plot_distribution(model, dataset_name, logger_wandb: WandbLogger = None):
"""
Plots the learned probability density function (PDF) represented by the given model
on a 2D grid of points sampled from the specified dataset.
Args:
model (nn.Module): The model to use for generating the PDF.
dataset_name (str): The name of the dataset to sample points from.
logger_wandb (WandbLogger, optional): The logger to use for logging the plot image to WandB.
Defaults to None.
"""
with torch.no_grad():
data, targets = generate_data(dataset_name, n_samples=1000)
fig = plt.figure(figsize=get_figsize(1.0))
Expand Down Expand Up @@ -685,7 +570,7 @@ def plot_distribution(model, dataset_name, logger_wandb: WandbLogger = None):
lw=0.5,
s=10,
alpha=0.5,
color=colors[1],
color=tab10.colors[1],
)

plt.xlabel("$X_0$")
Expand Down
6 changes: 4 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ def test(model, device, loader, tag):
num_leaves=args.I,
num_repetitions=args.R,
num_classes=num_classes,
leaf_type=Binomial,
leaf_kwargs={"total_count": n_bins - 1},
leaf_type=Normal,
leaf_kwargs={},
# leaf_kwargs={"total_count": n_bins - 1},
dropout=0.0,
cross_product=True,
)
model = Einet(config).to(device)
print(
Expand Down
Loading

0 comments on commit ae92983

Please sign in to comment.