Skip to content

feat(sq): algorithm optimization #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions code/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ name = "sqg"
version = "1.0.0"
description = "A robust and scalable alternative to existing K-means solvers."
authors = [
{name = "Vladimir Norkin", email = "[email protected]"},
{name = "Anton Kozyriev", email = "[email protected]"},
{ name = "Vladimir Norkin", email = "[email protected]" },
{ name = "Anton Kozyriev", email = "[email protected]" },
]
readme = "README.md"
license = {file = "LICENSE.code.md"}
license = { file = "LICENSE.code.md" }
dependencies = [
"numpy>=1.26.4,<2",
"scikit-learn>=1.5.1,<2",
Expand Down Expand Up @@ -47,3 +47,9 @@ classifiers = [
Homepage = "https://github.com/kaydotdev/stochastic-quantization"
Issues = "https://github.com/kaydotdev/stochastic-quantization/issues"
Repository = "https://github.com/kaydotdev/stochastic-quantization.git"

[project.optional-dependencies]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to group optional dependencies, something like this:

[project.optional-dependencies]
faiss-cpu = ["faiss-cpu>=1.10.0,<2"]
faiss-gpu = ["faiss-gpu>=1.10.0,<2"]
progress = ["tqdm>=4.66.0,<5"]
all = ["sqg[faiss-cpu,faiss-gpu,progress]"]

faiss-cpu = ["faiss-cpu>=1.10.0,<2"]
faiss-gpu = ["faiss-gpu>=1.10.0,<2"]
progress = ["tqdm>=4.66.0,<5"]
all = ["sqg[faiss-cpu,faiss-gpu,progress]"]
7 changes: 6 additions & 1 deletion code/setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from setuptools import setup, find_packages


if __name__ == "__main__":
setup(
name="sqg",
Expand Down Expand Up @@ -46,4 +45,10 @@
"numpy>=1.26.4,<2",
"scikit-learn>=1.5.1,<2",
],
extras_require={
"faiss-cpu": ["faiss-cpu>=1.10.0,<2"],
"faiss-gpu": ["faiss-gpu>=1.10.0,<2"],
"progress": ["tqdm>=4.66.0,<5"],
"all": ["sqg[faiss-cpu,faiss-gpu,progress]"]
},
)
12 changes: 12 additions & 0 deletions code/sqg/centroids_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .factory import CentroidStorage, CentroidStorageFactory, StorageBackendType
from .faiss_storage import FaissIndexBasedCentroidStorage
from .numpy_storage import NumpyCentroidStorage, NumpyMemmapCentroidStorage

__all__ = [
"CentroidStorage",
"CentroidStorageFactory",
"StorageBackendType",
"NumpyCentroidStorage",
"NumpyMemmapCentroidStorage",
"FaissIndexBasedCentroidStorage",
]
263 changes: 263 additions & 0 deletions code/sqg/centroids_storage/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import abc
import tempfile
from copy import deepcopy
from typing import Literal, Callable

import numpy as np


class CentroidStorage(abc.ABC):
"""
Abstract base class for centroid storage implementations.

Parameters
----------
n_clusters : int
The number of clusters (centroids) to initialize.
init : str or np.ndarray, optional
Method for initialization. Can be 'k-means++' or an ndarray of initial centroids.
"""

def __init__(self, n_clusters: int, init: str | np.ndarray = "k-means++", *args, **kwargs):
self._n_clusters = n_clusters
self._init = init

@property
def n_clusters(self) -> int:
"""
Returns the number of clusters.

Returns
-------
int
The number of clusters.
"""
return self._n_clusters

@property
@abc.abstractmethod
def centroids(self):
"""
Returns the centroids.

Returns
-------
np.ndarray
The centroids.

Raises
------
ValueError
If the centroids have not been initialized yet.
"""
raise NotImplementedError()

@property
@abc.abstractmethod
def name(self) -> str:
"""
Returns the name of the centroid storage implementation.

Returns
-------
str
The name of the centroid storage implementation.

Raises
------
NotImplementedError
If the method is not implemented by the subclass.
"""
raise NotImplementedError

@abc.abstractmethod
def init_centroids(self, X: np.ndarray, random_state: np.random.RandomState):
"""
Initializes the centroids.

Parameters
----------
X : np.ndarray
The data to initialize the centroids.
random_state : np.random.RandomState
The random state for reproducibility.

Raises
------
NotImplementedError
If the method is not implemented by the subclass.
"""
raise NotImplementedError

@abc.abstractmethod
def find_nearest_centroid(self, target: np.ndarray) -> tuple[np.ndarray, np.uint]:
"""
Finds the nearest centroid to the target.

Parameters
----------
target : np.ndarray
The target data point.

Returns
-------
tuple[np.ndarray, np.uint]
The nearest centroid and its index.

Raises
------
NotImplementedError
If the method is not implemented by the subclass.
"""
raise NotImplementedError

@abc.abstractmethod
def update_centroid(self, index: np.uint, delta: np.ndarray):
"""
Updates the centroid at the given index.

Parameters
----------
index : np.uint
The index of the centroid to update.
delta : np.ndarray
The change to apply to the centroid.

Raises
------
NotImplementedError
If the method is not implemented by the subclass.
"""
raise NotImplementedError

def calculate_loss(self, X: np.ndarray) -> np.float64:
"""Calculates stochastic Wasserstein (or Kantorovich–Rubinstein) distance between distributions ξ and y:

F(y) = Σᵢ₌₁ᴵ pᵢ min₁≤k≤K d(ξᵢ, yₖ)ʳ

Parameters
----------
xi : np.ndarray
The original distribution ξ with shape (N, D, ...).
y : np.ndarray
The quantized distribution y with shape (M, D, ...).

Returns
-------
np.float64
The calculated stochastic Wasserstein distance between distributions ξ and y.

Raises
------
ValueError
If one of the distributions ξ or y is empty.
ValueError
If there is a shape mismatch between individual elements in distribution ξ and y.

Notes
-----
The function assumes uniform weights (pᵢ = 1) for all elements in the original distribution.
The exponent r in the formula is implicitly set to 1 in this implementation.
"""

if X.size == 0 or self.centroids.size == 0:
raise ValueError("One of the distributions `X` or `centroids` is empty.")

if X.shape[1:] != self.centroids.shape[1:]:
raise ValueError(
"The dimensions of individual elements in distribution `X` and `centroids` must match. Elements in "
f"`X` have shape {X.shape[1:]}, but y elements have shape {self.centroids.shape[1:]}."
)

distances = [
np.linalg.norm(self.find_nearest_centroid(ksi)[0] - ksi) for ksi in X
]

return np.sum(distances)


StorageBackendType = Literal["numpy", "numpy_memmap", "faiss"] | CentroidStorage


class CentroidStorageFactory:
"""
Factory class for creating centroid storage instances.
"""
_implementations: dict[str, tuple[type[CentroidStorage], bool]] = dict()

@classmethod
def register(cls, requires_filepath: bool = False):
"""
Registers a new centroid storage implementation.

Parameters
----------
storage_type : type[CentroidStorage]
The centroid storage implementation to register.

Returns
-------
type[CentroidStorage]
The registered centroid storage implementation.
"""

def decorator(storage_type: type[CentroidStorage]):
cls._implementations[storage_type.name] = (storage_type, requires_filepath)
return storage_type

return decorator

@classmethod
def create(
cls,
storage_type: StorageBackendType,
n_clusters: int,
init: str | np.ndarray = "k-means++",
**kwargs
) -> tuple[CentroidStorage, Callable[[], None]]:
"""
Creates a centroid storage instance.

Parameters
----------
storage_type : StorageBackendType
The type of storage backend to use.
n_clusters : int
The number of clusters (centroids) to initialize.
init : str or np.ndarray, optional
Method for initialization. Can be 'k-means++' or an ndarray of initial centroids.
**kwargs
Additional keyword arguments for the storage implementation.

Returns
-------
CentroidStorage
The created centroid storage instance.

Raises
------
ValueError
If the storage type is unknown.
"""
if isinstance(storage_type, CentroidStorage):
return storage_type, lambda: None
if isinstance(storage_type, str) and storage_type in cls._implementations:
kwargs = deepcopy(kwargs)
storage_implementation, requires_filepath = cls._implementations[storage_type]
memory_file = None
if requires_filepath and "filepath" not in kwargs:
memory_file = tempfile.NamedTemporaryFile()
kwargs["filepath"] = memory_file.name
print(kwargs["filepath"])

def cleanup():
if memory_file is not None and not kwargs.get("keep_filepath"):
memory_file.close()

return storage_implementation(
n_clusters=n_clusters, init=init, **kwargs
), cleanup
raise ValueError(
f"Unknown storage type: {storage_type}, supported types are: {sorted(cls._implementations.keys())} "
f"or {CentroidStorage.__name__} instance"
)
Loading