Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ repos:
hooks:
- id: prettier
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7
rev: v0.14.4
hooks:
- id: ruff
types_or: [python, pyi, jupyter]
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
types_or: [python, pyi, jupyter]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: detect-private-key
- id: check-ast
Expand Down
6 changes: 2 additions & 4 deletions src/torchgmm/base/data/collation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Tuple

import torch


def collate_tuple(batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
def collate_tuple(batch: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
"""
Collate a tuple of batch items by returning the input tuple.

Expand All @@ -13,7 +11,7 @@ def collate_tuple(batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
return batch


def collate_tensor(batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
def collate_tensor(batch: tuple[torch.Tensor, ...]) -> torch.Tensor:
"""
Collates a tuple of batch items into the first tensor.

Expand Down
3 changes: 2 additions & 1 deletion src/torchgmm/base/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Iterator, TypeVar
from collections.abc import Iterator
from typing import Any, TypeVar

try:
from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper as IndexBatchSamplerWrapper
Expand Down
2 changes: 1 addition & 1 deletion src/torchgmm/base/data/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Iterator
from collections.abc import Iterator

from torch.utils.data import Sampler
from torch.utils.data.sampler import SequentialSampler
Expand Down
3 changes: 2 additions & 1 deletion src/torchgmm/base/data/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, TypeVar, Union
from collections.abc import Sequence
from typing import TypeVar, Union

import numpy as np
import numpy.typing as npt
Expand Down
8 changes: 5 additions & 3 deletions src/torchgmm/base/nn/_protocols.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pylint: disable=missing-class-docstring,missing-function-docstring
from typing import Generic, Iterator, OrderedDict, Protocol, Tuple, Type, TypeVar
from collections import OrderedDict
from collections.abc import Iterator
from typing import Generic, Protocol, TypeVar

import torch
from torch import nn
Expand All @@ -15,13 +17,13 @@ class ConfigurableModule(Protocol, Generic[C_co]):
def config(self) -> C_co: ...

@classmethod
def load(cls: Type[M], path: PathType) -> M: ...
def load(cls: type[M], path: PathType) -> M: ...

def save(self, path: PathType, compile_model: bool = False) -> None: ...

def save_config(self, path: PathType) -> None: ...

def named_children(self) -> Iterator[Tuple[str, nn.Module]]: ...
def named_children(self) -> Iterator[tuple[str, nn.Module]]: ...

def state_dict(self) -> OrderedDict[str, torch.Tensor]: ...

Expand Down
4 changes: 2 additions & 2 deletions src/torchgmm/base/utils/generics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Type, get_args, get_origin
from typing import Any, get_args, get_origin


def get_generic_type(cls: Type[Any], origin: Type[Any], index: int = 0) -> Type[Any]:
def get_generic_type(cls: type[Any], origin: type[Any], index: int = 0) -> type[Any]:
"""
Returns the ``index``-th generic type of the superclass ``origin``.

Expand Down
6 changes: 3 additions & 3 deletions src/torchgmm/bayes/gmm/estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Any, List, Tuple, cast
from typing import Any, cast

import numpy as np
import torch
Expand Down Expand Up @@ -272,7 +272,7 @@ def score_samples(self, data: TensorLike) -> torch.Tensor:
collate_fn=collate_tensor,
)
result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader)
return torch.stack([x[1] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
return torch.stack([x[1] for x in cast(list[tuple[torch.Tensor, torch.Tensor]], result)])

def predict(self, data: TensorLike) -> torch.Tensor:
"""
Expand Down Expand Up @@ -321,4 +321,4 @@ def predict_proba(self, data: TensorLike) -> torch.Tensor:
collate_fn=collate_tensor,
)
result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader)
return torch.cat([x[0] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
return torch.cat([x[0] for x in cast(list[tuple[torch.Tensor, torch.Tensor]], result)])
9 changes: 5 additions & 4 deletions src/torchgmm/bayes/gmm/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any

import torch
from torchmetrics import Metric
Expand All @@ -17,7 +18,7 @@ def __init__(
self,
num_components: int,
*,
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
dist_sync_fn: Callable[[Any], Any] | None = None,
):
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore

Expand All @@ -44,7 +45,7 @@ def __init__(
num_components: int,
num_features: int,
*,
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
dist_sync_fn: Callable[[Any], Any] | None = None,
):
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore

Expand Down Expand Up @@ -78,7 +79,7 @@ def __init__(
covariance_type: CovarianceType,
reg: float,
*,
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
dist_sync_fn: Callable[[Any], Any] | None = None,
):
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore

Expand Down
3 changes: 1 addition & 2 deletions src/torchgmm/bayes/gmm/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -90,7 +89,7 @@ def reset_parameters(self) -> None:
if self.covariance_type in ("full", "tied"):
self.precisions_cholesky.tril_()

def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Computes the log-probability of observing each of the provided datapoints for each of the
GMM's components.
Expand Down
8 changes: 4 additions & 4 deletions src/torchgmm/clustering/kmeans/estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Any, List, cast
from typing import Any, cast

import numpy as np
import torch
Expand Down Expand Up @@ -188,7 +188,7 @@ def predict(self, data: TensorLike) -> torch.Tensor:
collate_fn=collate_tensor,
)
result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="assignments"), loader)
return torch.cat(cast(List[torch.Tensor], result))
return torch.cat(cast(list[torch.Tensor], result))

