Skip to content

Commit 093707a

Browse files
committed
Fix test
1 parent 5791165 commit 093707a

File tree

2 files changed

+36
-67
lines changed

2 files changed

+36
-67
lines changed

sdv/multi_table/hma.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,9 @@
1616
from sdv.sampling import BaseHierarchicalSampler
1717

1818
LOGGER = logging.getLogger(__name__)
19-
MAX_NUMBER_OF_COLUMNS = 1000
20-
DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
2119
PERFORMANCE_ALERT_DISPLAY_CAP = 1_000_000
22-
23-
24-
class _EarlyStopEstimation(Exception):
25-
pass
20+
DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
21+
MAX_NUMBER_OF_COLUMNS = 1000
2622

2723

2824
class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
@@ -107,7 +103,7 @@ def _get_num_extended_columns(
107103

108104
@classmethod
109105
def _estimate_columns_traversal(
110-
cls, metadata, table_name, columns_per_table, visited, distributions=None, max_total=None
106+
cls, metadata, table_name, columns_per_table, visited, distributions=None
111107
):
112108
"""Given a table, estimate how many columns each parent will model.
113109
@@ -123,21 +119,19 @@ def _estimate_columns_traversal(
123119
"""
124120
for child_name in metadata._get_child_map()[table_name]:
125121
if child_name not in visited:
126-
cls._estimate_columns_traversal(
127-
metadata, child_name, columns_per_table, visited, distributions, max_total
128-
)
122+
cls._estimate_columns_traversal(metadata, child_name, columns_per_table, visited)
129123

130124
columns_per_table[table_name] += cls._get_num_extended_columns(
131125
metadata, child_name, table_name, columns_per_table, distributions
132126
)
133127

134-
if max_total is not None and sum(columns_per_table.values()) > max_total:
135-
raise _EarlyStopEstimation
128+
if sum(columns_per_table.values()) > PERFORMANCE_ALERT_DISPLAY_CAP:
129+
return
136130

137131
visited.add(table_name)
138132

139133
@classmethod
140-
def _estimate_num_columns(cls, metadata, distributions=None, max_total=None):
134+
def _estimate_num_columns(cls, metadata, distributions=None):
141135
"""Estimate the number of columns that will be modeled for each table.
142136
143137
This method estimates how many extended columns will be generated during the
@@ -163,13 +157,10 @@ def _estimate_num_columns(cls, metadata, distributions=None, max_total=None):
163157
# each table will model
164158
visited = set()
165159
for table_name in _get_root_tables(metadata.relationships):
166-
try:
167-
cls._estimate_columns_traversal(
168-
metadata, table_name, columns_per_table, visited, distributions, max_total
169-
)
170-
except _EarlyStopEstimation:
171-
break
172-
if max_total is not None and sum(columns_per_table.values()) > max_total:
160+
cls._estimate_columns_traversal(
161+
metadata, table_name, columns_per_table, visited, distributions
162+
)
163+
if sum(columns_per_table.values()) > PERFORMANCE_ALERT_DISPLAY_CAP:
173164
break
174165

175166
return columns_per_table
@@ -250,18 +241,14 @@ def _print_estimate_warning(self):
250241
metadata_columns = self._get_num_data_columns(self.metadata)
251242
print_table = []
252243
distributions = self._get_distributions()
253-
estimated_columns = self._estimate_num_columns(
254-
self.metadata, distributions, max_total=PERFORMANCE_ALERT_DISPLAY_CAP
255-
)
244+
estimated_columns = self._estimate_num_columns(self.metadata, distributions)
256245
for table, est_cols in estimated_columns.items():
257246
entry = []
258247
entry.append(table)
259248
entry.append(metadata_columns[table])
260249
total_est_cols += est_cols
261250
entry.append(est_cols)
262251
print_table.append(entry)
263-
if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP:
264-
break
265252

266253
if total_est_cols > MAX_NUMBER_OF_COLUMNS:
267254
display_total = (

tests/unit/multi_table/test_hma.py

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
from sdv.errors import SynthesizerInputError
99
from sdv.metadata.metadata import Metadata
1010
from sdv.multi_table.hma import (
11-
PERFORMANCE_ALERT_DISPLAY_CAP,
1211
HMASynthesizer,
13-
_EarlyStopEstimation,
1412
)
1513
from sdv.single_table.copulas import GaussianCopulaSynthesizer
1614
from tests.utils import get_multi_table_data, get_multi_table_metadata
@@ -135,6 +133,30 @@ def test__print_estimate_warning(self, get_distributions_mock, estimate_mock, ca
135133
match = re.search(constraint, captured.out + captured.err)
136134
assert match is None
137135

136+
@patch('sdv.multi_table.hma.HMASynthesizer._estimate_num_columns')
137+
@patch('sdv.multi_table.hma.HMASynthesizer._get_distributions')
138+
def test__print_estimate_warning_many_cols(self, get_distributions_mock, estimate_mock, capsys):
139+
"""Test that a warning appears if there are more than 1_000_000 expected columns"""
140+
# Setup
141+
metadata = get_multi_table_metadata()
142+
estimate_mock.side_effect = [{'nesreca': 1_000_010}, {'nesreca': 10}]
143+
144+
# Run
145+
HMASynthesizer(metadata)
146+
captured = capsys.readouterr()
147+
148+
# Assert
149+
expected_output = (
150+
'PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended.'
151+
' To model this data, HMA will generate a large number of columns. (1000000+ columns)\n'
152+
'\n\nTable Name # Columns in Metadata Est # Columns\n'
153+
' nesreca 1 1000010\n\n'
154+
"We recommend simplifying your metadata schema using 'sdv.utils.poc.simplify_schema'."
155+
'\nIf this is not possible, please visit datacebo.com and reach out to us for '
156+
'enterprise solutions.\n\n'
157+
)
158+
assert captured.out == expected_output
159+
138160
def test__get_extension_foreign_key_only(self):
139161
"""Test the ``_get_extension`` method.
140162
@@ -1289,43 +1311,3 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self):
12891311
num_table_cols -= 1
12901312

