|
7 | 7 | from dislib.data.array import Array
|
8 | 8 | from dislib.neighbors import NearestNeighbors
|
9 | 9 | from sklearn.metrics import accuracy_score
|
| 10 | +from sklearn.neighbors import NearestNeighbors as SKNeighbors |
| 11 | +from sklearn.neighbors import KDTree |
10 | 12 |
|
11 | 13 | from collections import defaultdict
|
12 | 14 |
|
| 15 | +import os |
| 16 | +import json |
| 17 | +import dislib.data.util.model as utilmodel |
| 18 | +import pickle |
| 19 | +from dislib.data.util import sync_obj, decoder_helper, encoder_helper |
| 20 | + |
13 | 21 |
|
14 | 22 | class KNeighborsClassifier(BaseEstimator):
|
15 | 23 | """Classifier implementing the k-nearest neighbors vote.
|
@@ -136,6 +144,104 @@ def score(self, q: Array, y: Array, collect=False):
|
136 | 144 |
|
137 | 145 | return compss_wait_on(score) if collect else score
|
138 | 146 |
|
| 147 | + def save_model(self, filepath, overwrite=True, save_format="json"): |
| 148 | + """Saves a model to a file. |
| 149 | + The model is synchronized before saving and can be reinstantiated |
| 150 | + in the exact same state, without any of the code used for model |
| 151 | + definition or fitting. |
| 152 | + Parameters |
| 153 | + ---------- |
| 154 | + filepath : str |
| 155 | + Path where to save the model |
| 156 | + overwrite : bool, optional (default=True) |
| 157 | + Whether any existing model at the target |
| 158 | + location should be overwritten. |
| 159 | + save_format : str, optional (default='json) |
| 160 | + Format used to save the models. |
| 161 | + Examples |
| 162 | + -------- |
| 163 | + >>> from dislib.classification import KNeighborsClassifier |
| 164 | + >>> import numpy as np |
| 165 | + >>> import dislib as ds |
| 166 | + >>> data = np.array([[0, 0, 5], [3, 0, 5], [3, 1, 2]]) |
| 167 | + >>> y_data = np.array([2, 1, 1, 2, 0]) |
| 168 | + >>> train = ds.array(x=ratings, block_size=(1, 1)) |
| 169 | + >>> knn = KNeighborsClassifier() |
| 170 | + >>> knn.fit(train) |
| 171 | + >>> knn.save_model("./model_KNN") |
| 172 | + """ |
| 173 | + |
| 174 | + # Check overwrite |
| 175 | + if not overwrite and os.path.isfile(filepath): |
| 176 | + return |
| 177 | + |
| 178 | + sync_obj(self.__dict__) |
| 179 | + model_metadata = self.__dict__ |
| 180 | + model_metadata["model_name"] = "knn" |
| 181 | + |
| 182 | + # Save model |
| 183 | + if save_format == "json": |
| 184 | + with open(filepath, "w") as f: |
| 185 | + json.dump(model_metadata, f, default=_encode_helper) |
| 186 | + elif save_format == "cbor": |
| 187 | + if utilmodel.cbor2 is None: |
| 188 | + raise ModuleNotFoundError("No module named 'cbor2'") |
| 189 | + with open(filepath, "wb") as f: |
| 190 | + utilmodel.cbor2.dump(model_metadata, f, |
| 191 | + default=_encode_helper_cbor) |
| 192 | + elif save_format == "pickle": |
| 193 | + with open(filepath, "wb") as f: |
| 194 | + pickle.dump(model_metadata, f) |
| 195 | + else: |
| 196 | + raise ValueError("Wrong save format.") |
| 197 | + |
| 198 | + def load_model(self, filepath, load_format="json"): |
| 199 | + """Loads a model from a file. |
| 200 | + The model is reinstantiated in the exact same state in which it was |
| 201 | + saved, without any of the code used for model definition or fitting. |
| 202 | + Parameters |
| 203 | + ---------- |
| 204 | + filepath : str |
| 205 | + Path of the saved the model |
| 206 | + load_format : str, optional (default='json') |
| 207 | + Format used to load the model. |
| 208 | + Examples |
| 209 | + -------- |
| 210 | + >>> from dislib.clasiffication import KNeighborsClassifier |
| 211 | + >>> import numpy as np |
| 212 | + >>> import dislib as ds |
| 213 | + >>> x_data = np.array([[1, 2], [2, 0], [3, 1], [4, 4], [5, 3]]) |
| 214 | + >>> y_data = np.array([2, 1, 1, 2, 0]) |
| 215 | + >>> x_test_m = np.array([[3, 2], [4, 4], [1, 3]]) |
| 216 | + >>> bn, bm = 2, 2 |
| 217 | + >>> x = ds.array(x=x_data, block_size=(bn, bm)) |
| 218 | + >>> y = ds.array(x=y_data, block_size=(bn, 1)) |
| 219 | + >>> test_data_m = ds.array(x=x_test_m, block_size=(bn, bm)) |
| 220 | + >>> knn = KNeighborsClassifier() |
| 221 | + >>> knn.fit(x, y) |
| 222 | + >>> knn.save_model("./model_KNN") |
| 223 | + >>> knn_loaded = KNeighborsClassifier() |
| 224 | + >>> knn_loaded.load_model("./model_KNN") |
| 225 | + >>> pred = knn_loaded.predict(test_data).collect() |
| 226 | + """ |
| 227 | + # Load model |
| 228 | + if load_format == "json": |
| 229 | + with open(filepath, "r") as f: |
| 230 | + model_metadata = json.load(f, object_hook=_decode_helper) |
| 231 | + elif load_format == "cbor": |
| 232 | + if utilmodel.cbor2 is None: |
| 233 | + raise ModuleNotFoundError("No module named 'cbor2'") |
| 234 | + with open(filepath, "rb") as f: |
| 235 | + model_metadata = utilmodel.cbor2. \ |
| 236 | + load(f, object_hook=_decode_helper_cbor) |
| 237 | + elif load_format == "pickle": |
| 238 | + with open(filepath, "rb") as f: |
| 239 | + model_metadata = pickle.load(f) |
| 240 | + else: |
| 241 | + raise ValueError("Wrong load format.") |
| 242 | + for key, val in model_metadata.items(): |
| 243 | + setattr(self, key, val) |
| 244 | + |
139 | 245 |
|
140 | 246 | @constraint(computing_units="${ComputingUnits}")
|
141 | 247 | @task(ind_blocks={Type: COLLECTION_IN, Depth: 2},
|
@@ -180,3 +286,59 @@ def _get_score(y_blocks, ypred_blocks):
|
180 | 286 | y_pred = Array._merge_blocks(ypred_blocks).flatten()
|
181 | 287 |
|
182 | 288 | return accuracy_score(y, y_pred)
|
| 289 | + |
| 290 | + |
| 291 | +def _decode_helper_cbor(decoder, obj): |
| 292 | + """Special decoder wrapper for dislib using cbor2.""" |
| 293 | + return _decode_helper(obj) |
| 294 | + |
| 295 | + |
| 296 | +def _decode_helper(obj): |
| 297 | + if isinstance(obj, dict) and "class_name" in obj: |
| 298 | + class_name = obj["class_name"] |
| 299 | + if class_name == "NearestNeighbors": |
| 300 | + nn = NearestNeighbors(obj["n_neighbors"]) |
| 301 | + nn.__setstate__(_decode_helper(obj["items"])) |
| 302 | + return nn |
| 303 | + elif class_name == "SKNeighbors": |
| 304 | + dict_ = _decode_helper(obj["items"]) |
| 305 | + model = SKNeighbors() |
| 306 | + model.__setstate__(dict_) |
| 307 | + return model |
| 308 | + elif class_name == "KDTree": |
| 309 | + dict_ = _decode_helper(obj["items"]) |
| 310 | + model = KDTree(dict_[0]) |
| 311 | + return model |
| 312 | + else: |
| 313 | + decoded = decoder_helper(class_name, obj) |
| 314 | + if decoded is not None: |
| 315 | + return decoded |
| 316 | + return obj |
| 317 | + |
| 318 | + |
| 319 | +def _encode_helper_cbor(encoder, obj): |
| 320 | + encoder.encode(_encode_helper(obj)) |
| 321 | + |
| 322 | + |
| 323 | +def _encode_helper(obj): |
| 324 | + encoded = encoder_helper(obj) |
| 325 | + if encoded is not None: |
| 326 | + return encoded |
| 327 | + elif isinstance(obj, SKNeighbors): |
| 328 | + return { |
| 329 | + "class_name": "SKNeighbors", |
| 330 | + "n_neighbors": obj.n_neighbors, |
| 331 | + "radius": obj.radius, |
| 332 | + "items": obj.__getstate__(), |
| 333 | + } |
| 334 | + elif isinstance(obj, KDTree): |
| 335 | + return { |
| 336 | + "class_name": "KDTree", |
| 337 | + "items": obj.__getstate__(), |
| 338 | + } |
| 339 | + elif isinstance(obj, NearestNeighbors): |
| 340 | + return { |
| 341 | + "class_name": obj.__class__.__name__, |
| 342 | + "n_neighbors": obj.n_neighbors, |
| 343 | + "items": obj.__getstate__(), |
| 344 | + } |
0 commit comments