11from __future__ import annotations
22
33import logging
4- from typing import Any , List , cast
4+ from typing import Any , cast
55
66import numpy as np
77import torch
@@ -188,7 +188,7 @@ def predict(self, data: TensorLike) -> torch.Tensor:
188188 collate_fn = collate_tensor ,
189189 )
190190 result = self .trainer ().predict (KMeansLightningModule (self .model_ , predict_target = "assignments" ), loader )
191- return torch .cat (cast (List [torch .Tensor ], result ))
191+ return torch .cat (cast (list [torch .Tensor ], result ))
192192
193193 def score (self , data : TensorLike ) -> float :
194194 """
@@ -236,7 +236,7 @@ def score_samples(self, data: TensorLike) -> torch.Tensor:
236236 collate_fn = collate_tensor ,
237237 )
238238 result = self .trainer ().predict (KMeansLightningModule (self .model_ , predict_target = "inertias" ), loader )
239- return torch .cat (cast (List [torch .Tensor ], result ))
239+ return torch .cat (cast (list [torch .Tensor ], result ))
240240
241241 def transform (self , data : TensorLike ) -> torch .Tensor :
242242 """
@@ -262,4 +262,4 @@ def transform(self, data: TensorLike) -> torch.Tensor:
262262 collate_fn = collate_tensor ,
263263 )
264264 result = self .trainer ().predict (KMeansLightningModule (self .model_ , predict_target = "distances" ), loader )
265- return torch .cat (cast (List [torch .Tensor ], result ))
265+ return torch .cat (cast (list [torch .Tensor ], result ))
0 commit comments