Skip to content

Commit 46989cb

Browse files
authored
HMA defaults to norm distribution for child tables (#2710)
1 parent 926c0b4 commit 46989cb

File tree

5 files changed

+52
-28
lines changed

5 files changed

+52
-28
lines changed

sdv/multi_table/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def set_table_parameters(self, table_name, table_parameters):
404404
self._table_synthesizers[table_name] = self._synthesizer(
405405
metadata=table_metadata, **table_parameters
406406
)
407+
self._table_synthesizers[table_name]._data_processor.table_name = table_name
407408
self._table_parameters[table_name].update(deepcopy(table_parameters))
408409

409410
def _validate_all_tables(self, data):

sdv/multi_table/hma.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,16 @@ def _get_num_data_columns(metadata):
5353
columns_per_table = {}
5454
for table_name, table in metadata.tables.items():
5555
key_columns = metadata._get_all_keys(table_name)
56-
columns_per_table[table_name] = sum([
56+
num_data_columns = sum([
5757
1
5858
for col_name, col_meta in table.columns.items()
5959
if (
6060
col_meta['sdtype'] != 'id'
6161
or (col_name not in key_columns and col_meta.get('pii', False) is False)
6262
)
6363
])
64+
num_extended_columns = 0
65+
columns_per_table[table_name] = [num_data_columns, num_extended_columns]
6466

6567
return columns_per_table
6668

@@ -85,18 +87,29 @@ def _get_num_extended_columns(
8587
table_name, cls.DEFAULT_SYNTHESIZER_KWARGS['default_distribution']
8688
)
8789

88-
num_parameters = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[distribution]
89-
90+
num_params_data = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[distribution]
91+
num_params_extended = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[
92+
DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION
93+
]
9094
num_rows_columns = len(metadata._get_foreign_keys(parent_table, table_name))
9195

92-
# no parameter columns are generated if there are no data columns
93-
num_data_columns = columns_per_table[table_name]
94-
if num_data_columns == 0:
96+
# no parameter columns are generated if there are no data or extended columns
97+
num_data_columns = columns_per_table[table_name][0]
98+
num_extended_columns = columns_per_table[table_name][1]
99+
100+
if (num_data_columns + num_extended_columns) == 0:
95101
return num_rows_columns
96102

97-
num_parameters_columns = num_rows_columns * num_data_columns * num_parameters
103+
num_parameters_columns = (num_rows_columns * num_data_columns * num_params_data) + (
104+
num_rows_columns * num_extended_columns * num_params_extended
105+
)
98106

99-
num_correlation_columns = num_rows_columns * (num_data_columns - 1) * num_data_columns // 2
107+
num_correlation_columns = (
108+
num_rows_columns
109+
* (num_data_columns + num_extended_columns - 1)
110+
* (num_data_columns + num_extended_columns)
111+
// 2
112+
)
100113

101114
return num_correlation_columns + num_rows_columns + num_parameters_columns
102115

@@ -118,9 +131,11 @@ def _estimate_columns_traversal(
118131
"""
119132
for child_name in metadata._get_child_map()[table_name]:
120133
if child_name not in visited:
121-
cls._estimate_columns_traversal(metadata, child_name, columns_per_table, visited)
134+
cls._estimate_columns_traversal(
135+
metadata, child_name, columns_per_table, visited, distributions
136+
)
122137

123-
columns_per_table[table_name] += cls._get_num_extended_columns(
138+
columns_per_table[table_name][1] += cls._get_num_extended_columns(
124139
metadata, child_name, table_name, columns_per_table, distributions
125140
)
126141

@@ -157,7 +172,9 @@ def _estimate_num_columns(cls, metadata, distributions=None):
157172
metadata, table_name, columns_per_table, visited, distributions
158173
)
159174

160-
return columns_per_table
175+
return {
176+
table_name: sum(columns_list) for table_name, columns_list in columns_per_table.items()
177+
}
161178

162179
def __init__(self, metadata, locales=['en_US'], verbose=True):
163180
BaseMultiTableSynthesizer.__init__(self, metadata, locales=locales)
@@ -173,6 +190,11 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
173190
BaseHierarchicalSampler.__init__(
174191
self, self.metadata, self._table_synthesizers, self._table_sizes
175192
)
193+
child_tables = set()
194+
for relationship in metadata.relationships:
195+
child_tables.add(relationship['child_table_name'])
196+
for child_table_name in child_tables:
197+
self.set_table_parameters(child_table_name, {'default_distribution': 'norm'})
176198
self._print_estimate_warning()
177199

178200
def set_table_parameters(self, table_name, table_parameters):
@@ -238,7 +260,7 @@ def _print_estimate_warning(self):
238260
for table, est_cols in self._estimate_num_columns(self.metadata, distributions).items():
239261
entry = []
240262
entry.append(table)
241-
entry.append(metadata_columns[table])
263+
entry.append(sum(metadata_columns[table]))
242264
total_est_cols += est_cols
243265
entry.append(est_cols)
244266
print_table.append(entry)
@@ -679,6 +701,9 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
679701
parameters = self._extract_parameters(row, table_name, foreign_key)
680702
table_meta = self._table_synthesizers[table_name].get_metadata()
681703
synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name])
704+
extended_columns = getattr(self, '_parent_extended_columns', {}).get(table_name, [])
705+
if extended_columns:
706+
self._set_extended_columns_distributions(synthesizer, table_name, extended_columns)
682707
synthesizer._set_parameters(parameters)
683708
try:
684709
likelihoods[parent_id] = synthesizer._get_likelihood(table_rows)

tests/integration/multi_table/test_hma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2610,9 +2610,10 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes():
26102610
})
26112611
synthesizer = HMASynthesizer(metadata)
26122612
synthesizer._finalize = Mock(return_value=data)
2613+
distributions = synthesizer._get_distributions()
26132614

26142615
# Run estimation
2615-
estimated_num_columns = synthesizer._estimate_num_columns(metadata)
2616+
estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions)
26162617

26172618
# Run actual modeling
26182619
synthesizer.fit(data)

tests/integration/utils/test_poc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def test_simplify_schema(capsys):
7070
# Assert
7171
expected_message_before = re.compile(
7272
r'PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended\.'
73-
r' To model this data, HMA will generate a large number of columns\. \(173818 columns\)\s+'
73+
r' To model this data, HMA will generate a large number of columns\. \(135934 columns\)\s+'
7474
r'Table Name\s*#\s*Columns in Metadata\s*Est # Columns\s*'
7575
r'match_stats\s*24\s*24\s*'
76-
r'matches\s*39\s*412\s*'
77-
r'players\s*5\s*378\s*'
78-
r'teams\s*1\s*173004\s*'
76+
r'matches\s*39\s*364\s*'
77+
r'players\s*5\s*330\s*'
78+
r'teams\s*1\s*135216\s*'
7979
r'We recommend simplifying your metadata schema using '
8080
r"'sdv.utils.poc.simplify_schema'\.\s*"
8181
r'If this is not possible, please visit '

tests/unit/multi_table/test_hma.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def test___init__(self):
2626
assert isinstance(instance._table_synthesizers['oseba'], GaussianCopulaSynthesizer)
2727
assert isinstance(instance._table_synthesizers['upravna_enota'], GaussianCopulaSynthesizer)
2828
assert instance._table_parameters == {
29-
'nesreca': {'default_distribution': 'beta'},
30-
'oseba': {'default_distribution': 'beta'},
29+
'nesreca': {'default_distribution': 'norm'},
30+
'oseba': {'default_distribution': 'norm'},
3131
'upravna_enota': {'default_distribution': 'beta'},
3232
}
3333
instance.metadata.validate.assert_called_once_with()
@@ -70,8 +70,6 @@ def test__get_extension(self):
7070

7171
# Assert
7272
expected = pd.DataFrame({
73-
'__nesreca__upravna_enota__univariates__id_nesreca__a': [1.0, 1.0, 1.0, 1.0],
74-
'__nesreca__upravna_enota__univariates__id_nesreca__b': [1.0, 1.0, 1.0, 1.0],
7573
'__nesreca__upravna_enota__univariates__id_nesreca__loc': [0.0, 1.0, 2.0, 3.0],
7674
'__nesreca__upravna_enota__univariates__id_nesreca__scale': [np.nan] * 4,
7775
'__nesreca__upravna_enota__num_rows': [1.0, 1.0, 1.0, 1.0],
@@ -187,12 +185,8 @@ def test__augment_table(self):
187185
'nesreca_val': [0, 1, 2, 3],
188186
'value': [0, 1, 2, 3],
189187
'__oseba__id_nesreca__correlation__0__0': [0.0] * 4,
190-
'__oseba__id_nesreca__univariates__oseba_val__a': [1.0] * 4,
191-
'__oseba__id_nesreca__univariates__oseba_val__b': [1.0] * 4,
192188
'__oseba__id_nesreca__univariates__oseba_val__loc': [0.0, 1.0, 2.0, 3.0],
193189
'__oseba__id_nesreca__univariates__oseba_val__scale': [1e-6] * 4,
194-
'__oseba__id_nesreca__univariates__oseba_value__a': [1.0] * 4,
195-
'__oseba__id_nesreca__univariates__oseba_value__b': [1.0] * 4,
196190
'__oseba__id_nesreca__univariates__oseba_value__loc': [0.0, 1.0, 2.0, 3.0],
197191
'__oseba__id_nesreca__univariates__oseba_value__scale': [1e-6] * 4,
198192
'__oseba__id_nesreca__num_rows': [1.0] * 4,
@@ -877,9 +871,10 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self):
877871
})
878872
synthesizer = HMASynthesizer(metadata)
879873
synthesizer._finalize = Mock(return_value=data)
874+
distributions = synthesizer._get_distributions()
880875

881876
# Run estimation
882-
estimated_num_columns = synthesizer._estimate_num_columns(metadata)
877+
estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions)
883878

884879
# Run actual modeling
885880
synthesizer.fit(data)
@@ -1152,9 +1147,10 @@ def test__estimate_num_columns_to_be_modeled(self):
11521147
})
11531148
synthesizer = HMASynthesizer(metadata)
11541149
synthesizer._finalize = Mock(return_value=data)
1150+
distributions = synthesizer._get_distributions()
11551151

11561152
# Run estimation
1157-
estimated_num_columns = synthesizer._estimate_num_columns(metadata)
1153+
estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions)
11581154

11591155
# Run actual modeling
11601156
synthesizer.fit(data)
@@ -1264,9 +1260,10 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self):
12641260
})
12651261
synthesizer = HMASynthesizer(metadata)
12661262
synthesizer._finalize = Mock(return_value=data)
1263+
distributions = synthesizer._get_distributions()
12671264

12681265
# Run estimation
1269-
estimated_num_columns = synthesizer._estimate_num_columns(metadata)
1266+
estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions)
12701267

12711268
# Run actual modeling
12721269
synthesizer.fit(data)

0 commit comments

Comments
 (0)