def score(self, data: TensorLike) -> float:
"""
Expand Down Expand Up @@ -236,7 +236,7 @@ def score_samples(self, data: TensorLike) -> torch.Tensor:
collate_fn=collate_tensor,
)
result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="inertias"), loader)
return torch.cat(cast(List[torch.Tensor], result))
return torch.cat(cast(list[torch.Tensor], result))

def transform(self, data: TensorLike) -> torch.Tensor:
"""
Expand All @@ -262,4 +262,4 @@ def transform(self, data: TensorLike) -> torch.Tensor:
collate_fn=collate_tensor,
)
result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="distances"), loader)
return torch.cat(cast(List[torch.Tensor], result))
return torch.cat(cast(list[torch.Tensor], result))
4 changes: 2 additions & 2 deletions src/torchgmm/clustering/kmeans/lightning_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=abstract-method
import math
from typing import List, Literal
from typing import Literal

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
# Initialize metrics
self.metric_inertia = MeanMetric()

def configure_callbacks(self) -> List[pl.Callback]:
def configure_callbacks(self) -> list[pl.Callback]:
if self.convergence_tolerance == 0:
return []
early_stopping = EarlyStopping(
Expand Down
13 changes: 7 additions & 6 deletions src/torchgmm/clustering/kmeans/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any

import torch
from torchmetrics import Metric
Expand All @@ -17,7 +18,7 @@ def __init__(
num_clusters: int,
num_features: int,
*,
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
dist_sync_fn: Callable[[Any], Any] | None = None,
):
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore

Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
num_choices: int,
num_features: int,
*,
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
dist_sync_fn: Callable[[Any], Any] | None = None,
):
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore

Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
num_choices: int,
num_features: int,
*,
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
dist_sync_fn: Callable[[Any], Any] | None = None,
):
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore

Expand Down Expand Up @@ -173,7 +174,7 @@ class BatchSummer(Metric):

full_state_update = True

def __init__(self, num_values: int, *, dist_sync_fn: Optional[Callable[[Any], Any]] = None):
def __init__(self, num_values: int, *, dist_sync_fn: Callable[[Any], Any] | None = None):
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore

self.sums: torch.Tensor
Expand All @@ -198,7 +199,7 @@ def __init__(
num_values: int,
for_variance: bool,
*,
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
dist_sync_fn: Callable[[Any], Any] | None = None,
):
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore

Expand Down
3 changes: 1 addition & 2 deletions src/torchgmm/clustering/kmeans/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Tuple

