Skip to content

Commit f1b03c4

Browse files
committed
Add EqualizedOddsImprovement
1 parent 48e7ee5 commit f1b03c4

File tree

5 files changed

+1041
-25
lines changed

5 files changed

+1041
-25
lines changed

sdmetrics/single_table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
7878
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
7979
from sdmetrics.single_table.table_structure import TableStructure
80+
from sdmetrics.single_table.equalized_odds import EqualizedOddsImprovement
8081

8182
__all__ = [
8283
'bayesian_network',
@@ -140,4 +141,5 @@
140141
'TableStructure',
141142
'DCRBaselineProtection',
142143
'DCROverfittingProtection',
144+
'EqualizedOddsImprovement',
143145
]

sdmetrics/single_table/data_augmentation/utils.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,11 @@
33
import pandas as pd
44

55
from sdmetrics._utils_metadata import _process_data_with_metadata, _validate_single_table_metadata
6-
7-
8-
def _validate_tables(real_training_data, synthetic_data, real_validation_data):
9-
"""Validate the tables of the Data Augmentation metrics."""
10-
tables = [real_training_data, synthetic_data, real_validation_data]
11-
if any(not isinstance(table, pd.DataFrame) for table in tables):
12-
raise ValueError(
13-
'`real_training_data`, `synthetic_data` and `real_validation_data` must be '
14-
'pandas DataFrames.'
15-
)
16-
17-
18-
def _validate_prediction_column_name(prediction_column_name):
19-
"""Validate the prediction column name of the Data Augmentation metrics."""
20-
if not isinstance(prediction_column_name, str):
21-
raise TypeError('`prediction_column_name` must be a string.')
22-
23-
24-
def _validate_classifier(classifier):
25-
"""Validate the classifier of the Data Augmentation metrics."""
26-
if classifier is not None and not isinstance(classifier, str):
27-
raise TypeError('`classifier` must be a string or None.')
28-
29-
if classifier != 'XGBoost':
30-
raise ValueError('Currently only `XGBoost` is supported as classifier.')
6+
from sdmetrics.single_table.utils import (
7+
_validate_classifier,
8+
_validate_prediction_column_name,
9+
_validate_tables,
10+
)
3111

3212

3313
def _validate_fixed_recall_value(fixed_recall_value):

0 commit comments

Comments
 (0)