Skip to content

Commit 5791165

Browse files
committed
Update msg
1 parent 926c0b4 commit 5791165

File tree

2 files changed

+78
-9
lines changed

2 files changed

+78
-9
lines changed

sdv/multi_table/hma.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
LOGGER = logging.getLogger(__name__)
1919
MAX_NUMBER_OF_COLUMNS = 1000
2020
DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
21+
PERFORMANCE_ALERT_DISPLAY_CAP = 1_000_000
22+
23+
24+
class _EarlyStopEstimation(Exception):
25+
pass
2126

2227

2328
class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
@@ -102,7 +107,7 @@ def _get_num_extended_columns(
102107

103108
@classmethod
104109
def _estimate_columns_traversal(
105-
cls, metadata, table_name, columns_per_table, visited, distributions=None
110+
cls, metadata, table_name, columns_per_table, visited, distributions=None, max_total=None
106111
):
107112
"""Given a table, estimate how many columns each parent will model.
108113
@@ -118,16 +123,21 @@ def _estimate_columns_traversal(
118123
"""
119124
for child_name in metadata._get_child_map()[table_name]:
120125
if child_name not in visited:
121-
cls._estimate_columns_traversal(metadata, child_name, columns_per_table, visited)
126+
cls._estimate_columns_traversal(
127+
metadata, child_name, columns_per_table, visited, distributions, max_total
128+
)
122129

123130
columns_per_table[table_name] += cls._get_num_extended_columns(
124131
metadata, child_name, table_name, columns_per_table, distributions
125132
)
126133

134+
if max_total is not None and sum(columns_per_table.values()) > max_total:
135+
raise _EarlyStopEstimation
136+
127137
visited.add(table_name)
128138

129139
@classmethod
130-
def _estimate_num_columns(cls, metadata, distributions=None):
140+
def _estimate_num_columns(cls, metadata, distributions=None, max_total=None):
131141
"""Estimate the number of columns that will be modeled for each table.
132142
133143
This method estimates how many extended columns will be generated during the
@@ -153,9 +163,14 @@ def _estimate_num_columns(cls, metadata, distributions=None):
153163
# each table will model
154164
visited = set()
155165
for table_name in _get_root_tables(metadata.relationships):
156-
cls._estimate_columns_traversal(
157-
metadata, table_name, columns_per_table, visited, distributions
158-
)
166+
try:
167+
cls._estimate_columns_traversal(
168+
metadata, table_name, columns_per_table, visited, distributions, max_total
169+
)
170+
except _EarlyStopEstimation:
171+
break
172+
if max_total is not None and sum(columns_per_table.values()) > max_total:
173+
break
159174

160175
return columns_per_table
161176

@@ -235,19 +250,29 @@ def _print_estimate_warning(self):
235250
metadata_columns = self._get_num_data_columns(self.metadata)
236251
print_table = []
237252
distributions = self._get_distributions()
238-
for table, est_cols in self._estimate_num_columns(self.metadata, distributions).items():
253+
estimated_columns = self._estimate_num_columns(
254+
self.metadata, distributions, max_total=PERFORMANCE_ALERT_DISPLAY_CAP
255+
)
256+
for table, est_cols in estimated_columns.items():
239257
entry = []
240258
entry.append(table)
241259
entry.append(metadata_columns[table])
242260
total_est_cols += est_cols
243261
entry.append(est_cols)
244262
print_table.append(entry)
263+
if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP:
264+
break
245265

246266
if total_est_cols > MAX_NUMBER_OF_COLUMNS:
267+
display_total = (
268+
f'{PERFORMANCE_ALERT_DISPLAY_CAP}+'
269+
if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP
270+
else f'{total_est_cols}'
271+
)
247272
self._print(
248273
'PerformanceAlert: Using the HMASynthesizer on this metadata '
249274
'schema is not recommended. To model this data, HMA will '
250-
f'generate a large number of columns. ({total_est_cols} columns)\n\n'
275+
f'generate a large number of columns. ({display_total} columns)\n\n'
251276
)
252277
self._print(
253278
pd.DataFrame(

tests/unit/multi_table/test_hma.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
from sdv.errors import SynthesizerInputError
99
from sdv.metadata.metadata import Metadata
10-
from sdv.multi_table.hma import HMASynthesizer
10+
from sdv.multi_table.hma import (
11+
PERFORMANCE_ALERT_DISPLAY_CAP,
12+
HMASynthesizer,
13+
_EarlyStopEstimation,
14+
)
1115
from sdv.single_table.copulas import GaussianCopulaSynthesizer
1216
from tests.utils import get_multi_table_data, get_multi_table_metadata
1317

@@ -1285,3 +1289,43 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self):
12851289
num_table_cols -= 1
12861290

12871291
assert num_table_cols == estimated_num_columns[table_name]
1292+
1293+
@patch('sdv.multi_table.hma.HMASynthesizer._estimate_num_columns')
1294+
@patch('sdv.multi_table.hma.HMASynthesizer._get_distributions')
1295+
def test__print_estimate_warning_capped_display_and_break(
1296+
self, get_distributions_mock, estimate_mock, capsys
1297+
):
1298+
"""When exceeding the cap, display 1_000_000+ and stop listing rows."""
1299+
# Setup
1300+
metadata = get_multi_table_metadata()
1301+
estimate_mock.return_value = {
1302+
'nesreca': PERFORMANCE_ALERT_DISPLAY_CAP + 5,
1303+
'oseba': 10,
1304+
}
1305+
1306+
# Run
1307+
HMASynthesizer(metadata)
1308+
captured = capsys.readouterr()
1309+
1310+
# Assert
1311+
assert 'PerformanceAlert:' in captured.out
1312+
assert f'{PERFORMANCE_ALERT_DISPLAY_CAP}+' in captured.out
1313+
assert 'nesreca' in captured.out
1314+
assert 'oseba' not in captured.out
1315+
1316+
def test__estimate_num_columns_early_stop_exception_is_handled(self):
1317+
"""_estimate_num_columns should handle internal early-stop and return partial results."""
1318+
# Setup
1319+
metadata = get_multi_table_metadata()
1320+
1321+
with patch('sdv.multi_table.hma.HMASynthesizer._estimate_columns_traversal') as tr_mock:
1322+
tr_mock.side_effect = _EarlyStopEstimation
1323+
1324+
# Run
1325+
result = HMASynthesizer._estimate_num_columns(metadata, max_total=1)
1326+
1327+
# Assert
1328+
assert isinstance(result, dict)
1329+
for table_name in ['nesreca', 'oseba', 'upravna_enota']:
1330+
assert table_name in result
1331+
assert isinstance(result[table_name], int)

0 commit comments

Comments
 (0)