From 732dbfc58d8e13765cef440f2616cbaf246d80ee Mon Sep 17 00:00:00 2001 From: Ostap Bodnar Date: Thu, 1 May 2025 23:12:59 +0300 Subject: [PATCH 1/7] feat(sq): algorithm optimization Multiple implementations added (faiss, numpy, numpy+memmap) Nearest centroid search optimized Refactored StochasticQuantization class Added parallel execution support Removed required history log (used only when turned on) Updated verbosity to support levels Integrated tqdm for progress tracking Model export/saving and loading adde --- code/pyproject.toml | 11 +- code/setup.py | 6 +- code/sqg/centroids_storage/__init__.py | 12 + code/sqg/centroids_storage/factory.py | 257 ++++++++++++++++ code/sqg/centroids_storage/faiss_storage.py | 85 ++++++ code/sqg/{ => centroids_storage}/init.py | 40 +++ code/sqg/centroids_storage/numpy_storage.py | 86 ++++++ code/sqg/metric.py | 95 ------ code/sqg/optim.py | 27 +- code/sqg/progress_tracking/__init__.py | 3 + code/sqg/progress_tracking/tqdm_joblib.py | 39 +++ code/sqg/quantization.py | 323 +++++++++++++------- code/sqg/utils.py | 10 + 13 files changed, 763 insertions(+), 231 deletions(-) create mode 100644 code/sqg/centroids_storage/__init__.py create mode 100644 code/sqg/centroids_storage/factory.py create mode 100644 code/sqg/centroids_storage/faiss_storage.py rename code/sqg/{ => centroids_storage}/init.py (83%) create mode 100644 code/sqg/centroids_storage/numpy_storage.py delete mode 100644 code/sqg/metric.py create mode 100644 code/sqg/progress_tracking/__init__.py create mode 100644 code/sqg/progress_tracking/tqdm_joblib.py create mode 100644 code/sqg/utils.py diff --git a/code/pyproject.toml b/code/pyproject.toml index f3a9d3d..4b5c166 100644 --- a/code/pyproject.toml +++ b/code/pyproject.toml @@ -10,14 +10,15 @@ name = "sqg" version = "1.0.0" description = "A robust and scalable alternative to existing K-means solvers." authors = [ - {name = "Vladimir Norkin", email = "v.norkin@kpi.ua"}, - {name = "Anton Kozyriev", email = "a.kozyriev@kpi.ua"}, + { name = "Vladimir Norkin", email = "v.norkin@kpi.ua" }, + { name = "Anton Kozyriev", email = "a.kozyriev@kpi.ua" }, ] readme = "README.md" -license = {file = "LICENSE.code.md"} +license = { file = "LICENSE.code.md" } dependencies = [ "numpy>=1.26.4,<2", "scikit-learn>=1.5.1,<2", + "tqdm>=4.66.0,<5", ] requires-python = ">=3.8" keywords = [ @@ -47,3 +48,7 @@ classifiers = [ Homepage = "https://github.com/kaydotdev/stochastic-quantization" Issues = "https://github.com/kaydotdev/stochastic-quantization/issues" Repository = "https://github.com/kaydotdev/stochastic-quantization.git" + +[project.optional-dependencies] +faiss-cpu = ["faiss-cpu>=1.10.0,<2"] +faiss-gpu = ["faiss-gpu>=1.10.0,<2"] diff --git a/code/setup.py b/code/setup.py index 7ae861f..4ecd5cc 100644 --- a/code/setup.py +++ b/code/setup.py @@ -1,6 +1,5 @@ from setuptools import setup, find_packages - if __name__ == "__main__": setup( name="sqg", @@ -45,5 +44,10 @@ install_requires=[ "numpy>=1.26.4,<2", "scikit-learn>=1.5.1,<2", + "tqdm>=4.66.0,<5", ], + extras_require={ + "faiss-cpu": ["faiss-cpu>=1.10.0,<2"], + "faiss-gpu": ["faiss-gpu>=1.10.0,<2"] + }, ) diff --git a/code/sqg/centroids_storage/__init__.py b/code/sqg/centroids_storage/__init__.py new file mode 100644 index 0000000..2a7ddcb --- /dev/null +++ b/code/sqg/centroids_storage/__init__.py @@ -0,0 +1,12 @@ +from centroids_storage.factory import CentroidStorage, CentroidStorageFactory, StorageBackendType +from centroids_storage.faiss_storage import FaissIndexBasedCentroidStorage +from centroids_storage.numpy_storage import NumpyCentroidStorage, NumpyMemmapCentroidStorage + +__all__ = [ + "CentroidStorage", + "CentroidStorageFactory", + "StorageBackendType", + "NumpyCentroidStorage", + "NumpyMemmapCentroidStorage", + "FaissIndexBasedCentroidStorage", +] diff --git a/code/sqg/centroids_storage/factory.py b/code/sqg/centroids_storage/factory.py new file mode 100644 index 0000000..9c872e3 --- /dev/null +++ b/code/sqg/centroids_storage/factory.py @@ -0,0 +1,257 @@ +import abc +import tempfile +from copy import deepcopy +from typing import Literal, Callable + +import numpy as np + + +class CentroidStorage(abc.ABC): + """ + Abstract base class for centroid storage implementations. + + Parameters + ---------- + n_clusters : int + The number of clusters (centroids) to initialize. + init : str or np.ndarray, optional + Method for initialization. Can be 'k-means++' or an ndarray of initial centroids. + """ + + def __init__(self, n_clusters: int, init: str | np.ndarray = "k-means++", *args, **kwargs): + self._n_clusters = n_clusters + self._init = init + + @property + def n_clusters(self) -> int: + """ + Returns the number of clusters. + + Returns + ------- + int + The number of clusters. + """ + return self._n_clusters + + @property + @abc.abstractmethod + def centroids(self): + """ + Returns the centroids. + + Returns + ------- + np.ndarray + The centroids. + + Raises + ------ + ValueError + If the centroids have not been initialized yet. + """ + raise NotImplementedError() + + @property + @abc.abstractmethod + def name(self) -> str: + """ + Returns the name of the centroid storage implementation. + + Returns + ------- + str + The name of the centroid storage implementation. + + Raises + ------ + NotImplementedError + If the method is not implemented by the subclass. + """ + raise NotImplementedError + + @abc.abstractmethod + def init_centroids(self, X: np.ndarray, random_state: np.random.RandomState): + """ + Initializes the centroids. + + Parameters + ---------- + X : np.ndarray + The data to initialize the centroids. + random_state : np.random.RandomState + The random state for reproducibility. + + Raises + ------ + NotImplementedError + If the method is not implemented by the subclass. + """ + raise NotImplementedError + + @abc.abstractmethod + def find_nearest_centroid(self, target: np.ndarray) -> tuple[np.ndarray, np.uint]: + """ + Finds the nearest centroid to the target. + + Parameters + ---------- + target : np.ndarray + The target data point. + + Returns + ------- + tuple[np.ndarray, np.uint] + The nearest centroid and its index. + + Raises + ------ + NotImplementedError + If the method is not implemented by the subclass. + """ + raise NotImplementedError + + @abc.abstractmethod + def update_centroid(self, index: np.uint, delta: np.ndarray): + """ + Updates the centroid at the given index. + + Parameters + ---------- + index : np.uint + The index of the centroid to update. + delta : np.ndarray + The change to apply to the centroid. + + Raises + ------ + NotImplementedError + If the method is not implemented by the subclass. + """ + raise NotImplementedError + + def calculate_loss(self, X: np.ndarray) -> np.float64: + """Calculates stochastic Wasserstein (or Kantorovich–Rubinstein) distance between distributions ξ and y: + + F(y) = Σᵢ₌₁ᴵ pᵢ min₁≤k≤K d(ξᵢ, yₖ)ʳ + + Parameters + ---------- + xi : np.ndarray + The original distribution ξ with shape (N, D, ...). + y : np.ndarray + The quantized distribution y with shape (M, D, ...). + + Returns + ------- + np.float64 + The calculated stochastic Wasserstein distance between distributions ξ and y. + + Raises + ------ + ValueError + If one of the distributions ξ or y is empty. + ValueError + If there is a shape mismatch between individual elements in distribution ξ and y. + + Notes + ----- + The function assumes uniform weights (pᵢ = 1) for all elements in the original distribution. + The exponent r in the formula is implicitly set to 1 in this implementation. + """ + + if X.size == 0 or self.centroids.size == 0: + raise ValueError("One of the distributions `X` or `centroids` is empty.") + + if X.shape[1:] != self.centroids.shape[1:]: + raise ValueError( + "The dimensions of individual elements in distribution `X` and `centroids` must match. Elements in " + f"`X` have shape {X.shape[1:]}, but y elements have shape {self.centroids.shape[1:]}." + ) + + distances = [ + np.linalg.norm(self.find_nearest_centroid(ksi)[0] - ksi) for ksi in X + ] + + return np.sum(distances) + + +StorageBackendType = Literal["numpy", "numpy_memmap", "faiss"] | CentroidStorage + + +class CentroidStorageFactory: + """ + Factory class for creating centroid storage instances. + """ + _implementations: dict[str, tuple[type[CentroidStorage], bool]] = dict() + + @classmethod + def register(cls, requires_filepath: bool = False): + """ + Registers a new centroid storage implementation. + + Parameters + ---------- + storage_type : type[CentroidStorage] + The centroid storage implementation to register. + + Returns + ------- + type[CentroidStorage] + The registered centroid storage implementation. + """ + + def decorator(storage_type: type[CentroidStorage]): + cls._implementations[storage_type.name] = (storage_type, requires_filepath) + return storage_type + + return decorator + + @classmethod + def create( + cls, + storage_type: StorageBackendType, + n_clusters: int, + init: str | np.ndarray = "k-means++", + **kwargs + ) -> tuple[CentroidStorage, Callable[[], None]]: + """ + Creates a centroid storage instance. + + Parameters + ---------- + storage_type : StorageBackendType + The type of storage backend to use. + n_clusters : int + The number of clusters (centroids) to initialize. + init : str or np.ndarray, optional + Method for initialization. Can be 'k-means++' or an ndarray of initial centroids. + **kwargs + Additional keyword arguments for the storage implementation. + + Returns + ------- + CentroidStorage + The created centroid storage instance. + + Raises + ------ + ValueError + If the storage type is unknown. + """ + if isinstance(storage_type, CentroidStorage): + return storage_type, lambda: None + if isinstance(storage_type, str) and storage_type in cls._implementations: + kwargs = deepcopy(kwargs) + storage_implementation, requires_filepath = cls._implementations[storage_type] + if requires_filepath and "filepath" not in kwargs: + memory_file = tempfile.NamedTemporaryFile() + kwargs["filepath"] = memory_file.name + print(kwargs["filepath"]) + return storage_implementation( + n_clusters=n_clusters, init=init, **kwargs + ), (lambda: memory_file.close()) if not kwargs.get("keep_filepath") else lambda: None + raise ValueError( + f"Unknown storage type: {storage_type}, supported types are: {sorted(cls._implementations.keys())} " + f"or {CentroidStorage.__name__} instance" + ) diff --git a/code/sqg/centroids_storage/faiss_storage.py b/code/sqg/centroids_storage/faiss_storage.py new file mode 100644 index 0000000..536ab48 --- /dev/null +++ b/code/sqg/centroids_storage/faiss_storage.py @@ -0,0 +1,85 @@ +import numpy as np + +from centroids_storage.factory import CentroidStorageFactory +from centroids_storage.numpy_storage import NumpyMemmapCentroidStorage + + +def _load_faiss(): + try: + import faiss + except ImportError as e: + raise ImportError( + "Faiss is not installed. Please install it using extras `pip install sq[faiss]` or `pip install faiss-cpu`." + ) from e + return faiss + + +@CentroidStorageFactory.register(requires_filepath=True) +class FaissIndexBasedCentroidStorage(NumpyMemmapCentroidStorage): + name = "faiss" + + def __init__( + self, + filepath: str, + n_clusters: int, + init: str | np.ndarray = "k-means++", + voronoi_cell_size: int = 100, + *args, **kwargs + ): + """ + FaissIndexBasedCentroidStorage is a centroid storage class that uses FAISS for efficient nearest neighbor + search and centroid updates. It is designed to work with large datasets. + + Parameters + ---------- + n_clusters : int + The number of clusters (centroids) to initialize. + init : str or np.ndarray, optional + Method for initialization. Can be 'k-means++' or an ndarray of initial centroids. + filepath : str or None, optional + Path to the file where centroids are stored. If None, centroids are stored in memory. + voronoi_cell_size : int, optional + The size of the Voronoi cells for the FAISS index. Default is 100. + """ + super().__init__(filepath, n_clusters, init, *args, **kwargs) + _faiss = _load_faiss() + self.index: _faiss.Index | None = None + self._dim = None + self._voronoi_cell_size = voronoi_cell_size + self._faiss = _faiss + + def init_centroids(self, X: np.ndarray, random_state: np.random.RandomState): + x_len, x_dims = X.shape + self._dim = x_dims + quantizer = self._faiss.IndexFlatL2(self._dim) + index = self._faiss.IndexIVFFlat(quantizer, self._dim, self._voronoi_cell_size, self._faiss.METRIC_L2) + super().init_centroids(X, random_state) + index.train(self._centroids) + index.add_with_ids(self._centroids, np.arange(0, self.n_clusters)) + self.index = index + + def find_nearest_centroid(self, target: np.ndarray) -> tuple[np.ndarray, np.uint]: + distances, indices = self.index.search(np.atleast_2d(target), 1) + nearest_index = indices[0, 0] + nearest_centroid = self._centroids[nearest_index] + return nearest_centroid, np.int64(nearest_index) + + def update_centroid(self, index: np.uint, delta: np.ndarray): + value = self._centroids[index] + index_array = np.array([index]) + self.index.remove_ids(index_array) + updated_value = np.atleast_2d(value - delta) + self._centroids[index] = updated_value + self.index.add_with_ids(updated_value, index_array) + + def __getstate__(self): + state = super().__getstate__() + state.pop('_faiss', None) + state['faiss_index'] = self._faiss.serialize_index(self.index) + return state + + def __setstate__(self, state): + self._faiss = _load_faiss() + faiss_index = state.pop('faiss_index') + super().__setstate__(state) + self.index = self._faiss.deserialize_index(faiss_index) diff --git a/code/sqg/init.py b/code/sqg/centroids_storage/init.py similarity index 83% rename from code/sqg/init.py rename to code/sqg/centroids_storage/init.py index 2c53aca..2c35c61 100644 --- a/code/sqg/init.py +++ b/code/sqg/centroids_storage/init.py @@ -6,6 +6,46 @@ from sklearn import metrics +def init_centroids(init: np.ndarray | str, n_clusters: int, x: np.ndarray, random_state: np.random.RandomState): + x_len, x_dims = x.shape + match init: + case _ if isinstance(init, np.ndarray): + init_len, init_dims = init.shape + + if init_dims != x_dims: + raise ValueError( + f"The dimensions of initial quantized distribution ({init_len}) and input tensor " + f"({x_dims}) must match." + ) + + if init_len != n_clusters: + raise ValueError( + f"The number of elements in the initial quantized distribution ({init_len}) should match the " + f"given number of optimal quants ({n_clusters})." + ) + + return init.copy() + case "sample": + random_indices = random_state.choice( + x_len, size=n_clusters, replace=False + ) + return x[random_indices] + case "random": + return random_state.rand(n_clusters, x_dims) + case "milp": + return _milp(x, n_clusters) + case "k-means++" | None: + return _kmeans_plus_plus( + x, n_clusters, random_state + ) + case _: + raise ValueError( + f"Initialization strategy ‘{init}’ is not a valid option. Supported options are " + "{‘sample’, ‘random’, ‘k-means++’, ‘milp’}." + ) + + + def _kmeans_plus_plus( X: np.ndarray, n_clusters: Union[int, np.uint] = 2, diff --git a/code/sqg/centroids_storage/numpy_storage.py b/code/sqg/centroids_storage/numpy_storage.py new file mode 100644 index 0000000..f51fda8 --- /dev/null +++ b/code/sqg/centroids_storage/numpy_storage.py @@ -0,0 +1,86 @@ +import os +import tempfile + +import numpy as np + +from centroids_storage.factory import CentroidStorageFactory, CentroidStorage +from centroids_storage.init import init_centroids + + +@CentroidStorageFactory.register() +class NumpyCentroidStorage(CentroidStorage): + name = "numpy" + + def __init__(self, n_clusters: int, init: str | np.ndarray = "k-means++", *args, **kwargs): + """ + Initializes the NumpyCentroidStorage class. + + Parameters + ---------- + n_clusters : int + The number of clusters (centroids) to initialize. + init : str or np.ndarray, optional + Method for initialization. Can be 'k-means++' or an ndarray of initial centroids. + """ + super().__init__(n_clusters, init, *args, **kwargs) + self._centroids = None + + @property + def centroids(self) -> np.ndarray: + if self._centroids is None: + raise ValueError("Centroids have not been initialized yet.") + return self._centroids + + def init_centroids(self, x: np.ndarray, random_state: np.random.RandomState): + self._centroids = init_centroids(self._init, self._n_clusters, x, random_state) + + def find_nearest_centroid(self, target: np.ndarray) -> tuple[np.ndarray, np.uint]: + y = self.centroids + + if y.size == 0 or target.size == 0: + raise ValueError("Either the `y` input tensor or the `target` tensor is empty.") + + if y.shape[1:] != target.shape: + raise ValueError( + "The dimensions of individual elements in `y` and `target` must match. Elements in `y` have " + f"shape {y.shape[1:]}, but `target` tensor has shape {target.shape}." + ) + + distance = np.linalg.norm(target - y, axis=1) + nearest_index = np.argmin(distance) + + return y[nearest_index, :], nearest_index + + def update_centroid(self, index: np.uint, delta: np.ndarray): + self._centroids[index] -= delta + + +@CentroidStorageFactory.register(requires_filepath=True) +class NumpyMemmapCentroidStorage(NumpyCentroidStorage): + name = "numpy_memmap" + + def __init__(self, filepath: str, n_clusters: int, init: str | np.ndarray = "k-means++", *args, **kwargs): + """ + Initializes the NumpyCentroidStorage class. + + Parameters + ---------- + n_clusters : int + The number of clusters (centroids) to initialize. + init : str or np.ndarray, optional + Method for initialization. Can be 'k-means++' or an ndarray of initial centroids. + """ + super().__init__(n_clusters, init, *args, **kwargs) + self._filepath = filepath + + def init_centroids(self, x: np.ndarray, random_state: np.random.RandomState): + _, x_dims = x.shape + self._centroids = np.memmap(self._filepath, dtype=x.dtype, mode='w+', shape=(self._n_clusters, x_dims)) + self._centroids[:] = init_centroids(self._init, self._n_clusters, x, random_state) + + def __getstate__(self): + state = self.__dict__.copy() + return state + + def __setstate__(self, state): + self.__dict__.update(state) diff --git a/code/sqg/metric.py b/code/sqg/metric.py deleted file mode 100644 index 46e65d3..0000000 --- a/code/sqg/metric.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Tuple - -import numpy as np - - -def _calculate_loss(xi: np.ndarray, y: np.ndarray) -> np.float64: - """Calculates stochastic Wasserstein (or Kantorovich–Rubinstein) distance between distributions ξ and y: - - F(y) = Σᵢ₌₁ᴵ pᵢ min₁≤k≤K d(ξᵢ, yₖ)ʳ - - Parameters - ---------- - xi : np.ndarray - The original distribution ξ with shape (N, D, ...). - y : np.ndarray - The quantized distribution y with shape (M, D, ...). - - Returns - ------- - np.float64 - The calculated stochastic Wasserstein distance between distributions ξ and y. - - Raises - ------ - ValueError - If one of the distributions ξ or y is empty. - ValueError - If there is a shape mismatch between individual elements in distribution ξ and y. - - Notes - ----- - The function assumes uniform weights (pᵢ = 1) for all elements in the original distribution. - The exponent r in the formula is implicitly set to 1 in this implementation. - """ - - if xi.size == 0 or y.size == 0: - raise ValueError("One of the distributions `xi` or `y` is empty.") - - if xi.shape[1:] != y.shape[1:]: - raise ValueError( - "The dimensions of individual elements in distribution `xi` and `y` must match. Elements in " - f"`xi` have shape {xi.shape[1:]}, but y elements have shape {y.shape[1:]}." - ) - - pairwise_distance = np.linalg.norm(xi[:, np.newaxis] - y, axis=-1) - min_distance = np.min(pairwise_distance, axis=-1) - - return np.sum(min_distance) - - -def _find_nearest_element( - y: np.ndarray, target: np.ndarray -) -> Tuple[np.ndarray, np.uint]: - """Searches for the nearest element in `y` to `target` based on Euclidean distance. This function computes the - Euclidean distance between each element in `y` and the `target`, then returns the element from `y` that has the - smallest distance to `target`, along with its index. The shape of an individual element of `y` must match the - shape of `target`. - - Parameters - ---------- - y : np.ndarray - The input tensor containing multiple elements to search from with shape (N, D, ...). - - target : np.ndarray - The target tensor with shape (D, ...). - - Returns - ------- - Tuple[np.ndarray, np.uint] - A tuple containing two elements: - - 1. np.ndarray: The nearest element found in `y` to the `target` with shape (D, ...). - 2. np.uint: The index of the nearest element in `y`. - - Raises - ------ - ValueError - If either the `y` input tensor or the `target` tensor is empty. - ValueError - If there is a shape mismatch between individual elements in `y` and `target`. - """ - - if y.size == 0 or target.size == 0: - raise ValueError("Either the `y` input tensor or the `target` tensor is empty.") - - if y.shape[1:] != target.shape: - raise ValueError( - "The dimensions of individual elements in `y` and `target` must match. Elements in `y` have " - f"shape {y.shape[1:]}, but `target` tensor has shape {target.shape}." - ) - - distance = np.linalg.norm(target - y, axis=1) - nearest_index = np.argmin(distance) - - return y[nearest_index, :], nearest_index diff --git a/code/sqg/optim.py b/code/sqg/optim.py index baa87e3..b1d5643 100644 --- a/code/sqg/optim.py +++ b/code/sqg/optim.py @@ -47,7 +47,7 @@ def step( Returns ------- parameters : np.ndarray - Updated parameter values. + Update delta value. Raises ------- @@ -90,7 +90,7 @@ def step( x: np.ndarray, learning_rate: np.float64, ) -> np.ndarray: - return x - learning_rate * grad_fn(x) + return learning_rate * grad_fn(x) def reset(self) -> None: pass @@ -140,13 +140,13 @@ def step( learning_rate: np.float64, ) -> np.ndarray: if self.momentum_term is None: - self.momentum_term = np.zeros(shape=(1, x.size)) + self.momentum_term = np.zeros(shape=x.size) self.momentum_term = self.gamma * self.momentum_term + learning_rate * grad_fn( x ) - return x - self.momentum_term + return self.momentum_term def reset(self) -> None: self.momentum_term = None @@ -201,13 +201,13 @@ def step( learning_rate: np.float64, ) -> np.ndarray: if self.momentum_term is None: - self.momentum_term = np.zeros(shape=(1, x.size)) + self.momentum_term = np.zeros(shape=x.size) self.momentum_term = self.gamma * self.momentum_term + learning_rate * grad_fn( x - self.gamma * self.momentum_term ) - return x - self.momentum_term + return self.momentum_term def reset(self) -> None: self.momentum_term = None @@ -253,13 +253,13 @@ def step( learning_rate: np.float64, ) -> np.ndarray: if self.grad_term is None: - self.grad_term = np.zeros(shape=(1, x.size)) + self.grad_term = np.zeros(shape=x.size) grad_x = grad_fn(x) self.grad_term += grad_x**2 - return x - (learning_rate / np.sqrt(self.grad_term + self.var_eps)) * grad_x + return (learning_rate / np.sqrt(self.grad_term + self.var_eps)) * grad_x def reset(self) -> None: self.grad_term = None @@ -315,13 +315,13 @@ def step( learning_rate: np.float64, ) -> np.ndarray: if self.grad_term is None: - self.grad_term = np.zeros(shape=(1, x.size)) + self.grad_term = np.zeros(shape=x.size) grad_x = grad_fn(x) self.grad_term = self.beta * self.grad_term + (1 - self.beta) * grad_x**2 - return x - (learning_rate / np.sqrt(self.grad_term + self.var_eps)) * grad_x + return (learning_rate / np.sqrt(self.grad_term + self.var_eps)) * grad_x def reset(self) -> None: self.grad_term = None @@ -385,8 +385,8 @@ def step( beta1, beta2 = self.betas if self.momentum_term is None and self.variance_term is None: - self.momentum_term = np.zeros(shape=(1, x.size)) - self.variance_term = np.zeros(shape=(1, x.size)) + self.momentum_term = np.zeros(shape=x.size) + self.variance_term = np.zeros(shape=x.size) grad_x = grad_fn(x) @@ -394,8 +394,7 @@ def step( self.variance_term = beta2 * self.variance_term + (1 - beta2) * grad_x**2 return ( - x - - (learning_rate / np.sqrt(self.variance_term + self.var_eps)) + (learning_rate / np.sqrt(self.variance_term + self.var_eps)) * self.momentum_term ) diff --git a/code/sqg/progress_tracking/__init__.py b/code/sqg/progress_tracking/__init__.py new file mode 100644 index 0000000..2be92eb --- /dev/null +++ b/code/sqg/progress_tracking/__init__.py @@ -0,0 +1,3 @@ +from progress_tracking.tqdm_joblib import tqdm_joblib + +__all__ = ["tqdm_joblib"] diff --git a/code/sqg/progress_tracking/tqdm_joblib.py b/code/sqg/progress_tracking/tqdm_joblib.py new file mode 100644 index 0000000..801f84e --- /dev/null +++ b/code/sqg/progress_tracking/tqdm_joblib.py @@ -0,0 +1,39 @@ +# https://github.com/louisabraham/tqdm_joblib +import contextlib + +import joblib +from tqdm.autonotebook import tqdm + + +@contextlib.contextmanager +def tqdm_joblib(*args, **kwargs): + """Context manager to patch joblib to report into tqdm progress bar + given as argument""" + + tqdm_object = tqdm(*args, **kwargs) + + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() + + +def ParallelPbar(desc=None, **tqdm_kwargs): + class Parallel(joblib.Parallel): + def __call__(self, it): + it = list(it) + with tqdm_joblib(total=len(it), desc=desc, **tqdm_kwargs): + return super().__call__(it) + + return Parallel diff --git a/code/sqg/quantization.py b/code/sqg/quantization.py index bbe3420..a743d37 100644 --- a/code/sqg/quantization.py +++ b/code/sqg/quantization.py @@ -1,12 +1,14 @@ -from typing import Optional, Union +from typing import Optional, Union, Callable +import joblib import numpy as np - from sklearn.base import BaseEstimator, ClusterMixin from sklearn.utils.validation import check_is_fitted, check_random_state +from tqdm.autonotebook import tqdm -from .init import _kmeans_plus_plus, _milp -from .metric import _calculate_loss, _find_nearest_element +from centroids_storage.factory import CentroidStorage, StorageBackendType, CentroidStorageFactory +from progress_tracking import tqdm_joblib +from utils import batched_iterable from .optim import BaseOptimizer @@ -30,24 +32,25 @@ class StochasticQuantization(BaseEstimator, ClusterMixin): n_iter_ : int Number of iterations until declaring convergence. - cluster_centers_ : ndarray - The optimal set of quantized points {y₁, …, yₖ}. """ def __init__( - self, - optim: BaseOptimizer, - *, - n_clusters: Union[int, np.uint] = 2, - max_iter: Union[int, np.uint] = 1, - learning_rate: Union[float, np.float64] = 0.001, - rank: Union[int, np.uint] = 3, - verbose: Union[int, np.uint] = 0, - element_selection_method: Optional[str] = None, - init: Optional[Union[str, np.ndarray]] = None, - tol: Optional[Union[float, np.float64]] = None, - log_step: Optional[Union[int, np.uint]] = None, - random_state: Optional[np.random.RandomState] = None, + self, + optim: BaseOptimizer, + *, + n_clusters: Union[int, np.uint] = 2, + max_iter: Union[int, np.uint] = 1, + learning_rate: Union[float, np.float64] = 0.001, + rank: Union[int, np.uint] = 3, + verbose: Union[int, np.uint] = 0, + backend: StorageBackendType = "numpy", + element_selection_method: Optional[str] = None, + init: Optional[Union[str, np.ndarray]] = None, + tol: Optional[Union[float, np.float64]] = None, + log_step: Optional[Union[int, np.uint]] = None, + random_state: Optional[np.random.RandomState] = None, + backend_kwargs: dict = None, + **kwargs ): """Initialize Stochastic Quantization solver with provided hyperparameters. @@ -71,7 +74,14 @@ def __init__( The degree of the norm (rank) r. Must be greater than or equal to 3. verbose : int or np.uint, default=0 - Verbosity mode (0 - silent mode, 1 - logs progress to STDOUT). + Verbosity mode (0 - silent mode, 1 - progress input with tqdm, 2 - additional log info like loss). + tqdm progress can be turned off by setting `use_tqdm=False` in `kwargs`. And loss is only printed when + `log_step` is set. + + backend : StorageBackendType, default='numpy' + The backend storage type for the centroid storage. Supported options are 'numpy' and 'numpy_memmap', + 'faiss' or any other CommandStorage implementation. Faiss backend requires the faiss library to be + installed. element_selection_method : {‘permutation’, ‘sample’}, optional Method used to select elements uniformly from {ξᵢ} during each iteration: @@ -109,21 +119,82 @@ def __init__( random_state : np.random.RandomState, optional Random state for reproducibility. + + backend_kwargs : dict, optional + Additional keyword arguments for the centroid storage backend. + `keep_filepath` - does not remove filepath on model cleanup + + Notes + ----- """ + self.n_iter_ = 0 + self.n_step_ = 0 self._optim = optim - self._n_clusters = n_clusters self._max_iter = max_iter self._element_selection_method = element_selection_method - self._init = init self._learning_rate = learning_rate self._rank = rank self._tol = tol self._log_step = log_step self._random_state = random_state self._verbose = verbose + self._verbose_details = verbose > 1 + self._verbose_progress = verbose > 0 and kwargs.get("use_tqdm", True) + self.loss_history_ = [] + self.iteration_loss_history_ = [] + storage, deferred_action = CentroidStorageFactory.create( + backend, n_clusters, init, **(backend_kwargs or {}) + ) + self._centroid_storage: CentroidStorage = storage + self._deferred_action: Callable[[], None] = deferred_action + self._kwargs = kwargs + + def __del__(self): + self._deferred_action() + + @property + def centroids(self) -> np.ndarray: + return self._centroid_storage.centroids + + def _shuffle_ksi(self, X: np.ndarray, random_state: np.random.RandomState): + """Shuffle the input tensor {ξᵢ} based on the specified element selection method. + Parameters + ---------- + X : np.ndarray + The input tensor containing training element {ξᵢ}. + random_state : np.random.RandomState + Random state for reproducibility. + Returns + ------- + ksi : generator + A generator that yields shuffled elements from the input tensor {ξᵢ}. + """ + X_len, _ = X.shape + match self._element_selection_method: + case "permutation" | None: + ksi = (ksi_j for ksi_j in random_state.permutation(X)) + case "sample": + ksi = ( + X[j] + for j in random_state.choice(X_len, size=X_len, replace=True) + ) + case _: + raise ValueError( + f"Element selection method ‘{self._element_selection_method}’ is not a valid option. Supported " + "options are {‘permutation’, ‘sample’}." + ) + return ksi + + def reset(self): + """Reset the Stochastic Quantization solver to its initial state.""" + self.n_iter_ = 0 + self.n_step_ = 0 + self.loss_history_ = [] + self.iteration_loss_history_ = [] + self._optim.reset() - def fit(self, X: np.ndarray, y=None): + def fit(self, X: np.ndarray, y=None, n_jobs: int = 1): """Search optimal values of {yₖ} using numeric iterative sequence, that updates parameters {yₖ} based on the calculated gradient value of a norm between sampled ξᵢ and the nearest element yₖ: @@ -140,6 +211,8 @@ def fit(self, X: np.ndarray, y=None): y : None Ignored. This parameter exists only for compatibility with estimator interface. + n_jobs : int, default=1 + The number of jobs to run in parallel. If -1, use all processors. Returns ------- @@ -165,116 +238,121 @@ def fit(self, X: np.ndarray, y=None): if not X_len: raise ValueError("The input tensor X should not be empty.") - match self._init: - case _ if isinstance(self._init, np.ndarray): - init_len, init_dims = self._init.shape - - if init_dims != X_dims: - raise ValueError( - f"The dimensions of initial quantized distribution ({init_len}) and input tensor " - f"({X_dims}) must match." - ) - - if init_len != self._n_clusters: - raise ValueError( - f"The number of elements in the initial quantized distribution ({init_len}) should match the " - f"given number of optimal quants ({self._n_clusters})." - ) - - self.cluster_centers_ = self._init.copy() - case "sample": - random_indices = random_state.choice( - X_len, size=self._n_clusters, replace=False - ) - self.cluster_centers_ = X[random_indices] - case "random": - self.cluster_centers_ = random_state.rand(self._n_clusters, X_dims) - case "milp": - self.cluster_centers_ = _milp(X, self._n_clusters) - case "k-means++" | None: - self.cluster_centers_ = _kmeans_plus_plus( - X, self._n_clusters, random_state - ) - case _: - raise ValueError( - f"Initialization strategy ‘{self._init}’ is not a valid option. Supported options are " - "{‘sample’, ‘random’, ‘k-means++’, ‘milp’}." - ) - - if self._verbose: + self._centroid_storage.init_centroids(X, random_state) + if self._verbose_details: print("Initialization complete") - self.n_iter_ = 0 - self.n_step_ = 0 - self._log_step = self._log_step or X_len - self.loss_history_ = [_calculate_loss(X, self.cluster_centers_)] - self._optim.reset() + self.reset() + if self._log_step or self._tol: + initial_loss = self._centroid_storage.calculate_loss(X) + self.iteration_loss_history_.append(initial_loss) + self.loss_history_.append(initial_loss) + if self._verbose_details: + print("Initial loss:", initial_loss) for i in range(self._max_iter): self.n_iter_ += 1 - match self._element_selection_method: - case "permutation" | None: - ksi = (ksi_j for ksi_j in random_state.permutation(X)) - case "sample": - ksi = ( - X[j] - for j in random_state.choice(X_len, size=X_len, replace=True) - ) - case _: - raise ValueError( - f"Element selection method ‘{self._element_selection_method}’ is not a valid option. Supported " - "options are {‘permutation’, ‘sample’}." - ) - - for ksi_j in ksi: - self.n_step_ += 1 - - nearest_quant, quant_ind = _find_nearest_element( - self.cluster_centers_, ksi_j - ) - - grad_fn = ( - lambda x: self._rank - * np.linalg.norm(ksi_j - x, ord=2) ** (self._rank - 2) - * (x - ksi_j) - ) - - self.cluster_centers_[quant_ind, :] = self._optim.step( - grad_fn, nearest_quant, self._learning_rate - ) - - if not self.n_step_ % self._log_step: - current_loss = _calculate_loss(X, self.cluster_centers_) - - self.loss_history_.append(current_loss) - - if self._verbose: - print( - f"Gradient step [{self.n_step_}/{self._max_iter * X_len}]: loss={current_loss}" + ksi = self._shuffle_ksi(X, random_state) + + if n_jobs == 1: + for ksi_j in tqdm( + ksi, total=X_len, desc="Performing cluster optimization", disable=not self._verbose + ): + self._optimize(self._centroid_storage, self._optim, ksi_j, self._rank, self._learning_rate) + self.n_step_ += 1 + self.__log_step(X, X_len) + else: + with tqdm_joblib( + total=X_len, desc="Performing cluster optimization", disable=not self._verbose): + size = self._log_step or X_len + for ksi_batch in batched_iterable(ksi, size): + joblib.Parallel(n_jobs=n_jobs, max_nbytes=self._kwargs.get('joblib_max_nbytes', '50M'))( + joblib.delayed(self._optimize)( + self._centroid_storage, + self._optim, + ksi_j, + self._rank, + self._learning_rate, + ) for ksi_j in ksi_batch ) + self.n_step_ += size + self.__log_step(X, X_len) - current_loss = _calculate_loss(X, self.cluster_centers_) - - if ( - self._tol is not None - and self.loss_history_[-1] - current_loss < self._tol - ): - if self._verbose: + if self._tol is not None and self._early_stop(X): + if self._verbose_details: print( - f"Converged (small optimal quants change) at step [{self.n_step_}/{self._max_iter * X_len}]" + f"Converged (small optimal quants change) at step [{self.n_step_}/{self._max_iter * X_len}] " + f"with loss={self.iteration_loss_history_[-1]} (iteration {self.n_iter_}, step " + f"{self.n_step_ % X_len})" ) break - return self - def predict(self, X: np.ndarray): + def __log_step(self, X: np.ndarray, X_len: int): + """Log the objective function value at the specified step. + Parameters + ---------- + X : np.ndarray + The input tensor containing training element {ξᵢ}. + """ + if self._log_step and self.n_step_ % self._log_step == 0: + current_loss = self._centroid_storage.calculate_loss(X) + self.loss_history_.append(current_loss) + + if self._verbose_details: + print( + f"Gradient step [{self.n_step_}/{self._max_iter * X_len}]: loss={current_loss} " + f"(iter: {self.n_iter_}, step: {self.n_step_ % X_len})" + ) + + def _early_stop(self, X: np.ndarray) -> bool: + """Early stop the Stochastic Quantization solver + based on the relative difference between the last two objective function values. + Parameters + ---------- + X : np.ndarray + The input tensor containing training element {ξᵢ}. + Returns + ------- + bool + True if the relative difference is less than the tolerance, False otherwise. + """ + current_loss = self._centroid_storage.calculate_loss(X) + self.iteration_loss_history_.append(current_loss) + return self.iteration_loss_history_[-2] - current_loss < self._tol + + @staticmethod + def _optimize(centroid_storage: CentroidStorage, optim: BaseOptimizer, ksi_j: np.array, rank: int, + learning_rate: Union[float, np.float64]): + """Perform optimization step for a single sample. + Parameters + ---------- + centroid_storage : CentroidStorage + Centroid storage object containing the current quantized distribution. + optim : BaseOptimizer + Optimizer object used for gradient descent. + ksi_j : np.ndarray + Sample from the input tensor {ξᵢ}. + """ + nearest_quant, quant_ind = centroid_storage.find_nearest_centroid(ksi_j) + + grad_fn = ( + lambda x: rank * np.linalg.norm(ksi_j - x, ord=2) ** (rank - 2) * (x - ksi_j) + ) + + delta = optim.step(grad_fn, nearest_quant, learning_rate) + centroid_storage.update_centroid(quant_ind, delta) + + def predict(self, X: np.ndarray, n_jobs: int = 1): """Predict the closest optimal quant {yₖ} each sample in X belongs to. Parameters ---------- X : np.ndarray New data to predict. + n_jobs : int, default=1 + The number of jobs to run in parallel. If -1, use all processors. Returns ------- @@ -289,8 +367,17 @@ def predict(self, X: np.ndarray): check_is_fitted(self) - pairwise_distance = np.linalg.norm( - X[:, np.newaxis] - self.cluster_centers_, axis=-1 - ) + _predict = lambda storage, target: storage.find_nearest_centroid(target)[1] + + if n_jobs == 1: + clusters = [ + _predict(self._centroid_storage, target) + for target in tqdm(X, desc="Prediction of the closet cluster", disable=not self._verbose) + ] + else: + with tqdm_joblib(total=len(X), desc="Prediction of the closet cluster", disable=not self._verbose): + clusters = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(_predict)(self._centroid_storage, target) for target in X + ) - return np.argmin(pairwise_distance, axis=-1) + return clusters diff --git a/code/sqg/utils.py b/code/sqg/utils.py new file mode 100644 index 0000000..675e0ff --- /dev/null +++ b/code/sqg/utils.py @@ -0,0 +1,10 @@ +from itertools import islice + + +def batched_iterable(iterable, batch_size): + iterator = iter(iterable) + while True: + batch = list(islice(iterator, batch_size + 1)) + if not batch: + break + yield batch From 5c0db3d66352176a22648e8151e51d7ae987e2d9 Mon Sep 17 00:00:00 2001 From: Ostap Bodnar Date: Sat, 7 Jun 2025 17:13:35 +0300 Subject: [PATCH 2/7] fix: update imports after rebasing --- code/sqg/centroids_storage/__init__.py | 6 +++--- code/sqg/centroids_storage/faiss_storage.py | 4 ++-- code/sqg/centroids_storage/numpy_storage.py | 7 ++----- code/sqg/progress_tracking/__init__.py | 2 +- code/sqg/quantization.py | 6 +++--- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/code/sqg/centroids_storage/__init__.py b/code/sqg/centroids_storage/__init__.py index 2a7ddcb..5ecf61a 100644 --- a/code/sqg/centroids_storage/__init__.py +++ b/code/sqg/centroids_storage/__init__.py @@ -1,6 +1,6 @@ -from centroids_storage.factory import CentroidStorage, CentroidStorageFactory, StorageBackendType -from centroids_storage.faiss_storage import FaissIndexBasedCentroidStorage -from centroids_storage.numpy_storage import NumpyCentroidStorage, NumpyMemmapCentroidStorage +from .factory import CentroidStorage, CentroidStorageFactory, StorageBackendType +from .faiss_storage import FaissIndexBasedCentroidStorage +from .numpy_storage import NumpyCentroidStorage, NumpyMemmapCentroidStorage __all__ = [ "CentroidStorage", diff --git a/code/sqg/centroids_storage/faiss_storage.py b/code/sqg/centroids_storage/faiss_storage.py index 536ab48..300d487 100644 --- a/code/sqg/centroids_storage/faiss_storage.py +++ b/code/sqg/centroids_storage/faiss_storage.py @@ -1,7 +1,7 @@ import numpy as np -from centroids_storage.factory import CentroidStorageFactory -from centroids_storage.numpy_storage import NumpyMemmapCentroidStorage +from .factory import CentroidStorageFactory +from .numpy_storage import NumpyMemmapCentroidStorage def _load_faiss(): diff --git a/code/sqg/centroids_storage/numpy_storage.py b/code/sqg/centroids_storage/numpy_storage.py index f51fda8..a3e8c8b 100644 --- a/code/sqg/centroids_storage/numpy_storage.py +++ b/code/sqg/centroids_storage/numpy_storage.py @@ -1,10 +1,7 @@ -import os -import tempfile - import numpy as np -from centroids_storage.factory import CentroidStorageFactory, CentroidStorage -from centroids_storage.init import init_centroids +from .factory import CentroidStorageFactory, CentroidStorage +from .init import init_centroids @CentroidStorageFactory.register() diff --git a/code/sqg/progress_tracking/__init__.py b/code/sqg/progress_tracking/__init__.py index 2be92eb..f492197 100644 --- a/code/sqg/progress_tracking/__init__.py +++ b/code/sqg/progress_tracking/__init__.py @@ -1,3 +1,3 @@ -from progress_tracking.tqdm_joblib import tqdm_joblib +from .tqdm_joblib import tqdm_joblib __all__ = ["tqdm_joblib"] diff --git a/code/sqg/quantization.py b/code/sqg/quantization.py index a743d37..715310d 100644 --- a/code/sqg/quantization.py +++ b/code/sqg/quantization.py @@ -6,9 +6,9 @@ from sklearn.utils.validation import check_is_fitted, check_random_state from tqdm.autonotebook import tqdm -from centroids_storage.factory import CentroidStorage, StorageBackendType, CentroidStorageFactory -from progress_tracking import tqdm_joblib -from utils import batched_iterable +from .centroids_storage.factory import CentroidStorage, StorageBackendType, CentroidStorageFactory +from .progress_tracking import tqdm_joblib +from .utils import batched_iterable from .optim import BaseOptimizer From 97aff9e8308a7fca6f7068c8584ee068f814eb86 Mon Sep 17 00:00:00 2001 From: Ostap Bodnar Date: Sun, 13 Jul 2025 19:50:04 +0300 Subject: [PATCH 3/7] fix: make tqdm optional --- code/pyproject.toml | 3 ++- code/setup.py | 5 ++-- code/sqg/progress_tracking/__init__.py | 3 ++- code/sqg/progress_tracking/tqdm_joblib.py | 3 ++- code/sqg/progress_tracking/tqdm_wrapper.py | 27 ++++++++++++++++++++++ code/sqg/quantization.py | 11 ++++----- 6 files changed, 41 insertions(+), 11 deletions(-) create mode 100644 code/sqg/progress_tracking/tqdm_wrapper.py diff --git a/code/pyproject.toml b/code/pyproject.toml index 4b5c166..1b6757e 100644 --- a/code/pyproject.toml +++ b/code/pyproject.toml @@ -18,7 +18,6 @@ license = { file = "LICENSE.code.md" } dependencies = [ "numpy>=1.26.4,<2", "scikit-learn>=1.5.1,<2", - "tqdm>=4.66.0,<5", ] requires-python = ">=3.8" keywords = [ @@ -52,3 +51,5 @@ Repository = "https://github.com/kaydotdev/stochastic-quantization.git" [project.optional-dependencies] faiss-cpu = ["faiss-cpu>=1.10.0,<2"] faiss-gpu = ["faiss-gpu>=1.10.0,<2"] +progress = ["tqdm>=4.66.0,<5"] +all = ["sqg[faiss-cpu,faiss-gpu,progress]"] diff --git a/code/setup.py b/code/setup.py index 4ecd5cc..d72c0c9 100644 --- a/code/setup.py +++ b/code/setup.py @@ -44,10 +44,11 @@ install_requires=[ "numpy>=1.26.4,<2", "scikit-learn>=1.5.1,<2", - "tqdm>=4.66.0,<5", ], extras_require={ "faiss-cpu": ["faiss-cpu>=1.10.0,<2"], - "faiss-gpu": ["faiss-gpu>=1.10.0,<2"] + "faiss-gpu": ["faiss-gpu>=1.10.0,<2"], + "progress": ["tqdm>=4.66.0,<5"], + "all": ["sqg[faiss-cpu,faiss-gpu,progress]"] }, ) diff --git a/code/sqg/progress_tracking/__init__.py b/code/sqg/progress_tracking/__init__.py index f492197..182b35c 100644 --- a/code/sqg/progress_tracking/__init__.py +++ b/code/sqg/progress_tracking/__init__.py @@ -1,3 +1,4 @@ +from .tqdm_wrapper import tqdm from .tqdm_joblib import tqdm_joblib -__all__ = ["tqdm_joblib"] +__all__ = ["tqdm_joblib", "tqdm"] diff --git a/code/sqg/progress_tracking/tqdm_joblib.py b/code/sqg/progress_tracking/tqdm_joblib.py index 801f84e..950054a 100644 --- a/code/sqg/progress_tracking/tqdm_joblib.py +++ b/code/sqg/progress_tracking/tqdm_joblib.py @@ -2,7 +2,8 @@ import contextlib import joblib -from tqdm.autonotebook import tqdm + +from .tqdm_wrapper import tqdm @contextlib.contextmanager diff --git a/code/sqg/progress_tracking/tqdm_wrapper.py b/code/sqg/progress_tracking/tqdm_wrapper.py new file mode 100644 index 0000000..44c352b --- /dev/null +++ b/code/sqg/progress_tracking/tqdm_wrapper.py @@ -0,0 +1,27 @@ +try: + from tqdm.autonotebook import tqdm as _real_tqdm + _TQDM_AVAILABLE = True +except ImportError: + _TQDM_AVAILABLE = False + + +def tqdm(*args, **kwargs): + """Wrapper for tqdm that handles optional dependency. + + If tqdm is available, calls the real tqdm function. + If not available, returns a dummy iterator with a warning (unless disabled). + """ + if _TQDM_AVAILABLE: + return _real_tqdm(*args, **kwargs) + else: + if kwargs.get('disable', False): + return args[0] if args else [] + + import warnings + warnings.warn( + "tqdm is not installed. Progress bars are disabled. " + "Install with: pip install sqg[progress]", + UserWarning, + stacklevel=2 + ) + return args[0] if args else [] diff --git a/code/sqg/quantization.py b/code/sqg/quantization.py index 715310d..9aba3db 100644 --- a/code/sqg/quantization.py +++ b/code/sqg/quantization.py @@ -4,10 +4,9 @@ import numpy as np from sklearn.base import BaseEstimator, ClusterMixin from sklearn.utils.validation import check_is_fitted, check_random_state -from tqdm.autonotebook import tqdm from .centroids_storage.factory import CentroidStorage, StorageBackendType, CentroidStorageFactory -from .progress_tracking import tqdm_joblib +from .progress_tracking import tqdm_joblib, tqdm from .utils import batched_iterable from .optim import BaseOptimizer @@ -257,14 +256,14 @@ def fit(self, X: np.ndarray, y=None, n_jobs: int = 1): if n_jobs == 1: for ksi_j in tqdm( - ksi, total=X_len, desc="Performing cluster optimization", disable=not self._verbose + ksi, total=X_len, desc="Performing cluster optimization", disable=not self._verbose_progress ): self._optimize(self._centroid_storage, self._optim, ksi_j, self._rank, self._learning_rate) self.n_step_ += 1 self.__log_step(X, X_len) else: with tqdm_joblib( - total=X_len, desc="Performing cluster optimization", disable=not self._verbose): + total=X_len, desc="Performing cluster optimization", disable=not self._verbose_progress): size = self._log_step or X_len for ksi_batch in batched_iterable(ksi, size): joblib.Parallel(n_jobs=n_jobs, max_nbytes=self._kwargs.get('joblib_max_nbytes', '50M'))( @@ -372,10 +371,10 @@ def predict(self, X: np.ndarray, n_jobs: int = 1): if n_jobs == 1: clusters = [ _predict(self._centroid_storage, target) - for target in tqdm(X, desc="Prediction of the closet cluster", disable=not self._verbose) + for target in tqdm(X, desc="Prediction of the closet cluster", disable=not self._verbose_progress) ] else: - with tqdm_joblib(total=len(X), desc="Prediction of the closet cluster", disable=not self._verbose): + with tqdm_joblib(total=len(X), desc="Prediction of the closet cluster", disable=not self._verbose_progress): clusters = joblib.Parallel(n_jobs=n_jobs)( joblib.delayed(_predict)(self._centroid_storage, target) for target in X ) From 3234e34918dcc9dd072a2b27aecb746428e27744 Mon Sep 17 00:00:00 2001 From: Ostap Bodnar Date: Sun, 13 Jul 2025 19:50:30 +0300 Subject: [PATCH 4/7] fix: use built-in-batch for python 3.12+ --- code/sqg/utils.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/code/sqg/utils.py b/code/sqg/utils.py index 675e0ff..9102791 100644 --- a/code/sqg/utils.py +++ b/code/sqg/utils.py @@ -1,10 +1,30 @@ -from itertools import islice +from sys import version_info + +import itertools + + +if version_info >= (3, 12) and hasattr(itertools, "batched"): + # Built-in since 3.12 (returns tuples) + batched = itertools.batched # type: ignore[attr-defined] +else: + def batched(iterable, n, *, strict=False): + """Back-port of itertools.batched for Py < 3.12 (returns tuples).""" + + if n < 1: + raise ValueError("n must be >= 1") + + it = iter(iterable) + while (chunk := tuple(itertools.islice(it, n))): + if strict and len(chunk) != n: + raise ValueError("last batch smaller than n") + + yield chunk def batched_iterable(iterable, batch_size): iterator = iter(iterable) while True: - batch = list(islice(iterator, batch_size + 1)) + batch = list(itertools.islice(iterator, batch_size + 1)) if not batch: break yield batch From 17e65ae55458bc2998ce69ad25747bac3651e9ea Mon Sep 17 00:00:00 2001 From: Ostap Bodnar Date: Sun, 13 Jul 2025 20:02:41 +0300 Subject: [PATCH 5/7] fix: file cleanup issue --- code/sqg/centroids_storage/factory.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/code/sqg/centroids_storage/factory.py b/code/sqg/centroids_storage/factory.py index 9c872e3..c55f47a 100644 --- a/code/sqg/centroids_storage/factory.py +++ b/code/sqg/centroids_storage/factory.py @@ -244,13 +244,19 @@ def create( if isinstance(storage_type, str) and storage_type in cls._implementations: kwargs = deepcopy(kwargs) storage_implementation, requires_filepath = cls._implementations[storage_type] + memory_file = None if requires_filepath and "filepath" not in kwargs: memory_file = tempfile.NamedTemporaryFile() kwargs["filepath"] = memory_file.name print(kwargs["filepath"]) + + def cleanup(): + if memory_file is not None and not kwargs.get("keep_filepath"): + memory_file.close() + return storage_implementation( n_clusters=n_clusters, init=init, **kwargs - ), (lambda: memory_file.close()) if not kwargs.get("keep_filepath") else lambda: None + ), cleanup raise ValueError( f"Unknown storage type: {storage_type}, supported types are: {sorted(cls._implementations.keys())} " f"or {CentroidStorage.__name__} instance" From 0be6790b0d149cbe9848fa76e0c938bc347ba0e6 Mon Sep 17 00:00:00 2001 From: Ostap Bodnar Date: Sun, 13 Jul 2025 20:20:32 +0300 Subject: [PATCH 6/7] fix: tests --- code/sqg/centroids_storage/numpy_storage.py | 3 +- code/sqg/quantization.py | 11 ++++ code/tests/test_calculate_loss.py | 68 +++++++++++++-------- code/tests/test_find_nearest_element.py | 42 +++++++++---- code/tests/test_stochastic_quantization.py | 9 ++- 5 files changed, 93 insertions(+), 40 deletions(-) diff --git a/code/sqg/centroids_storage/numpy_storage.py b/code/sqg/centroids_storage/numpy_storage.py index a3e8c8b..5af7646 100644 --- a/code/sqg/centroids_storage/numpy_storage.py +++ b/code/sqg/centroids_storage/numpy_storage.py @@ -1,4 +1,5 @@ import numpy as np +from sklearn.exceptions import NotFittedError from .factory import CentroidStorageFactory, CentroidStorage from .init import init_centroids @@ -25,7 +26,7 @@ def __init__(self, n_clusters: int, init: str | np.ndarray = "k-means++", *args, @property def centroids(self) -> np.ndarray: if self._centroids is None: - raise ValueError("Centroids have not been initialized yet.") + raise NotFittedError("Centroids have not been initialized yet.") return self._centroids def init_centroids(self, x: np.ndarray, random_state: np.random.RandomState): diff --git a/code/sqg/quantization.py b/code/sqg/quantization.py index 9aba3db..0542dde 100644 --- a/code/sqg/quantization.py +++ b/code/sqg/quantization.py @@ -156,6 +156,17 @@ def __del__(self): def centroids(self) -> np.ndarray: return self._centroid_storage.centroids + @property + def cluster_centers_(self) -> np.ndarray: + """Returns the cluster centers (centroids). + + Returns + ------- + np.ndarray + The cluster centers. + """ + return self._centroid_storage.centroids + def _shuffle_ksi(self, X: np.ndarray, random_state: np.random.RandomState): """Shuffle the input tensor {ξᵢ} based on the specified element selection method. Parameters diff --git a/code/tests/test_calculate_loss.py b/code/tests/test_calculate_loss.py index 669581b..af3c889 100644 --- a/code/tests/test_calculate_loss.py +++ b/code/tests/test_calculate_loss.py @@ -1,8 +1,7 @@ import unittest import numpy as np - -from sqg.quantization import _calculate_loss +from sqg.centroids_storage.factory import CentroidStorageFactory class TestCalculateLoss(unittest.TestCase): @@ -11,64 +10,77 @@ def setUp(self): def test_should_raise_value_error_if_shape_mismatch(self): # arrange - x = self.random_state.random((10, 2, 2)) - y = self.random_state.random((10, 1, 2)) + x = self.random_state.random((10, 2)) + y = self.random_state.random((10, 3)) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=10, init=y) # assert with self.assertRaises(ValueError): # act - _calculate_loss(x, y) + storage.init_centroids(x, self.random_state) + cleanup() def test_should_raise_value_error_if_different_axis(self): # arrange - x = self.random_state.random((1, 2, 2)) - y = self.random_state.random((1, 2, 2, 2)) + x = self.random_state.random((1, 2)) + y = np.array([[[1.0, 2.0], [3.0, 4.0]]]) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=1, init=y) # assert with self.assertRaises(ValueError): # act - _calculate_loss(x, y) + storage.init_centroids(x, self.random_state) + cleanup() def test_should_raise_value_error_if_one_of_distributions_is_empty(self): # arrange - x = self.random_state.random((1, 2, 2)) + x = self.random_state.random((1, 2)) y = np.array([]) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=1, init=y) # assert with self.assertRaises(ValueError): # act - _calculate_loss(x, y) + storage.calculate_loss(x) + cleanup() def test_should_return_distance_for_distributions_with_different_size(self): # arrange - expected_distance = 7.17435203 + expected_distance = 4.141985 # Updated for 2D data x = np.array( [ - [[0.37454012, 0.95071431], [0.73199394, 0.59865848]], - [[0.15601864, 0.15599452], [0.05808361, 0.86617615]], - [[0.60111501, 0.70807258], [0.02058449, 0.96990985]], - [[0.83244264, 0.21233911], [0.18182497, 0.18340451]], - [[0.30424224, 0.52475643], [0.43194502, 0.29122914]], - [[0.61185289, 0.13949386], [0.29214465, 0.36636184]], - [[0.45606998, 0.78517596], [0.19967378, 0.51423444]], - [[0.59241457, 0.04645041], [0.60754485, 0.17052412]], - [[0.06505159, 0.94888554], [0.96563203, 0.80839735]], - [[0.30461377, 0.09767211], [0.68423303, 0.44015249]], + [0.37454012, 0.95071431], + [0.15601864, 0.15599452], + [0.60111501, 0.70807258], + [0.83244264, 0.21233911], + [0.30424224, 0.52475643], + [0.61185289, 0.13949386], + [0.45606998, 0.78517596], + [0.59241457, 0.04645041], + [0.06505159, 0.94888554], + [0.30461377, 0.09767211], ] ) y = np.array( [ - [[0.12203823, 0.49517691], [0.03438852, 0.9093204]], - [[0.25877998, 0.66252228], [0.31171108, 0.52006802]], + [0.12203823, 0.49517691], + [0.25877998, 0.66252228], ] ) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=2, init=y) + storage.init_centroids(x, self.random_state) # act - actual_distance = _calculate_loss(x, y) + actual_distance = storage.calculate_loss(x) # assert - self.assertAlmostEqual(actual_distance, expected_distance, places=7) + self.assertAlmostEqual(actual_distance, expected_distance, places=6) + cleanup() def test_should_return_zero_distance_for_identical_tensors(self): # arrange @@ -76,12 +88,16 @@ def test_should_return_zero_distance_for_identical_tensors(self): x = np.array([[1.0, 1.0], [1.0, 1.0]]) y = np.array([[1.0, 1.0], [1.0, 1.0]]) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=2, init=y) + storage.init_centroids(x, self.random_state) # act - actual_distance = _calculate_loss(x, y) + actual_distance = storage.calculate_loss(x) # assert self.assertAlmostEqual(actual_distance, expected_distance, places=7) + cleanup() if __name__ == "__main__": diff --git a/code/tests/test_find_nearest_element.py b/code/tests/test_find_nearest_element.py index 51cf27e..e9d37fb 100644 --- a/code/tests/test_find_nearest_element.py +++ b/code/tests/test_find_nearest_element.py @@ -1,8 +1,7 @@ import unittest import numpy as np - -from sqg.quantization import _find_nearest_element +from sqg.centroids_storage.factory import CentroidStorageFactory class TestFindNearestElement(unittest.TestCase): @@ -11,33 +10,44 @@ def setUp(self): def test_should_raise_value_error_if_shape_mismatch(self): # arrange - x = self.random_state.random((10, 2, 2)) - y = self.random_state.random((1, 2)) + x = self.random_state.random((10, 2)) + y = self.random_state.random((1, 3)) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=10, init=x) + storage.init_centroids(x, self.random_state) # assert with self.assertRaises(ValueError): # act - _find_nearest_element(x, y) + storage.find_nearest_centroid(y) + cleanup() def test_should_raise_value_error_if_different_axis(self): # arrange - x = self.random_state.random((1, 2, 2)) - y = self.random_state.random((1, 2, 2, 2)) + x = self.random_state.random((1, 2)) + y = np.array([[[1.0, 2.0], [3.0, 4.0]]]) # 3D array for init + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=1, init=y) # assert with self.assertRaises(ValueError): # act - _find_nearest_element(x, y) + storage.init_centroids(x, self.random_state) + cleanup() def test_should_raise_value_error_if_one_of_distributions_is_empty(self): # arrange - x = self.random_state.random((1, 2, 2)) + x = self.random_state.random((1, 2)) y = np.array([]) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=1, init=x) + storage.init_centroids(x, self.random_state) # assert with self.assertRaises(ValueError): # act - _find_nearest_element(x, y) + storage.find_nearest_centroid(y) + cleanup() def test_should_return_nearest_element_with_index(self): # arrange @@ -57,15 +67,19 @@ def test_should_return_nearest_element_with_index(self): ] ) y = np.array([0.0, 0.0]) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=8, init=x) + storage.init_centroids(x, self.random_state) # act - actual_element, actual_index = _find_nearest_element(x, y) + actual_element, actual_index = storage.find_nearest_centroid(y) # assert np.testing.assert_allclose( actual_element, expected_element, rtol=1e-3, atol=1e-3 ) self.assertEqual(actual_index, expected_index) + cleanup() def test_should_return_first_index_of_multiple_nearest_elements(self): # arrange @@ -85,15 +99,19 @@ def test_should_return_first_index_of_multiple_nearest_elements(self): ] ) y = np.array([0.0, 0.0]) + + storage, cleanup = CentroidStorageFactory.create("numpy", n_clusters=8, init=x) + storage.init_centroids(x, self.random_state) # act - actual_element, actual_index = _find_nearest_element(x, y) + actual_element, actual_index = storage.find_nearest_centroid(y) # assert np.testing.assert_allclose( actual_element, expected_element, rtol=1e-3, atol=1e-3 ) self.assertEqual(actual_index, expected_index) + cleanup() if __name__ == "__main__": diff --git a/code/tests/test_stochastic_quantization.py b/code/tests/test_stochastic_quantization.py index 77c9d67..1a3ae9f 100644 --- a/code/tests/test_stochastic_quantization.py +++ b/code/tests/test_stochastic_quantization.py @@ -11,7 +11,7 @@ class TestStochasticQuantization(unittest.TestCase): def setUp(self): self.random_state = np.random.RandomState(seed=42) self.algorithm = StochasticQuantization( - SGDOptimizer(), n_clusters=2, max_iter=1, random_state=self.random_state + SGDOptimizer(), n_clusters=2, max_iter=1, random_state=self.random_state, log_step=100 ) self.X = np.array( [ @@ -54,6 +54,7 @@ def test_should_raise_value_error_if_initial_distribution_size_and_cluster_numbe max_iter=1, random_state=self.random_state, init=np.array([[[0.0, 0.0], [0.0, 0.0]]]), + log_step=1, ) # assert @@ -71,6 +72,7 @@ def test_should_raise_value_error_if_dimensions_of_quantized_distribution_and_in max_iter=1, random_state=self.random_state, init=np.array([[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + log_step=1, ) # assert @@ -88,6 +90,7 @@ def test_should_return_init_quants_if_input_tensor_contains_single_element_with_ max_iter=1, random_state=self.random_state, init="k-means++", + log_step=1, ) X = np.array( @@ -113,6 +116,7 @@ def test_should_return_init_quants_if_max_iteration_is_zero_with_kmeans_plus_plu max_iter=0, random_state=self.random_state, init="k-means++", + log_step=1, ) X = np.array( @@ -146,6 +150,7 @@ def test_should_return_optimal_quants_for_sampling_strategy(self): max_iter=1, random_state=self.random_state, init="sample", + log_step=100 ) expected_cluster_centers = np.array( @@ -178,6 +183,7 @@ def test_should_return_optimal_quants_for_uniformly_distributed_quants_strategy( max_iter=1, random_state=self.random_state, init="random", + log_step=100 ) expected_cluster_centers = np.array( @@ -210,6 +216,7 @@ def test_should_return_optimal_quants_for_kmeans_plus_plus_strategy( max_iter=1, random_state=self.random_state, init="k-means++", + log_step=100 ) expected_cluster_centers = np.array( From 49e8645efc11b76d9c8f8d368ba00965e3cd20bb Mon Sep 17 00:00:00 2001 From: Ostap Bodnar Date: Sun, 13 Jul 2025 20:46:08 +0300 Subject: [PATCH 7/7] fix: tqdm issues --- code/sqg/progress_tracking/tqdm_joblib.py | 6 +++++- code/sqg/progress_tracking/tqdm_wrapper.py | 25 +++++++++++++--------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/code/sqg/progress_tracking/tqdm_joblib.py b/code/sqg/progress_tracking/tqdm_joblib.py index 950054a..8f3317e 100644 --- a/code/sqg/progress_tracking/tqdm_joblib.py +++ b/code/sqg/progress_tracking/tqdm_joblib.py @@ -3,7 +3,7 @@ import joblib -from .tqdm_wrapper import tqdm +from .tqdm_wrapper import tqdm, _TQDM_AVAILABLE @contextlib.contextmanager @@ -13,6 +13,10 @@ def tqdm_joblib(*args, **kwargs): tqdm_object = tqdm(*args, **kwargs) + if not _TQDM_AVAILABLE: + yield tqdm_object + return + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/code/sqg/progress_tracking/tqdm_wrapper.py b/code/sqg/progress_tracking/tqdm_wrapper.py index 44c352b..56379cd 100644 --- a/code/sqg/progress_tracking/tqdm_wrapper.py +++ b/code/sqg/progress_tracking/tqdm_wrapper.py @@ -1,27 +1,32 @@ +import warnings + try: from tqdm.autonotebook import tqdm as _real_tqdm _TQDM_AVAILABLE = True + _WARNING_SHOWN = False except ImportError: _TQDM_AVAILABLE = False + _WARNING_SHOWN = False def tqdm(*args, **kwargs): """Wrapper for tqdm that handles optional dependency. If tqdm is available, calls the real tqdm function. - If not available, returns a dummy iterator with a warning (unless disabled). + If not available, returns the original iterable with a warning (shown only once). """ + global _WARNING_SHOWN + if _TQDM_AVAILABLE: return _real_tqdm(*args, **kwargs) else: - if kwargs.get('disable', False): - return args[0] if args else [] + if not _WARNING_SHOWN: + warnings.warn( + "tqdm is not installed. Progress bars are disabled. " + "Install with: pip install sqg[progress]", + UserWarning, + stacklevel=2 + ) + _WARNING_SHOWN = True - import warnings - warnings.warn( - "tqdm is not installed. Progress bars are disabled. " - "Install with: pip install sqg[progress]", - UserWarning, - stacklevel=2 - ) return args[0] if args else []