Skip to content

Commit ff7ed55

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 871b0e3 commit ff7ed55

File tree

23 files changed

+62
-62
lines changed

23 files changed

+62
-62
lines changed

src/torchgmm/base/data/collation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import Tuple
2-
31
import torch
42

53

6-
def collate_tuple(batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
4+
def collate_tuple(batch: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
75
"""
86
Collate a tuple of batch items by returning the input tuple.
97
@@ -13,7 +11,7 @@ def collate_tuple(batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
1311
return batch
1412

1513

16-
def collate_tensor(batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
14+
def collate_tensor(batch: tuple[torch.Tensor, ...]) -> torch.Tensor:
1715
"""
1816
Collates a tuple of batch items into the first tensor.
1917

src/torchgmm/base/data/loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Iterator, TypeVar
1+
from collections.abc import Iterator
2+
from typing import Any, TypeVar
23

34
try:
45
from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper as IndexBatchSamplerWrapper

src/torchgmm/base/data/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Iterator
2+
from collections.abc import Iterator
33

44
from torch.utils.data import Sampler
55
from torch.utils.data.sampler import SequentialSampler

src/torchgmm/base/data/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Sequence, TypeVar, Union
1+
from collections.abc import Sequence
2+
from typing import TypeVar, Union
23

34
import numpy as np
45
import numpy.typing as npt

src/torchgmm/base/nn/_protocols.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# pylint: disable=missing-class-docstring,missing-function-docstring
2-
from typing import Generic, Iterator, OrderedDict, Protocol, Tuple, Type, TypeVar
2+
from collections import OrderedDict
3+
from collections.abc import Iterator
4+
from typing import Generic, Protocol, TypeVar
35

46
import torch
57
from torch import nn
@@ -15,13 +17,13 @@ class ConfigurableModule(Protocol, Generic[C_co]):
1517
def config(self) -> C_co: ...
1618

1719
@classmethod
18-
def load(cls: Type[M], path: PathType) -> M: ...
20+
def load(cls: type[M], path: PathType) -> M: ...
1921

2022
def save(self, path: PathType, compile_model: bool = False) -> None: ...
2123

2224
def save_config(self, path: PathType) -> None: ...
2325

24-
def named_children(self) -> Iterator[Tuple[str, nn.Module]]: ...
26+
def named_children(self) -> Iterator[tuple[str, nn.Module]]: ...
2527

2628
def state_dict(self) -> OrderedDict[str, torch.Tensor]: ...
2729

src/torchgmm/base/utils/generics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Any, Type, get_args, get_origin
1+
from typing import Any, get_args, get_origin
22

33

4-
def get_generic_type(cls: Type[Any], origin: Type[Any], index: int = 0) -> Type[Any]:
4+
def get_generic_type(cls: type[Any], origin: type[Any], index: int = 0) -> type[Any]:
55
"""
66
Returns the ``index``-th generic type of the superclass ``origin``.
77

src/torchgmm/bayes/core/_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def jit_log_normal(
1414
# Precision shape is `[num_components, dim, dim]`.
1515
log_prob = x.new_empty((x.size(0), means.size(0)))
1616
# We loop here to not blow up the size of intermediate matrices
17-
for k, (mu, prec_chol) in enumerate(zip(means, precisions_cholesky)):
17+
for k, (mu, prec_chol) in enumerate(zip(means, precisions_cholesky, strict=False)):
1818
inner = x.matmul(prec_chol) - mu.matmul(prec_chol)
1919
log_prob[:, k] = inner.square().sum(1)
2020
elif covariance_type == "tied":

src/torchgmm/bayes/gmm/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Any, List, Tuple, cast
4+
from typing import Any, cast
55

66
import numpy as np
77
import torch
@@ -272,7 +272,7 @@ def score_samples(self, data: TensorLike) -> torch.Tensor:
272272
collate_fn=collate_tensor,
273273
)
274274
result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader)
275-
return torch.stack([x[1] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
275+
return torch.stack([x[1] for x in cast(list[tuple[torch.Tensor, torch.Tensor]], result)])
276276

277277
def predict(self, data: TensorLike) -> torch.Tensor:
278278
"""
@@ -321,4 +321,4 @@ def predict_proba(self, data: TensorLike) -> torch.Tensor:
321321
collate_fn=collate_tensor,
322322
)
323323
result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader)
324-
return torch.cat([x[0] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
324+
return torch.cat([x[0] for x in cast(list[tuple[torch.Tensor, torch.Tensor]], result)])

src/torchgmm/bayes/gmm/metrics.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Callable, Optional
1+
from collections.abc import Callable
2+
from typing import Any
23

34
import torch
45
from torchmetrics import Metric
@@ -17,7 +18,7 @@ def __init__(
1718
self,
1819
num_components: int,
1920
*,
20-
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
21+
dist_sync_fn: Callable[[Any], Any] | None = None,
2122
):
2223
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
2324

@@ -44,7 +45,7 @@ def __init__(
4445
num_components: int,
4546
num_features: int,
4647
*,
47-
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
48+
dist_sync_fn: Callable[[Any], Any] | None = None,
4849
):
4950
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
5051

@@ -78,7 +79,7 @@ def __init__(
7879
covariance_type: CovarianceType,
7980
reg: float,
8081
*,
81-
dist_sync_fn: Optional[Callable[[Any], Any]] = None,
82+
dist_sync_fn: Callable[[Any], Any] | None = None,
8283
):
8384
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
8485

src/torchgmm/bayes/gmm/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass
2-
from typing import Tuple
32

43
import numpy as np
54
import torch
@@ -90,7 +89,7 @@ def reset_parameters(self) -> None:
9089
if self.covariance_type in ("full", "tied"):
9190
self.precisions_cholesky.tril_()
9291

93-
def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
92+
def forward(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
9493
"""
9594
Computes the log-probability of observing each of the provided datapoints for each of the
9695
GMM's components.

0 commit comments

Comments
 (0)