Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/skprometheus/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def make_pipeline(*steps, memory=None, verbose=False):
class Pipeline(pipeline.Pipeline):
DEFAULT_LATENCY_BUCKETS = (0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25,
0.5, 0.75, 1., 2.5, 5., 7.5, 10., float('inf'))
DEFAULT_PROBA_BUCKETS = (0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1)
DEFAULT_PROBA_BUCKETS = (0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)

"""
A pipeline that adds metrics to the prometheus metric registry.
Expand Down
10 changes: 7 additions & 3 deletions src/skprometheus/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@ class OneHotEncoder(preprocessing.OneHotEncoder):
"""
OneHotEncoder that adds metrics to the prometheus metric registry.
"""
@wraps(preprocessing.OneHotEncoder.__init__, assigned=["__signature__"])
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __new__(cls, *args, **kwargs):
MetricRegistry.add_counter(
"model_categorical",
"Counts category occurrence for each categorical feature.",
additional_labels=("feature", "category"),
)
return super(OneHotEncoder, cls).__new__(cls)
@wraps(preprocessing.OneHotEncoder.__init__, assigned=["__signature__"])
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


def transform(self, X):
"""
Expand Down
19 changes: 4 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,16 @@
from types import SimpleNamespace

import pytest
from prometheus_client import REGISTRY
from sklearn.utils import estimator_checks

from skprometheus.metrics import MetricRegistry
from tests.utils import pickle_load_populates_metric_registry, unregister_collectors


@pytest.fixture(autouse=True)
def unregister_collectors():
def _unregister_collectors():
"""
Fixture for cleaning registers before each test. Both prometheus_client.REGISTRY and
skprometheus.metrics.MetricRegistry are cleaned.
"""
collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors:
REGISTRY.unregister(collector)

# Resetting attributes of MetricRegistry to avoid state transfer between tests
# TODO: Maybe find less ugly solution in future?
MetricRegistry.metrics_initialized = False
MetricRegistry.current_labels = {}
MetricRegistry.labels = set()
MetricRegistry.metrics = SimpleNamespace()
unregister_collectors()


transformer_checks = (
Expand All @@ -32,6 +20,7 @@ def unregister_collectors():
)

general_checks = (
pickle_load_populates_metric_registry,
estimator_checks.check_fit2d_predict1d,
estimator_checks.check_methods_subset_invariance,
estimator_checks.check_fit2d_1sample,
Expand Down
27 changes: 27 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import pickle
import time
from copy import copy
from types import SimpleNamespace

import numpy as np
from prometheus_client import REGISTRY
from sklearn.base import BaseEstimator, ClassifierMixin

from skprometheus.metrics import MetricRegistry


class FixedLatencyClassifier(ClassifierMixin, BaseEstimator):
def __init__(self, latency):
Expand Down Expand Up @@ -44,3 +50,24 @@ def predict(self, X):

def metric_exists(metric_name, registry=REGISTRY):
...


def pickle_load_populates_metric_registry(name, estimator):
metrics = {m._name for m in REGISTRY._collector_to_names.keys()}
pkl = pickle.dumps(estimator)
unregister_collectors()
pickle.loads(pkl)
assert {m._name for m in REGISTRY._collector_to_names.keys()} == metrics


def unregister_collectors():
collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors:
REGISTRY.unregister(collector)

# Resetting attributes of MetricRegistry to avoid state transfer between tests
# TODO: Maybe find less ugly solution in future?
MetricRegistry.metrics_initialized = False
MetricRegistry.current_labels = {}
MetricRegistry.labels = set()
MetricRegistry.metrics = SimpleNamespace()