Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from sdv.sampling import BaseHierarchicalSampler

LOGGER = logging.getLogger(__name__)
MAX_NUMBER_OF_COLUMNS = 1000
PERFORMANCE_ALERT_DISPLAY_CAP = 1_000_000
DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
MAX_NUMBER_OF_COLUMNS = 1000


class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
Expand Down Expand Up @@ -139,6 +140,10 @@ def _estimate_columns_traversal(
metadata, child_name, table_name, columns_per_table, distributions
)

total_cols = sum(columns_list[1] for columns_list in columns_per_table.values())
if total_cols > PERFORMANCE_ALERT_DISPLAY_CAP:
return

visited.add(table_name)

@classmethod
Expand Down Expand Up @@ -171,6 +176,9 @@ def _estimate_num_columns(cls, metadata, distributions=None):
cls._estimate_columns_traversal(
metadata, table_name, columns_per_table, visited, distributions
)
total_cols = sum(columns_list[1] for columns_list in columns_per_table.values())
if total_cols > PERFORMANCE_ALERT_DISPLAY_CAP:
break

return {
table_name: sum(columns_list) for table_name, columns_list in columns_per_table.items()
Expand Down Expand Up @@ -257,7 +265,8 @@ def _print_estimate_warning(self):
metadata_columns = self._get_num_data_columns(self.metadata)
print_table = []
distributions = self._get_distributions()
for table, est_cols in self._estimate_num_columns(self.metadata, distributions).items():
estimated_columns = self._estimate_num_columns(self.metadata, distributions)
for table, est_cols in estimated_columns.items():
entry = []
entry.append(table)
entry.append(sum(metadata_columns[table]))
Expand All @@ -266,10 +275,15 @@ def _print_estimate_warning(self):
print_table.append(entry)

if total_est_cols > MAX_NUMBER_OF_COLUMNS:
display_total = (
f'{PERFORMANCE_ALERT_DISPLAY_CAP}+'
if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP
else f'{total_est_cols}'
)
self._print(
'PerformanceAlert: Using the HMASynthesizer on this metadata '
'schema is not recommended. To model this data, HMA will '
f'generate a large number of columns. ({total_est_cols} columns)\n\n'
f'generate a large number of columns. ({display_total} columns)\n\n'
)
self._print(
pd.DataFrame(
Expand Down
28 changes: 27 additions & 1 deletion tests/unit/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from sdv.errors import SynthesizerInputError
from sdv.metadata.metadata import Metadata
from sdv.multi_table.hma import HMASynthesizer
from sdv.multi_table.hma import (
HMASynthesizer,
)
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from tests.utils import get_multi_table_data, get_multi_table_metadata

Expand Down Expand Up @@ -129,6 +131,30 @@ def test__print_estimate_warning(self, get_distributions_mock, estimate_mock, ca
match = re.search(constraint, captured.out + captured.err)
assert match is None

@patch('sdv.multi_table.hma.HMASynthesizer._estimate_num_columns')
@patch('sdv.multi_table.hma.HMASynthesizer._get_distributions')
def test__print_estimate_warning_many_cols(self, get_distributions_mock, estimate_mock, capsys):
"""Test that a warning appears if there are more than 1_000_000 expected columns"""
# Setup
metadata = get_multi_table_metadata()
estimate_mock.side_effect = [{'nesreca': 1_000_010}, {'nesreca': 10}]

# Run
HMASynthesizer(metadata)
captured = capsys.readouterr()

# Assert
expected_output = (
'PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended.'
' To model this data, HMA will generate a large number of columns. (1000000+ columns)\n'
'\n\nTable Name # Columns in Metadata Est # Columns\n'
' nesreca 1 1000010\n\n'
"We recommend simplifying your metadata schema using 'sdv.utils.poc.simplify_schema'."
'\nIf this is not possible, please visit datacebo.com and reach out to us for '
'enterprise solutions.\n\n'
)
assert captured.out == expected_output

def test__get_extension_foreign_key_only(self):
"""Test the ``_get_extension`` method.

Expand Down