Skip to content

Commit

Permalink
refactor: Apply black (120 max-line)
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Dec 5, 2022
1 parent e72f1bc commit 32808c3
Show file tree
Hide file tree
Showing 16 changed files with 163 additions and 423 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[tool.black]
line-length = 100
line-length = 120
30 changes: 7 additions & 23 deletions simple_einet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,11 @@ def __getitem__(self, index: int):

def downscale(self, scale):
"""Downscale this shape by the given scale. Only changes height/width."""
return Shape(
self.channels, round(self.height / scale), round(self.width / scale)
)
return Shape(self.channels, round(self.height / scale), round(self.width / scale))

def upscale(self, scale):
"""Upscale this shape by the given scale. Only changes height/width."""
return Shape(
self.channels, round(self.height * scale), round(self.width * scale)
)
return Shape(self.channels, round(self.height * scale), round(self.width * scale))

@property
def num_pixels(self):
Expand Down Expand Up @@ -240,19 +236,15 @@ def get_datasets(cfg, normalize: bool) -> Tuple[Dataset, Dataset, Dataset]:

elif "celeba" in dataset_name:
if normalize:
transform.transforms.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
)
transform.transforms.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))

dataset_train = CelebA(**kwargs, split="train")
dataset_val = CelebA(**kwargs, split="valid")
dataset_test = CelebA(**kwargs, split="test")

elif dataset_name == "cifar":
if normalize:
transform.transforms.append(
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
)
transform.transforms.append(transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))
dataset_train = CIFAR10(**kwargs, train=True)

N = len(dataset_train.data)
Expand All @@ -264,15 +256,11 @@ def get_datasets(cfg, normalize: bool) -> Tuple[Dataset, Dataset, Dataset]:

elif "svhn" in dataset_name:
if normalize:
transform.transforms.append(
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
)
transform.transforms.append(transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))

# Load train
dataset_train = SVHN(**kwargs, split="train")



N = len(dataset_train.data)
lenghts = [round(N * 0.9), round(N * 0.1)]
dataset_train, dataset_val = random_split(dataset_train, lengths=lenghts)
Expand All @@ -289,9 +277,7 @@ def get_datasets(cfg, normalize: bool) -> Tuple[Dataset, Dataset, Dataset]:
return dataset_train, dataset_val, dataset_test


def build_dataloader(
cfg, loop: bool, normalize: bool
) -> Tuple[DataLoader, DataLoader, DataLoader]:
def build_dataloader(cfg, loop: bool, normalize: bool) -> Tuple[DataLoader, DataLoader, DataLoader]:
# Get dataset objects
dataset_train, dataset_val, dataset_test = get_datasets(cfg, normalize=normalize)

Expand Down Expand Up @@ -355,9 +341,7 @@ def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):

def __iter__(self):
start = self._rank
yield from itertools.islice(
self._infinite_indices(), start, None, self._world_size
)
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)

def _infinite_indices(self):
g = torch.Generator()
Expand Down
2 changes: 1 addition & 1 deletion simple_einet/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

from simple_einet.distributions.utils import *
from simple_einet.distributions.abstract_leaf import AbstractLeaf
from simple_einet.distributions.normal import RatNormal, CustomNormal
from simple_einet.distributions.normal import RatNormal, CustomNormal
16 changes: 8 additions & 8 deletions simple_einet/distributions/abstract_leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def dist_mode(distribution: dist.Distribution, context: SamplingContext = None)
return distribution.mean.repeat(context.num_samples, 1, 1, 1, 1)
from simple_einet.distributions.normal import CustomNormal
from simple_einet.distributions.binomial import CustomBinomial

if isinstance(distribution, CustomNormal):
# Repeat the mode along the batch axis
return distribution.mu.repeat(context.num_samples, 1, 1, 1, 1)
Expand Down Expand Up @@ -94,14 +95,11 @@ def dist_sample(distribution: dist.Distribution, context: SamplingContext = None
samples = samples.unsqueeze(1)
else:
from simple_einet.distributions import CustomNormal

if type(distribution) == dist.Normal:
distribution = dist.Normal(
loc=distribution.loc, scale=distribution.scale * context.temperature_leaves
)
distribution = dist.Normal(loc=distribution.loc, scale=distribution.scale * context.temperature_leaves)
elif type(distribution) == CustomNormal:
distribution = CustomNormal(
mu=distribution.mu, sigma=distribution.sigma * context.temperature_leaves
)
distribution = CustomNormal(mu=distribution.mu, sigma=distribution.sigma * context.temperature_leaves)
samples = distribution.sample(sample_shape=(context.num_samples,))

assert (
Expand Down Expand Up @@ -174,7 +172,9 @@ def __init__(
def _apply_dropout(self, x: torch.Tensor) -> torch.Tensor:
# Apply dropout sampled from a bernoulli during training (model.train() has been called)
if self.dropout > 0.0 and self.training:
dropout_indices = self._bernoulli_dist.sample(x.shape, ).bool()
dropout_indices = self._bernoulli_dist.sample(
x.shape,
).bool()
x[dropout_indices] = 0.0
return x

Expand Down Expand Up @@ -219,4 +219,4 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to
return samples

def extra_repr(self):
return f"num_features={self.num_features}, num_leaves={self.num_leaves}, out_shape={self.out_shape}"
return f"num_features={self.num_features}, num_leaves={self.num_leaves}, out_shape={self.out_shape}"
4 changes: 1 addition & 3 deletions simple_einet/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def __init__(self, num_features: int, num_channels: int, num_leaves: int, num_re
super().__init__(num_features, num_channels, num_leaves, num_repetitions)

# Create bernoulli parameters
self.probs = nn.Parameter(
torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)
)
self.probs = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))

def _get_base_distribution(self):
# Use sigmoid to ensure, that probs are in valid range
Expand Down
34 changes: 8 additions & 26 deletions simple_einet/distributions/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,14 @@ def __init__(
self.total_count = check_valid(total_count, int, lower_bound=1)

# Create binomial parameters
self.probs = nn.Parameter(
torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)
)
self.probs = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))

