1
1
"""Utils method for data augmentation metrics."""
2
2
3
- import pandas as pd
4
-
5
3
from sdmetrics ._utils_metadata import _process_data_with_metadata , _validate_single_table_metadata
6
-
7
-
8
- def _validate_tables (real_training_data , synthetic_data , real_validation_data ):
9
- """Validate the tables of the Data Augmentation metrics."""
10
- tables = [real_training_data , synthetic_data , real_validation_data ]
11
- if any (not isinstance (table , pd .DataFrame ) for table in tables ):
12
- raise ValueError (
13
- '`real_training_data`, `synthetic_data` and `real_validation_data` must be '
14
- 'pandas DataFrames.'
15
- )
16
-
17
-
18
- def _validate_prediction_column_name (prediction_column_name ):
19
- """Validate the prediction column name of the Data Augmentation metrics."""
20
- if not isinstance (prediction_column_name , str ):
21
- raise TypeError ('`prediction_column_name` must be a string.' )
22
-
23
-
24
- def _validate_classifier (classifier ):
25
- """Validate the classifier of the Data Augmentation metrics."""
26
- if classifier is not None and not isinstance (classifier , str ):
27
- raise TypeError ('`classifier` must be a string or None.' )
28
-
29
- if classifier != 'XGBoost' :
30
- raise ValueError ('Currently only `XGBoost` is supported as classifier.' )
4
+ from sdmetrics .single_table .utils import (
5
+ _validate_classifier ,
6
+ _validate_data_and_metadata ,
7
+ _validate_prediction_column_name ,
8
+ _validate_tables ,
9
+ )
31
10
32
11
33
12
def _validate_fixed_recall_value (fixed_recall_value ):
@@ -53,51 +32,6 @@ def _validate_parameters(
53
32
_validate_fixed_recall_value (fixed_recall_value )
54
33
55
34
56
- def _validate_data_and_metadata (
57
- real_training_data ,
58
- synthetic_data ,
59
- real_validation_data ,
60
- metadata ,
61
- prediction_column_name ,
62
- minority_class_label ,
63
- ):
64
- """Validate the data and metadata of the Data Augmentation metrics."""
65
- if prediction_column_name not in metadata ['columns' ]:
66
- raise ValueError (
67
- f'The column `{ prediction_column_name } ` is not described in the metadata.'
68
- ' Please update your metadata.'
69
- )
70
-
71
- if metadata ['columns' ][prediction_column_name ]['sdtype' ] not in ('categorical' , 'boolean' ):
72
- raise ValueError (
73
- f'The column `{ prediction_column_name } ` must be either categorical or boolean.'
74
- ' Please update your metadata.'
75
- )
76
-
77
- if minority_class_label not in real_training_data [prediction_column_name ].unique ():
78
- raise ValueError (
79
- f'The value `{ minority_class_label } ` is not present in the column '
80
- f'`{ prediction_column_name } ` for the real training data.'
81
- )
82
-
83
- if minority_class_label not in real_validation_data [prediction_column_name ].unique ():
84
- raise ValueError (
85
- f"The metric can't be computed because the value `{ minority_class_label } ` "
86
- f'is not present in the column `{ prediction_column_name } ` for the real validation data.'
87
- ' The `precision` and `recall` are undefined for this case.'
88
- )
89
-
90
- synthetic_labels = set (synthetic_data [prediction_column_name ].unique ())
91
- real_labels = set (real_training_data [prediction_column_name ].unique ())
92
- if not synthetic_labels .issubset (real_labels ):
93
- to_print = "', '" .join (sorted (synthetic_labels - real_labels ))
94
- raise ValueError (
95
- f'The ``{ prediction_column_name } `` column must have the same values in the real '
96
- 'and synthetic data. The following values are present in the synthetic data and'
97
- f" not the real data: '{ to_print } '"
98
- )
99
-
100
-
101
35
def _validate_inputs (
102
36
real_training_data ,
103
37
synthetic_data ,
@@ -127,6 +61,16 @@ def _validate_inputs(
127
61
minority_class_label ,
128
62
)
129
63
64
+ synthetic_labels = set (synthetic_data [prediction_column_name ].unique ())
65
+ real_labels = set (real_training_data [prediction_column_name ].unique ())
66
+ if not synthetic_labels .issubset (real_labels ):
67
+ to_print = "', '" .join (sorted (synthetic_labels - real_labels ))
68
+ raise ValueError (
69
+ f'The `{ prediction_column_name } ` column must have the same values in the real '
70
+ 'and synthetic data. The following values are present in the synthetic data and'
71
+ f" not the real data: '{ to_print } '"
72
+ )
73
+
130
74
131
75
def _process_data_with_metadata_ml_efficacy_metrics (
132
76
real_training_data , synthetic_data , real_validation_data , metadata
0 commit comments