import torch
from torch import jit, nn
Expand Down Expand Up @@ -52,7 +51,7 @@ def reset_parameters(self) -> None:
"""
nn.init.normal_(self.centroids)

def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the distance of each datapoint to each centroid as well as the "inertia", the
squared distance of each datapoint to its closest centroid.
Expand Down
3 changes: 1 addition & 2 deletions src/torchgmm/utils/lightning_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import List

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -32,7 +31,7 @@ def on_train_epoch_end(self) -> None:
self.nonparametric_training_epoch_end()
else:

def training_epoch_end(self, outputs: List[torch.Tensor]) -> None:
def training_epoch_end(self, outputs: list[torch.Tensor]) -> None:
"""Training epoch end hook for PyTorch Lightning < 2.0.0."""
self.nonparametric_training_epoch_end()

Expand Down
3 changes: 1 addition & 2 deletions tests/_data/gmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pylint: disable=missing-function-docstring
from typing import Tuple

import torch

Expand All @@ -9,7 +8,7 @@

def sample_gmm(
num_datapoints: int, num_features: int, num_components: int, covariance_type: CovarianceType
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
config = GaussianMixtureModelConfig(num_components, num_features, covariance_type)
model = GaussianMixtureModel(config)

Expand Down
11 changes: 5 additions & 6 deletions tests/_data/normal.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
# pylint: disable=missing-function-docstring
from typing import List

import torch


def sample_data(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
def sample_data(counts: list[int], dims: list[int]) -> list[torch.Tensor]:
return [torch.randn(count, dim) for count, dim in zip(counts, dims)]


def sample_means(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
def sample_means(counts: list[int], dims: list[int]) -> list[torch.Tensor]:
return [torch.randn(count, dim) for count, dim in zip(counts, dims)]


def sample_spherical_covars(counts: List[int]) -> List[torch.Tensor]:
def sample_spherical_covars(counts: list[int]) -> list[torch.Tensor]:
return [torch.rand(count) for count in counts]


def sample_diag_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
def sample_diag_covars(counts: list[int], dims: list[int]) -> list[torch.Tensor]:
return [torch.rand(count, dim).squeeze() for count, dim in zip(counts, dims)]


def sample_full_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
def sample_full_covars(counts: list[int], dims: list[int]) -> list[torch.Tensor]:
result = []
for count, dim in zip(counts, dims):
A = torch.rand(count, dim * 10, dim)
Expand Down
5 changes: 2 additions & 3 deletions tests/bayes/gmm/benchmark_gmm_estimator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pylint: disable=missing-function-docstring
from typing import Optional

import pytest
import pytorch_lightning as pl
Expand Down Expand Up @@ -72,7 +71,7 @@ def test_torchgmm(
num_features: int,
num_components: int,
covariance_type: CovarianceType,
batch_size: Optional[int],
batch_size: int | None,
):
pl.seed_everything(0)
data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type)
Expand Down Expand Up @@ -115,7 +114,7 @@ def test_torchgmm_gpu(
num_features: int,
num_components: int,
covariance_type: CovarianceType,
batch_size: Optional[int],
batch_size: int | None,
):
# Initialize GPU
torch.empty(1, device="cuda:0")
Expand Down
5 changes: 2 additions & 3 deletions tests/bayes/gmm/test_gmm_estimator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# pylint: disable=missing-function-docstring
import math
from typing import Optional

import pytest
import torch
Expand All @@ -23,7 +22,7 @@ def test_fit_model_config():


@pytest.mark.parametrize("batch_size", [2, None])
def test_fit_num_iter(batch_size: Optional[int]):
def test_fit_num_iter(batch_size: int | None):
# For the following data, K-means will find centroids [0.5, 3.5]. The estimator first computes
# the NLL (first iteration), afterwards there is no improvmement in the NLL (second iteration).
data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]])
Expand All @@ -41,7 +40,7 @@ def test_fit_num_iter(batch_size: Optional[int]):
("batch_size", "max_epochs", "converged"),
[(2, 1, False), (2, 3, True), (None, 1, False), (None, 3, True)],
)
def test_fit_converged(batch_size: Optional[int], max_epochs: int, converged: bool):
def test_fit_converged(batch_size: int | None, max_epochs: int, converged: bool):
data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]])

estimator = GaussianMixture(
Expand Down
Loading