def _get_base_distribution(self, context: SamplingContext = None):
# Use sigmoid to ensure, that probs are in valid range
if context is not None and context.is_differentiable:
return CustomBinomial(
probs=self.probs.sigmoid(), total_count=self.total_count
)
return CustomBinomial(probs=self.probs.sigmoid(), total_count=self.total_count)
else:
return dist.Binomial(
probs=self.probs.sigmoid(), total_count=self.total_count
)
return dist.Binomial(probs=self.probs.sigmoid(), total_count=self.total_count)


class CustomBinomial:
Expand Down Expand Up @@ -94,18 +88,10 @@ def __init__(
self.cond_idxs = cond_idxs

self.probs_conditioned_base = nn.Parameter(
0.5
+ torch.rand(
1, num_channels, num_features // 2, num_leaves, num_repetitions
)
* 0.1
0.5 + torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) * 0.1
)
self.probs_unconditioned = nn.Parameter(
0.5
+ torch.rand(
1, num_channels, num_features // 2, num_leaves, num_repetitions
)
* 0.1
0.5 + torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) * 0.1
)

def get_conditioned_distribution(self, x_cond: torch.Tensor):
Expand Down Expand Up @@ -155,9 +141,7 @@ def forward(self, x, marginalized_scopes: List[int]):

return x

