1818LOGGER = logging .getLogger (__name__ )
1919MAX_NUMBER_OF_COLUMNS = 1000
2020DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
21+ PERFORMANCE_ALERT_DISPLAY_CAP = 1_000_000
22+
23+
24+ class _EarlyStopEstimation (Exception ):
25+ pass
2126
2227
2328class HMASynthesizer (BaseHierarchicalSampler , BaseMultiTableSynthesizer ):
@@ -102,7 +107,7 @@ def _get_num_extended_columns(
102107
103108 @classmethod
104109 def _estimate_columns_traversal (
105- cls , metadata , table_name , columns_per_table , visited , distributions = None
110+ cls , metadata , table_name , columns_per_table , visited , distributions = None , max_total = None
106111 ):
107112 """Given a table, estimate how many columns each parent will model.
108113
@@ -118,16 +123,21 @@ def _estimate_columns_traversal(
118123 """
119124 for child_name in metadata ._get_child_map ()[table_name ]:
120125 if child_name not in visited :
121- cls ._estimate_columns_traversal (metadata , child_name , columns_per_table , visited )
126+ cls ._estimate_columns_traversal (
127+ metadata , child_name , columns_per_table , visited , distributions , max_total
128+ )
122129
123130 columns_per_table [table_name ] += cls ._get_num_extended_columns (
124131 metadata , child_name , table_name , columns_per_table , distributions
125132 )
126133
134+ if max_total is not None and sum (columns_per_table .values ()) > max_total :
135+ raise _EarlyStopEstimation
136+
127137 visited .add (table_name )
128138
129139 @classmethod
130- def _estimate_num_columns (cls , metadata , distributions = None ):
140+ def _estimate_num_columns (cls , metadata , distributions = None , max_total = None ):
131141 """Estimate the number of columns that will be modeled for each table.
132142
133143 This method estimates how many extended columns will be generated during the
@@ -153,9 +163,14 @@ def _estimate_num_columns(cls, metadata, distributions=None):
153163 # each table will model
154164 visited = set ()
155165 for table_name in _get_root_tables (metadata .relationships ):
156- cls ._estimate_columns_traversal (
157- metadata , table_name , columns_per_table , visited , distributions
158- )
166+ try :
167+ cls ._estimate_columns_traversal (
168+ metadata , table_name , columns_per_table , visited , distributions , max_total
169+ )
170+ except _EarlyStopEstimation :
171+ break
172+ if max_total is not None and sum (columns_per_table .values ()) > max_total :
173+ break
159174
160175 return columns_per_table
161176
@@ -235,19 +250,29 @@ def _print_estimate_warning(self):
235250 metadata_columns = self ._get_num_data_columns (self .metadata )
236251 print_table = []
237252 distributions = self ._get_distributions ()
238- for table , est_cols in self ._estimate_num_columns (self .metadata , distributions ).items ():
253+ estimated_columns = self ._estimate_num_columns (
254+ self .metadata , distributions , max_total = PERFORMANCE_ALERT_DISPLAY_CAP
255+ )
256+ for table , est_cols in estimated_columns .items ():
239257 entry = []
240258 entry .append (table )
241259 entry .append (metadata_columns [table ])
242260 total_est_cols += est_cols
243261 entry .append (est_cols )
244262 print_table .append (entry )
263+ if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP :
264+ break
245265
246266 if total_est_cols > MAX_NUMBER_OF_COLUMNS :
267+ display_total = (
268+ f'{ PERFORMANCE_ALERT_DISPLAY_CAP } +'
269+ if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP
270+ else f'{ total_est_cols } '
271+ )
247272 self ._print (
248273 'PerformanceAlert: Using the HMASynthesizer on this metadata '
249274 'schema is not recommended. To model this data, HMA will '
250- f'generate a large number of columns. ({ total_est_cols } columns)\n \n '
275+ f'generate a large number of columns. ({ display_total } columns)\n \n '
251276 )
252277 self ._print (
253278 pd .DataFrame (
0 commit comments