diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9595bf6..3ca77d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.7 + rev: v0.14.4 hooks: - id: ruff types_or: [python, pyi, jupyter] @@ -19,7 +19,7 @@ repos: - id: ruff-format types_or: [python, pyi, jupyter] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: detect-private-key - id: check-ast diff --git a/src/torchgmm/base/data/collation.py b/src/torchgmm/base/data/collation.py index 395856a..2199de7 100644 --- a/src/torchgmm/base/data/collation.py +++ b/src/torchgmm/base/data/collation.py @@ -1,9 +1,7 @@ -from typing import Tuple - import torch -def collate_tuple(batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: +def collate_tuple(batch: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: """ Collate a tuple of batch items by returning the input tuple. @@ -13,7 +11,7 @@ def collate_tuple(batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: return batch -def collate_tensor(batch: Tuple[torch.Tensor, ...]) -> torch.Tensor: +def collate_tensor(batch: tuple[torch.Tensor, ...]) -> torch.Tensor: """ Collates a tuple of batch items into the first tensor. diff --git a/src/torchgmm/base/data/loader.py b/src/torchgmm/base/data/loader.py index 3a5f88d..e680517 100644 --- a/src/torchgmm/base/data/loader.py +++ b/src/torchgmm/base/data/loader.py @@ -1,4 +1,5 @@ -from typing import Any, Iterator, TypeVar +from collections.abc import Iterator +from typing import Any, TypeVar try: from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper as IndexBatchSamplerWrapper diff --git a/src/torchgmm/base/data/sampler.py b/src/torchgmm/base/data/sampler.py index 7040326..e8fefbc 100644 --- a/src/torchgmm/base/data/sampler.py +++ b/src/torchgmm/base/data/sampler.py @@ -1,5 +1,5 @@ import math -from typing import Iterator +from collections.abc import Iterator from torch.utils.data import Sampler from torch.utils.data.sampler import SequentialSampler diff --git a/src/torchgmm/base/data/types.py b/src/torchgmm/base/data/types.py index eda3709..9498f9c 100644 --- a/src/torchgmm/base/data/types.py +++ b/src/torchgmm/base/data/types.py @@ -1,4 +1,5 @@ -from typing import Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import TypeVar, Union import numpy as np import numpy.typing as npt diff --git a/src/torchgmm/base/nn/_protocols.py b/src/torchgmm/base/nn/_protocols.py index a11fee5..34c58d7 100644 --- a/src/torchgmm/base/nn/_protocols.py +++ b/src/torchgmm/base/nn/_protocols.py @@ -1,5 +1,7 @@ # pylint: disable=missing-class-docstring,missing-function-docstring -from typing import Generic, Iterator, OrderedDict, Protocol, Tuple, Type, TypeVar +from collections import OrderedDict +from collections.abc import Iterator +from typing import Generic, Protocol, TypeVar import torch from torch import nn @@ -15,13 +17,13 @@ class ConfigurableModule(Protocol, Generic[C_co]): def config(self) -> C_co: ... @classmethod - def load(cls: Type[M], path: PathType) -> M: ... + def load(cls: type[M], path: PathType) -> M: ... def save(self, path: PathType, compile_model: bool = False) -> None: ... def save_config(self, path: PathType) -> None: ... - def named_children(self) -> Iterator[Tuple[str, nn.Module]]: ... + def named_children(self) -> Iterator[tuple[str, nn.Module]]: ... def state_dict(self) -> OrderedDict[str, torch.Tensor]: ... diff --git a/src/torchgmm/base/utils/generics.py b/src/torchgmm/base/utils/generics.py index 4ce7266..1aeae66 100644 --- a/src/torchgmm/base/utils/generics.py +++ b/src/torchgmm/base/utils/generics.py @@ -1,7 +1,7 @@ -from typing import Any, Type, get_args, get_origin +from typing import Any, get_args, get_origin -def get_generic_type(cls: Type[Any], origin: Type[Any], index: int = 0) -> Type[Any]: +def get_generic_type(cls: type[Any], origin: type[Any], index: int = 0) -> type[Any]: """ Returns the ``index``-th generic type of the superclass ``origin``. diff --git a/src/torchgmm/bayes/gmm/estimator.py b/src/torchgmm/bayes/gmm/estimator.py index a655250..7777b0f 100644 --- a/src/torchgmm/bayes/gmm/estimator.py +++ b/src/torchgmm/bayes/gmm/estimator.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, List, Tuple, cast +from typing import Any, cast import numpy as np import torch @@ -272,7 +272,7 @@ def score_samples(self, data: TensorLike) -> torch.Tensor: collate_fn=collate_tensor, ) result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader) - return torch.stack([x[1] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)]) + return torch.stack([x[1] for x in cast(list[tuple[torch.Tensor, torch.Tensor]], result)]) def predict(self, data: TensorLike) -> torch.Tensor: """ @@ -321,4 +321,4 @@ def predict_proba(self, data: TensorLike) -> torch.Tensor: collate_fn=collate_tensor, ) result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader) - return torch.cat([x[0] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)]) + return torch.cat([x[0] for x in cast(list[tuple[torch.Tensor, torch.Tensor]], result)]) diff --git a/src/torchgmm/bayes/gmm/metrics.py b/src/torchgmm/bayes/gmm/metrics.py index 0752cba..7235314 100644 --- a/src/torchgmm/bayes/gmm/metrics.py +++ b/src/torchgmm/bayes/gmm/metrics.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import torch from torchmetrics import Metric @@ -17,7 +18,7 @@ def __init__( self, num_components: int, *, - dist_sync_fn: Optional[Callable[[Any], Any]] = None, + dist_sync_fn: Callable[[Any], Any] | None = None, ): super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore @@ -44,7 +45,7 @@ def __init__( num_components: int, num_features: int, *, - dist_sync_fn: Optional[Callable[[Any], Any]] = None, + dist_sync_fn: Callable[[Any], Any] | None = None, ): super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore @@ -78,7 +79,7 @@ def __init__( covariance_type: CovarianceType, reg: float, *, - dist_sync_fn: Optional[Callable[[Any], Any]] = None, + dist_sync_fn: Callable[[Any], Any] | None = None, ): super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore diff --git a/src/torchgmm/bayes/gmm/model.py b/src/torchgmm/bayes/gmm/model.py index 37b205a..a3974e0 100644 --- a/src/torchgmm/bayes/gmm/model.py +++ b/src/torchgmm/bayes/gmm/model.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Tuple import numpy as np import torch @@ -90,7 +89,7 @@ def reset_parameters(self) -> None: if self.covariance_type in ("full", "tied"): self.precisions_cholesky.tril_() - def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Computes the log-probability of observing each of the provided datapoints for each of the GMM's components. diff --git a/src/torchgmm/clustering/kmeans/estimator.py b/src/torchgmm/clustering/kmeans/estimator.py index 011a5fc..11e022d 100644 --- a/src/torchgmm/clustering/kmeans/estimator.py +++ b/src/torchgmm/clustering/kmeans/estimator.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, List, cast +from typing import Any, cast import numpy as np import torch @@ -188,7 +188,7 @@ def predict(self, data: TensorLike) -> torch.Tensor: collate_fn=collate_tensor, ) result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="assignments"), loader) - return torch.cat(cast(List[torch.Tensor], result)) + return torch.cat(cast(list[torch.Tensor], result)) def score(self, data: TensorLike) -> float: """ @@ -236,7 +236,7 @@ def score_samples(self, data: TensorLike) -> torch.Tensor: collate_fn=collate_tensor, ) result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="inertias"), loader) - return torch.cat(cast(List[torch.Tensor], result)) + return torch.cat(cast(list[torch.Tensor], result)) def transform(self, data: TensorLike) -> torch.Tensor: """ @@ -262,4 +262,4 @@ def transform(self, data: TensorLike) -> torch.Tensor: collate_fn=collate_tensor, ) result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="distances"), loader) - return torch.cat(cast(List[torch.Tensor], result)) + return torch.cat(cast(list[torch.Tensor], result)) diff --git a/src/torchgmm/clustering/kmeans/lightning_module.py b/src/torchgmm/clustering/kmeans/lightning_module.py index c4e969d..eb708ee 100644 --- a/src/torchgmm/clustering/kmeans/lightning_module.py +++ b/src/torchgmm/clustering/kmeans/lightning_module.py @@ -1,6 +1,6 @@ # pylint: disable=abstract-method import math -from typing import List, Literal +from typing import Literal import pytorch_lightning as pl import torch @@ -56,7 +56,7 @@ def __init__( # Initialize metrics self.metric_inertia = MeanMetric() - def configure_callbacks(self) -> List[pl.Callback]: + def configure_callbacks(self) -> list[pl.Callback]: if self.convergence_tolerance == 0: return [] early_stopping = EarlyStopping( diff --git a/src/torchgmm/clustering/kmeans/metrics.py b/src/torchgmm/clustering/kmeans/metrics.py index dd05ee3..64995b5 100644 --- a/src/torchgmm/clustering/kmeans/metrics.py +++ b/src/torchgmm/clustering/kmeans/metrics.py @@ -1,5 +1,6 @@ import random -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import torch from torchmetrics import Metric @@ -17,7 +18,7 @@ def __init__( num_clusters: int, num_features: int, *, - dist_sync_fn: Optional[Callable[[Any], Any]] = None, + dist_sync_fn: Callable[[Any], Any] | None = None, ): super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore @@ -61,7 +62,7 @@ def __init__( num_choices: int, num_features: int, *, - dist_sync_fn: Optional[Callable[[Any], Any]] = None, + dist_sync_fn: Callable[[Any], Any] | None = None, ): super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore @@ -123,7 +124,7 @@ def __init__( num_choices: int, num_features: int, *, - dist_sync_fn: Optional[Callable[[Any], Any]] = None, + dist_sync_fn: Callable[[Any], Any] | None = None, ): super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore @@ -173,7 +174,7 @@ class BatchSummer(Metric): full_state_update = True - def __init__(self, num_values: int, *, dist_sync_fn: Optional[Callable[[Any], Any]] = None): + def __init__(self, num_values: int, *, dist_sync_fn: Callable[[Any], Any] | None = None): super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore self.sums: torch.Tensor @@ -198,7 +199,7 @@ def __init__( num_values: int, for_variance: bool, *, - dist_sync_fn: Optional[Callable[[Any], Any]] = None, + dist_sync_fn: Callable[[Any], Any] | None = None, ): super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore diff --git a/src/torchgmm/clustering/kmeans/model.py b/src/torchgmm/clustering/kmeans/model.py index d264be6..e2a363c 100644 --- a/src/torchgmm/clustering/kmeans/model.py +++ b/src/torchgmm/clustering/kmeans/model.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Tuple import torch from torch import jit, nn @@ -52,7 +51,7 @@ def reset_parameters(self) -> None: """ nn.init.normal_(self.centroids) - def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the distance of each datapoint to each centroid as well as the "inertia", the squared distance of each datapoint to its closest centroid. diff --git a/src/torchgmm/utils/lightning_module.py b/src/torchgmm/utils/lightning_module.py index 32ea630..281db6a 100644 --- a/src/torchgmm/utils/lightning_module.py +++ b/src/torchgmm/utils/lightning_module.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import List import pytorch_lightning as pl import torch @@ -32,7 +31,7 @@ def on_train_epoch_end(self) -> None: self.nonparametric_training_epoch_end() else: - def training_epoch_end(self, outputs: List[torch.Tensor]) -> None: + def training_epoch_end(self, outputs: list[torch.Tensor]) -> None: """Training epoch end hook for PyTorch Lightning < 2.0.0.""" self.nonparametric_training_epoch_end() diff --git a/tests/_data/gmm.py b/tests/_data/gmm.py index 7d4e13e..45044b9 100644 --- a/tests/_data/gmm.py +++ b/tests/_data/gmm.py @@ -1,5 +1,4 @@ # pylint: disable=missing-function-docstring -from typing import Tuple import torch @@ -9,7 +8,7 @@ def sample_gmm( num_datapoints: int, num_features: int, num_components: int, covariance_type: CovarianceType -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: config = GaussianMixtureModelConfig(num_components, num_features, covariance_type) model = GaussianMixtureModel(config) diff --git a/tests/_data/normal.py b/tests/_data/normal.py index a094cba..cbad992 100644 --- a/tests/_data/normal.py +++ b/tests/_data/normal.py @@ -1,26 +1,25 @@ # pylint: disable=missing-function-docstring -from typing import List import torch -def sample_data(counts: List[int], dims: List[int]) -> List[torch.Tensor]: +def sample_data(counts: list[int], dims: list[int]) -> list[torch.Tensor]: return [torch.randn(count, dim) for count, dim in zip(counts, dims)] -def sample_means(counts: List[int], dims: List[int]) -> List[torch.Tensor]: +def sample_means(counts: list[int], dims: list[int]) -> list[torch.Tensor]: return [torch.randn(count, dim) for count, dim in zip(counts, dims)] -def sample_spherical_covars(counts: List[int]) -> List[torch.Tensor]: +def sample_spherical_covars(counts: list[int]) -> list[torch.Tensor]: return [torch.rand(count) for count in counts] -def sample_diag_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]: +def sample_diag_covars(counts: list[int], dims: list[int]) -> list[torch.Tensor]: return [torch.rand(count, dim).squeeze() for count, dim in zip(counts, dims)] -def sample_full_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]: +def sample_full_covars(counts: list[int], dims: list[int]) -> list[torch.Tensor]: result = [] for count, dim in zip(counts, dims): A = torch.rand(count, dim * 10, dim) diff --git a/tests/bayes/gmm/benchmark_gmm_estimator.py b/tests/bayes/gmm/benchmark_gmm_estimator.py index 9140964..b370da4 100644 --- a/tests/bayes/gmm/benchmark_gmm_estimator.py +++ b/tests/bayes/gmm/benchmark_gmm_estimator.py @@ -1,5 +1,4 @@ # pylint: disable=missing-function-docstring -from typing import Optional import pytest import pytorch_lightning as pl @@ -72,7 +71,7 @@ def test_torchgmm( num_features: int, num_components: int, covariance_type: CovarianceType, - batch_size: Optional[int], + batch_size: int | None, ): pl.seed_everything(0) data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type) @@ -115,7 +114,7 @@ def test_torchgmm_gpu( num_features: int, num_components: int, covariance_type: CovarianceType, - batch_size: Optional[int], + batch_size: int | None, ): # Initialize GPU torch.empty(1, device="cuda:0") diff --git a/tests/bayes/gmm/test_gmm_estimator.py b/tests/bayes/gmm/test_gmm_estimator.py index 701d9c2..230a52d 100644 --- a/tests/bayes/gmm/test_gmm_estimator.py +++ b/tests/bayes/gmm/test_gmm_estimator.py @@ -1,6 +1,5 @@ # pylint: disable=missing-function-docstring import math -from typing import Optional import pytest import torch @@ -23,7 +22,7 @@ def test_fit_model_config(): @pytest.mark.parametrize("batch_size", [2, None]) -def test_fit_num_iter(batch_size: Optional[int]): +def test_fit_num_iter(batch_size: int | None): # For the following data, K-means will find centroids [0.5, 3.5]. The estimator first computes # the NLL (first iteration), afterwards there is no improvmement in the NLL (second iteration). data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]]) @@ -41,7 +40,7 @@ def test_fit_num_iter(batch_size: Optional[int]): ("batch_size", "max_epochs", "converged"), [(2, 1, False), (2, 3, True), (None, 1, False), (None, 3, True)], ) -def test_fit_converged(batch_size: Optional[int], max_epochs: int, converged: bool): +def test_fit_converged(batch_size: int | None, max_epochs: int, converged: bool): data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]]) estimator = GaussianMixture( diff --git a/tests/bayes/gmm/test_gmm_metrics.py b/tests/bayes/gmm/test_gmm_metrics.py index 8b28fcc..48c239d 100644 --- a/tests/bayes/gmm/test_gmm_metrics.py +++ b/tests/bayes/gmm/test_gmm_metrics.py @@ -1,5 +1,6 @@ # pylint: disable=protected-access,missing-function-docstring -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import numpy as np import sklearn.mixture._gaussian_mixture as skgmm # type: ignore diff --git a/tests/clustering/kmeans/benchmark_kmeans_estimator.py b/tests/clustering/kmeans/benchmark_kmeans_estimator.py index e95218a..8b4ab73 100644 --- a/tests/clustering/kmeans/benchmark_kmeans_estimator.py +++ b/tests/clustering/kmeans/benchmark_kmeans_estimator.py @@ -1,5 +1,4 @@ # pylint: disable=missing-function-docstring -from typing import Optional import pytest import pytorch_lightning as pl @@ -66,7 +65,7 @@ def test_sklearn( def test_torchgmm( benchmark: BenchmarkFixture, num_datapoints: int, - batch_size: Optional[int], + batch_size: int | None, num_features: int, num_centroids: int, init_strategy: KMeansInitStrategy, @@ -106,7 +105,7 @@ def test_torchgmm( def test_torchgmm_gpu( benchmark: BenchmarkFixture, num_datapoints: int, - batch_size: Optional[int], + batch_size: int | None, num_features: int, num_centroids: int, init_strategy: KMeansInitStrategy, diff --git a/tests/clustering/kmeans/test_kmeans_estimator.py b/tests/clustering/kmeans/test_kmeans_estimator.py index 3afb631..0214362 100644 --- a/tests/clustering/kmeans/test_kmeans_estimator.py +++ b/tests/clustering/kmeans/test_kmeans_estimator.py @@ -1,6 +1,5 @@ # pylint: disable=missing-function-docstring import math -from typing import Optional import pytest import torch @@ -60,7 +59,7 @@ def test_fit_converged(num_epochs: int, converged: bool): ) def test_fit_inertia( num_datapoints: int, - batch_size: Optional[int], + batch_size: int | None, num_features: int, num_centroids: int, ):