def sample(
self, num_samples: int = None, context: SamplingContext = None
) -> torch.Tensor:
def sample(self, num_samples: int = None, context: SamplingContext = None) -> torch.Tensor:
ev = context.evidence
x_cond = ev[:, :, self.cond_idxs, None, None]
d = self.get_conditioned_distribution(x_cond)
Expand Down Expand Up @@ -185,11 +169,9 @@ def sample(
# If parent index into out_channels are given
if context.indices_out is not None:
# Choose only specific samples for each feature/scope
samples = torch.gather(
samples, dim=2, index=context.indices_out.unsqueeze(-1)
).squeeze(-1)
samples = torch.gather(samples, dim=2, index=context.indices_out.unsqueeze(-1)).squeeze(-1)

return samples

def _get_base_distribution(self) -> dist.Distribution:
raise NotImplementedError("This should not happen.")
raise NotImplementedError("This should not happen.")
9 changes: 3 additions & 6 deletions simple_einet/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from simple_einet.type_checks import check_valid
from simple_einet.distributions.abstract_leaf import AbstractLeaf, dist_mode


class Mixture(AbstractLeaf):
def __init__(
self,
Expand All @@ -32,9 +33,7 @@ def __init__(
"""
super().__init__(in_features, out_channels, num_repetitions, dropout)
# Build different layers for each distribution specified
reprs = [
distr(in_features, out_channels, num_repetitions, dropout) for distr in distributions
]
reprs = [distr(in_features, out_channels, num_repetitions, dropout) for distr in distributions]
self.representations = nn.ModuleList(reprs)

# Build sum layer as mixture of distributions
Expand Down Expand Up @@ -74,8 +73,6 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to
# If parent index into out_channels are given
if context.indices_out is not None:
# Choose only specific samples for each feature/scope
samples = torch.gather(samples, dim=2, index=context.indices_out.unsqueeze(-1)).squeeze(
-1
)
samples = torch.gather(samples, dim=2, index=context.indices_out.unsqueeze(-1)).squeeze(-1)

return samples
4 changes: 1 addition & 3 deletions simple_einet/distributions/multidistribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,4 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to
return samples

def _get_base_distribution(self) -> dist.Distribution:
raise NotImplementedError(
"MultiDistributionLayer does not implement _get_base_distribution."
)
raise NotImplementedError("MultiDistributionLayer does not implement _get_base_distribution.")
20 changes: 5 additions & 15 deletions simple_einet/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def __init__(
"""
# TODO: Fix for num_repetitions
super().__init__(in_features, out_channels, num_repetitions, dropout, cardinality)
raise NotImplementedError(
"MultivariateNormal has not been adapted to the new implementation yet - sorry."
)
raise NotImplementedError("MultivariateNormal has not been adapted to the new implementation yet - sorry.")
self._pad_value = in_features % cardinality
self.out_features = np.ceil(in_features / cardinality).astype(int)
self._n_dists = np.ceil(in_features / cardinality).astype(int)
Expand All @@ -45,15 +43,11 @@ def __init__(
self.max_sigma = check_valid(max_sigma, float, min_sigma)

# Create gaussian means and covs
self.means = nn.Parameter(
torch.randn(out_channels * self._n_dists * self.num_repetitions, cardinality)
)
self.means = nn.Parameter(torch.randn(out_channels * self._n_dists * self.num_repetitions, cardinality))

# Generate covariance matrix via the cholesky decomposition: s = A'A where A is a triangular matrix
# Further ensure, that diag(a) > 0 everywhere, such that A has full rank
rand = torch.zeros(
out_channels * self._n_dists * self.num_repetitions, cardinality, cardinality
)
rand = torch.zeros(out_channels * self._n_dists * self.num_repetitions, cardinality, cardinality)

for i in range(cardinality):
rand[:, i, i] = 1.0
Expand All @@ -62,9 +56,7 @@ def __init__(

# Make matrices triangular and remove diagonal entries
cov_tril_wo_diag = rand.tril(diagonal=-1)
cov_tril_wi_diag = torch.rand(
out_channels * self._n_dists * self.num_repetitions, cardinality, cardinality
)
cov_tril_wi_diag = torch.rand(out_channels * self._n_dists * self.num_repetitions, cardinality, cardinality)

self.cov_tril_wo_diag = nn.Parameter(cov_tril_wo_diag)
self.cov_tril_wi_diag = nn.Parameter(cov_tril_wi_diag)
Expand Down Expand Up @@ -100,9 +92,7 @@ def forward(self, x: torch.Tensor, marginalized_scopes: List[int]) -> torch.Tens
# Output shape: [n, out_channels, d / cardinality]
mv = self._get_base_distribution()
x = mv.log_prob(x) # [n, r * d/k * oc]
x = x.view(
batch_size, self.num_repetitions, self.num_leaves, self._n_dists
) # [n, r, oc, d/k]
x = x.view(batch_size, self.num_repetitions, self.num_leaves, self._n_dists) # [n, r, oc, d/k]
x = x.permute(0, 3, 2, 1) # [n, d/k, oc, r]

# Marginalize and apply dropout
Expand Down
28 changes: 7 additions & 21 deletions simple_einet/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,8 @@ def __init__(
super().__init__(num_features, num_channels, num_leaves, num_repetitions)

# Create gaussian means and stds
self.means = nn.Parameter(
torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)
)
self.log_stds = nn.Parameter(
torch.rand(1, num_channels, num_features, num_leaves, num_repetitions)
)
self.means = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))
self.log_stds = nn.Parameter(torch.rand(1, num_channels, num_features, num_leaves, num_repetitions))

def _get_base_distribution(self, context: SamplingContext = None):
return dist.Normal(loc=self.means, scale=self.log_stds.exp())
Expand Down Expand Up @@ -72,26 +68,18 @@ def __init__(
)

# Create gaussian means and stds
self.means = nn.Parameter(
torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)
)
self.means = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))

if min_sigma is not None and max_sigma is not None:
# Init from normal
self.stds = nn.Parameter(
torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)
)
self.stds = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))
else:
# Init uniform between 0 and 1
self.stds = nn.Parameter(
torch.rand(1, num_channels, num_features, num_leaves, num_repetitions)
)
self.stds = nn.Parameter(torch.rand(1, num_channels, num_features, num_leaves, num_repetitions))

self.min_sigma = check_valid(min_sigma, float, 0.0, max_sigma)
self.max_sigma = check_valid(max_sigma, float, min_sigma)
self.min_mean = check_valid(
min_mean, float, upper_bound=max_mean, allow_none=True
)
self.min_mean = check_valid(min_mean, float, upper_bound=max_mean, allow_none=True)
self.max_mean = check_valid(max_mean, float, min_mean, allow_none=True)

def _get_base_distribution(self, context: SamplingContext = None) -> "CustomNormal":
Expand Down Expand Up @@ -121,9 +109,7 @@ def __init__(self, mu: torch.Tensor, sigma: torch.Tensor):

def sample(self, sample_shape: Tuple[int]):
num_samples = sample_shape[0]
eps = torch.randn(
(num_samples,) + self.mu.shape, dtype=self.mu.dtype, device=self.mu.device
)
eps = torch.randn((num_samples,) + self.mu.shape, dtype=self.mu.dtype, device=self.mu.device)
samples = self.mu.unsqueeze(0) + self.sigma.unsqueeze(0) * eps
return samples

Expand Down
Loading

0 comments on commit 32808c3

Please sign in to comment.