1212import numpy as np
1313from sklearn .metrics import pairwise_distances
1414from sklearn .base import BaseEstimator , ClusterMixin
15- from sklearn .utils .validation import check_array
15+ from sklearn .utils .validation import check_array , validate_data , check_random_state
1616
1717from radius_clustering .utils ._emos import py_emos_main
1818from radius_clustering .utils ._mds_approx import solve_mds
1919
2020DIR_PATH = os .path .dirname (os .path .realpath (__file__ ))
2121
2222
23- class RadiusClustering (BaseEstimator , ClusterMixin ):
23+ class RadiusClustering (ClusterMixin , BaseEstimator ):
2424 """
2525 Radius Clustering algorithm.
2626
@@ -42,29 +42,56 @@ class RadiusClustering(BaseEstimator, ClusterMixin):
4242 The indices of the cluster centers.
4343 labels\_ : array-like, shape (n_samples,)
4444 The cluster labels for each point in the input data.
45- effective_radius : float
45+ effective_radius\_ : float
4646 The maximum distance between any point and its assigned cluster center.
47+ random_state\_ : int | None
48+ The random state used for reproducibility. If None, no random state is set.
49+
50+ .. note::
51+ The `random_state_` attribute is not used when the `manner` is set to "exact".
52+
53+ .. versionadded:: 1.3.0
54+ The *random_state* parameter was added to allow reproducibility in the approximate method.
55+
56+ .. versionchanged:: 1.3.0
57+ All publicly accessible attributes are now suffixed with an underscore (e.g., `centers_`, `labels_`).
58+ This is particularly useful for compatibility with scikit-learn's API.
4759 """
4860
49- def __init__ (self , manner = "approx" , threshold = 0.5 ):
61+ _estimator_type = "clusterer"
62+
63+ def __init__ (self , manner : str = "approx" , threshold : float = 0.5 , random_state : int | None = None ) -> None :
5064 self .manner = manner
5165 self .threshold = threshold
66+ self .random_state = random_state
5267
53- def _check_symmetric (self , a , tol = 1e-8 ):
68+ def _check_symmetric (self , a : np . ndarray , tol : float = 1e-8 ) -> bool :
5469 if a .ndim != 2 :
5570 raise ValueError ("Input must be a 2D array." )
5671 if a .shape [0 ] != a .shape [1 ]:
5772 return False
5873 return np .allclose (a , a .T , atol = tol )
5974
60- def fit (self , X , y = None ):
75+ def fit (self , X : np . ndarray , y : None = None ) -> "RadiusClustering" :
6176 """
6277 Fit the MDS clustering model to the input data.
6378
79+ This method computes the distance matrix if the input is a feature matrix,
80+ or uses the provided distance matrix directly if the input is already a distance matrix.
81+
82+ .. note::
83+ If the input is a distance matrix, it should be symmetric and square.
84+ If the input is a feature matrix, the distance matrix will be computed using Euclidean distance.
85+
86+ .. tip::
87+ Next version will support providing different metrics or even custom callables to compute the distance matrix.
88+
6489 Parameters:
6590 -----------
6691 X : array-like, shape (n_samples, n_features)
67- The input data to cluster.
92+ The input data to cluster. X should be a 2D array-like structure. It can either be :
93+ - A distance matrix (symmetric, square) with shape (n_samples, n_samples).
94+ - A feature matrix with shape (n_samples, n_features) where the distance matrix will be computed.
6895 y : Ignored
6996 Not used, present here for API consistency by convention.
7097
@@ -91,38 +118,43 @@ def fit(self, X, y=None):
91118 For examples on common datasets and differences with kmeans,
92119 see :ref:`sphx_glr_auto_examples_plot_iris_example.py`
93120 """
94- self .X = check_array ( X )
121+ self .X_checked_ = validate_data ( self , X )
95122
96123 # Create dist and adj matrices
97- if not self ._check_symmetric (self .X ):
98- dist_mat = pairwise_distances (self .X , metric = "euclidean" )
124+ if not self ._check_symmetric (self .X_checked_ ):
125+ dist_mat = pairwise_distances (self .X_checked_ , metric = "euclidean" )
99126 else :
100- dist_mat = self .X
127+ dist_mat = self .X_checked_
101128 adj_mask = np .triu ((dist_mat <= self .threshold ), k = 1 )
102- self .nb_edges = np .sum (adj_mask )
103- if self .nb_edges == 0 :
104- self .centers_ = list (range (self .X .shape [0 ]))
105- self .labels_ = self .centers_
106- self .effective_radius = 0
107- self ._mds_exec_time = 0
129+ self .nb_edges_ = np .sum (adj_mask )
130+ if self .nb_edges_ == 0 :
131+ self .centers_ = list (range (self .X_checked_ .shape [0 ]))
132+ self .labels_ = np . array ( self .centers_ )
133+ self .effective_radius_ = 0
134+ self .mds_exec_time_ = 0
108135 return self
109- self .edges = np .argwhere (adj_mask ).astype (np .uint32 ) #TODO: changer en uint32
110- self .dist_mat = dist_mat
136+ self .edges_ = np .argwhere (adj_mask ).astype (np .uint32 ) # Edges in the adjacency matrix
137+ # uint32 is used to use less memory. Max number of features is 2^32-1
138+ self .dist_mat_ = dist_mat
111139
112140 self ._clustering ()
113141 self ._compute_effective_radius ()
114142 self ._compute_labels ()
115143
116144 return self
117145
118- def fit_predict (self , X , y = None ):
146+ def fit_predict (self , X : np . ndarray , y : None = None ) -> np . ndarray :
119147 """
120148 Fit the model and return the cluster labels.
121149
150+ This method is a convenience function that combines `fit` and `predict`.
151+
122152 Parameters:
123153 -----------
124154 X : array-like, shape (n_samples, n_features)
125- The input data to cluster.
155+ The input data to cluster. X should be a 2D array-like structure. It can either be :
156+ - A distance matrix (symmetric, square) with shape (n_samples, n_samples).
157+ - A feature matrix with shape (n_samples, n_features) where the distance matrix will be computed.
126158 y : Ignored
127159 Not used, present here for API consistency by convention.
128160
@@ -138,13 +170,13 @@ def _clustering(self):
138170 """
139171 Perform the clustering using either the exact or approximate MDS method.
140172 """
141- n = self .X .shape [0 ]
173+ n = self .X_checked_ .shape [0 ]
142174 if self .manner == "exact" :
143175 self ._clustering_exact (n )
144176 else :
145177 self ._clustering_approx (n )
146178
147- def _clustering_exact (self , n ) :
179+ def _clustering_exact (self , n : int ) -> None :
148180 """
149181 Perform exact MDS clustering.
150182
@@ -158,13 +190,26 @@ def _clustering_exact(self, n):
158190 This function uses the EMOS algorithm to solve the MDS problem.
159191 See: [jiang]_ for more details.
160192 """
161- self .centers_ , self ._mds_exec_time = py_emos_main (
162- self .edges .flatten (), n , self .nb_edges
193+ self .centers_ , self .mds_exec_time_ = py_emos_main (
194+ self .edges_ .flatten (), n , self .nb_edges_
163195 )
164196
165- def _clustering_approx (self , n ) :
197+ def _clustering_approx (self , n : int ) -> None :
166198 """
167- Perform approximate MDS clustering.
199+ Perform approximate MDS clustering. This method uses a pretty trick to set the seed for the random state of the C++ code of the MDS solver.
200+
201+ .. tip::
202+ The random state is used to ensure reproducibility of the results when using the approximate method.
203+ If `random_state` is None, a default value of 42 is used.
204+
205+ .. important::
206+ :collapsible: closed
207+ The trick to set the random state is :
208+ 1. Use the `check_random_state` function to get a `RandomState`singleton instance, set up with the provided `random_state`.
209+ 2. Use the `randint` method of the `RandomState` instance to generate a random integer.
210+ 3. Use this random integer as the seed for the C++ code of the MDS solver.
211+
212+ This ensures that the seed passed to the C++ code is always an integer, which is required by the MDS solver, and allows for reproducibility of the results.
168213
169214 Parameters:
170215 -----------
@@ -176,9 +221,13 @@ def _clustering_approx(self, n):
176221 This function uses the approximation method to solve the MDS problem.
177222 See [casado]_ for more details.
178223 """
179- result = solve_mds (n , self .edges .flatten ().astype (np .int32 ), self .nb_edges , "test" )
224+ if self .random_state is None :
225+ self .random_state = 42
226+ self .random_state_ = check_random_state (self .random_state )
227+ seed = self .random_state_ .randint (np .iinfo (np .int32 ).max )
228+ result = solve_mds (n , self .edges_ .flatten ().astype (np .int32 ), self .nb_edges_ , seed )
180229 self .centers_ = [x for x in result ["solution_set" ]]
181- self ._mds_exec_time = result ["Time" ]
230+ self .mds_exec_time_ = result ["Time" ]
182231
183232 def _compute_effective_radius (self ):
184233 """
@@ -187,13 +236,13 @@ def _compute_effective_radius(self):
187236 The effective radius is the maximum radius among all clusters.
188237 That means EffRad = max(R(C_i)) for all i.
189238 """
190- self .effective_radius = np .min (self .dist_mat [:, self .centers_ ], axis = 1 ).max ()
239+ self .effective_radius_ = np .min (self .dist_mat_ [:, self .centers_ ], axis = 1 ).max ()
191240
192241 def _compute_labels (self ):
193242 """
194243 Compute the cluster labels for each point in the dataset.
195244 """
196- distances = self .dist_mat [:, self .centers_ ]
245+ distances = self .dist_mat_ [:, self .centers_ ]
197246 self .labels_ = np .argmin (distances , axis = 1 )
198247
199248 min_dist = np .min (distances , axis = 1 )
0 commit comments