diff --git a/src/skprometheus/preprocessing.py b/src/skprometheus/preprocessing.py index 1dd7711..034d338 100644 --- a/src/skprometheus/preprocessing.py +++ b/src/skprometheus/preprocessing.py @@ -31,7 +31,7 @@ def transform(self, X): for idx, row in enumerate(categories.T): for category in row: - if not category: + if category is None: category = "missing" MetricRegistry.model_categorical(feature=str(features[idx]), category=str(category)).inc() diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 864b5d7..f1169b6 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -37,6 +37,7 @@ def test_OneHotEncoder(): assert REGISTRY.get_sample_value('skprom_model_categorical_total', {'feature': '2', 'category': '4'}) == 2 assert REGISTRY.get_sample_value('skprom_model_categorical_total', {'feature': '3', 'category': '9'}) == 1 + assert REGISTRY.get_sample_value('skprom_model_categorical_total', {'feature': '3', 'category': '0'}) == 1 def test_OneHotEncoder_pandas():