12911313
assert num_table_cols == estimated_num_columns[table_name]
1292-
1293-
@patch('sdv.multi_table.hma.HMASynthesizer._estimate_num_columns')
1294-
@patch('sdv.multi_table.hma.HMASynthesizer._get_distributions')
1295-
def test__print_estimate_warning_capped_display_and_break(
1296-
self, get_distributions_mock, estimate_mock, capsys
1297-
):
1298-
"""When exceeding the cap, display 1_000_000+ and stop listing rows."""
1299-
# Setup
1300-
metadata = get_multi_table_metadata()
1301-
estimate_mock.return_value = {
1302-
'nesreca': PERFORMANCE_ALERT_DISPLAY_CAP + 5,
1303-
'oseba': 10,
1304-
}
1305-
1306-
# Run
1307-
HMASynthesizer(metadata)
1308-
captured = capsys.readouterr()
1309-
1310-
# Assert
1311-
assert 'PerformanceAlert:' in captured.out
1312-
assert f'{PERFORMANCE_ALERT_DISPLAY_CAP}+' in captured.out
1313-
assert 'nesreca' in captured.out
1314-
assert 'oseba' not in captured.out
1315-
1316-
def test__estimate_num_columns_early_stop_exception_is_handled(self):
1317-
"""_estimate_num_columns should handle internal early-stop and return partial results."""
1318-
# Setup
1319-
metadata = get_multi_table_metadata()
1320-
1321-
with patch('sdv.multi_table.hma.HMASynthesizer._estimate_columns_traversal') as tr_mock:
1322-
tr_mock.side_effect = _EarlyStopEstimation
1323-
1324-
# Run
1325-
result = HMASynthesizer._estimate_num_columns(metadata, max_total=1)
1326-
1327-
# Assert
1328-
assert isinstance(result, dict)
1329-
for table_name in ['nesreca', 'oseba', 'upravna_enota']:
1330-
assert table_name in result
1331-
assert isinstance(result[table_name], int)

0 commit comments

Comments
 (0)