Skip to content

Commit d78f4f8

Browse files
committed
Add EqualizedOddsImprovement
1 parent 48e7ee5 commit d78f4f8

File tree

11 files changed

+1235
-88
lines changed

11 files changed

+1235
-88
lines changed

sdmetrics/single_table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
7878
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
7979
from sdmetrics.single_table.table_structure import TableStructure
80+
from sdmetrics.single_table.equalized_odds import EqualizedOddsImprovement
8081

8182
__all__ = [
8283
'bayesian_network',
@@ -140,4 +141,5 @@
140141
'TableStructure',
141142
'DCRBaselineProtection',
142143
'DCROverfittingProtection',
144+
'EqualizedOddsImprovement',
143145
]

sdmetrics/single_table/data_augmentation/utils.py

Lines changed: 16 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,12 @@
11
"""Utils method for data augmentation metrics."""
22

3-
import pandas as pd
4-
53
from sdmetrics._utils_metadata import _process_data_with_metadata, _validate_single_table_metadata
6-
7-
8-
def _validate_tables(real_training_data, synthetic_data, real_validation_data):
9-
"""Validate the tables of the Data Augmentation metrics."""
10-
tables = [real_training_data, synthetic_data, real_validation_data]
11-
if any(not isinstance(table, pd.DataFrame) for table in tables):
12-
raise ValueError(
13-
'`real_training_data`, `synthetic_data` and `real_validation_data` must be '
14-
'pandas DataFrames.'
15-
)
16-
17-
18-
def _validate_prediction_column_name(prediction_column_name):
19-
"""Validate the prediction column name of the Data Augmentation metrics."""
20-
if not isinstance(prediction_column_name, str):
21-
raise TypeError('`prediction_column_name` must be a string.')
22-
23-
24-
def _validate_classifier(classifier):
25-
"""Validate the classifier of the Data Augmentation metrics."""
26-
if classifier is not None and not isinstance(classifier, str):
27-
raise TypeError('`classifier` must be a string or None.')
28-
29-
if classifier != 'XGBoost':
30-
raise ValueError('Currently only `XGBoost` is supported as classifier.')
4+
from sdmetrics.single_table.utils import (
5+
_validate_classifier,
6+
_validate_data_and_metadata,
7+
_validate_prediction_column_name,
8+
_validate_tables,
9+
)
3110

3211

3312
def _validate_fixed_recall_value(fixed_recall_value):
@@ -53,51 +32,6 @@ def _validate_parameters(
5332
_validate_fixed_recall_value(fixed_recall_value)
5433

5534

56-
def _validate_data_and_metadata(
57-
real_training_data,
58-
synthetic_data,
59-
real_validation_data,
60-
metadata,
61-
prediction_column_name,
62-
minority_class_label,
63-
):
64-
"""Validate the data and metadata of the Data Augmentation metrics."""
65-
if prediction_column_name not in metadata['columns']:
66-
raise ValueError(
67-
f'The column `{prediction_column_name}` is not described in the metadata.'
68-
' Please update your metadata.'
69-
)
70-
71-
if metadata['columns'][prediction_column_name]['sdtype'] not in ('categorical', 'boolean'):
72-
raise ValueError(
73-
f'The column `{prediction_column_name}` must be either categorical or boolean.'
74-
' Please update your metadata.'
75-
)
76-
77-
if minority_class_label not in real_training_data[prediction_column_name].unique():
78-
raise ValueError(
79-
f'The value `{minority_class_label}` is not present in the column '
80-
f'`{prediction_column_name}` for the real training data.'
81-
)
82-
83-
if minority_class_label not in real_validation_data[prediction_column_name].unique():
84-
raise ValueError(
85-
f"The metric can't be computed because the value `{minority_class_label}` "
86-
f'is not present in the column `{prediction_column_name}` for the real validation data.'
87-
' The `precision` and `recall` are undefined for this case.'
88-
)
89-
90-
synthetic_labels = set(synthetic_data[prediction_column_name].unique())
91-
real_labels = set(real_training_data[prediction_column_name].unique())
92-
if not synthetic_labels.issubset(real_labels):
93-
to_print = "', '".join(sorted(synthetic_labels - real_labels))
94-
raise ValueError(
95-
f'The ``{prediction_column_name}`` column must have the same values in the real '
96-
'and synthetic data. The following values are present in the synthetic data and'
97-
f" not the real data: '{to_print}'"
98-
)
99-
100-
10135
def _validate_inputs(
10236
real_training_data,
10337
synthetic_data,
@@ -127,6 +61,16 @@ def _validate_inputs(
12761
minority_class_label,
12862
)
12963

64+
synthetic_labels = set(synthetic_data[prediction_column_name].unique())
65+
real_labels = set(real_training_data[prediction_column_name].unique())
66+
if not synthetic_labels.issubset(real_labels):
67+
to_print = "', '".join(sorted(synthetic_labels - real_labels))
68+
raise ValueError(
69+
f'The `{prediction_column_name}` column must have the same values in the real '
70+
'and synthetic data. The following values are present in the synthetic data and'
71+
f" not the real data: '{to_print}'"
72+
)
73+
13074

13175
def _process_data_with_metadata_ml_efficacy_metrics(
13276
real_training_data, synthetic_data, real_validation_data, metadata

0 commit comments

Comments
 (0)