-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy pathbase.py
More file actions
108 lines (85 loc) · 3.9 KB
/
base.py
File metadata and controls
108 lines (85 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Base class for Machine Learning Detection metrics for single table datasets."""
import logging
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sdmetrics.errors import IncomputableMetricError
from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric
from sdmetrics.utils import HyperTransformer
LOGGER = logging.getLogger(__name__)
class DetectionMetric(SingleTableMetric):
"""Base class for Machine Learning Detection based metrics on single tables.
These metrics build a Machine Learning Classifier that learns to tell the synthetic
data apart from the real data, which later on is evaluated using Cross Validation.
The output of the metric is one minus the average ROC AUC score obtained.
Attributes:
name (str):
Name to use when reports about this metric are printed.
goal (sdmetrics.goal.Goal):
The goal of this metric.
min_value (Union[float, tuple[float]]):
Minimum value or values that this metric can take.
max_value (Union[float, tuple[float]]):
Maximum value or values that this metric can take.
"""
name = 'SingleTable Detection'
goal = Goal.MINIMIZE
min_value = 0.0
max_value = 1.0
@staticmethod
def _fit_predict(X_train, y_train, X_test):
"""Fit a classifier and then use it to predict."""
raise NotImplementedError()
@classmethod
def compute(cls, real_data, synthetic_data, metadata=None):
"""Compute this metric.
This builds a Machine Learning Classifier that learns to tell the synthetic
data apart from the real data, which later on is evaluated using Cross Validation.
The output of the metric is the average ROC AUC score obtained.
Args:
real_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the synthetic dataset.
metadata (dict):
Table metadata dict. If not passed, it is build based on the
real_data fields and dtypes.
Returns:
float:
One minus the ROC AUC Cross Validation Score obtained by the classifier.
"""
real_data, synthetic_data, metadata = cls._validate_inputs(
real_data, synthetic_data, metadata)
ht = HyperTransformer()
transformed_real_data = ht.fit_transform(real_data).to_numpy()
transformed_synthetic_data = ht.transform(synthetic_data).to_numpy()
X = np.concatenate([transformed_real_data, transformed_synthetic_data])
y = np.hstack([
np.ones(len(transformed_real_data)), np.zeros(len(transformed_synthetic_data))
])
if np.isin(X, [np.inf, -np.inf]).any():
X[np.isin(X, [np.inf, -np.inf])] = np.nan
try:
scores = []
kf = StratifiedKFold(n_splits=3, shuffle=True)
for train_index, test_index in kf.split(X, y):
y_pred = cls._fit_predict(X[train_index], y[train_index], X[test_index])
roc_auc = roc_auc_score(y[test_index], y_pred)
scores.append(max(0.5,roc_auc))
return np.mean(scores)
except ValueError as err:
raise IncomputableMetricError(f'DetectionMetric: Unable to be fit with error {err}')
@classmethod
def normalize(cls, raw_score):
"""Return the `raw_score`normalized to be higher-is-better in [0,1]
Args:
raw_score (float):
The value of the metric from `compute`.
Returns:
float:
Returns `2*(1-raw_score)`.
"""
assert raw_score >= 0.5, "raw auc score should be in [0.5,1]"
score = 2 * (1 - raw_score)
return super().normalize(score)