1616from sdv .sampling import BaseHierarchicalSampler
1717
1818LOGGER = logging .getLogger (__name__ )
19- MAX_NUMBER_OF_COLUMNS = 1000
20- DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
2119PERFORMANCE_ALERT_DISPLAY_CAP = 1_000_000
22-
23-
24- class _EarlyStopEstimation (Exception ):
25- pass
20+ DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'
21+ MAX_NUMBER_OF_COLUMNS = 1000
2622
2723
2824class HMASynthesizer (BaseHierarchicalSampler , BaseMultiTableSynthesizer ):
@@ -107,7 +103,7 @@ def _get_num_extended_columns(
107103
108104 @classmethod
109105 def _estimate_columns_traversal (
110- cls , metadata , table_name , columns_per_table , visited , distributions = None , max_total = None
106+ cls , metadata , table_name , columns_per_table , visited , distributions = None
111107 ):
112108 """Given a table, estimate how many columns each parent will model.
113109
@@ -123,21 +119,19 @@ def _estimate_columns_traversal(
123119 """
124120 for child_name in metadata ._get_child_map ()[table_name ]:
125121 if child_name not in visited :
126- cls ._estimate_columns_traversal (
127- metadata , child_name , columns_per_table , visited , distributions , max_total
128- )
122+ cls ._estimate_columns_traversal (metadata , child_name , columns_per_table , visited )
129123
130124 columns_per_table [table_name ] += cls ._get_num_extended_columns (
131125 metadata , child_name , table_name , columns_per_table , distributions
132126 )
133127
134- if max_total is not None and sum (columns_per_table .values ()) > max_total :
135- raise _EarlyStopEstimation
128+ if sum (columns_per_table .values ()) > PERFORMANCE_ALERT_DISPLAY_CAP :
129+ return
136130
137131 visited .add (table_name )
138132
139133 @classmethod
140- def _estimate_num_columns (cls , metadata , distributions = None , max_total = None ):
134+ def _estimate_num_columns (cls , metadata , distributions = None ):
141135 """Estimate the number of columns that will be modeled for each table.
142136
143137 This method estimates how many extended columns will be generated during the
@@ -163,13 +157,10 @@ def _estimate_num_columns(cls, metadata, distributions=None, max_total=None):
163157 # each table will model
164158 visited = set ()
165159 for table_name in _get_root_tables (metadata .relationships ):
166- try :
167- cls ._estimate_columns_traversal (
168- metadata , table_name , columns_per_table , visited , distributions , max_total
169- )
170- except _EarlyStopEstimation :
171- break
172- if max_total is not None and sum (columns_per_table .values ()) > max_total :
160+ cls ._estimate_columns_traversal (
161+ metadata , table_name , columns_per_table , visited , distributions
162+ )
163+ if sum (columns_per_table .values ()) > PERFORMANCE_ALERT_DISPLAY_CAP :
173164 break
174165
175166 return columns_per_table
@@ -250,18 +241,14 @@ def _print_estimate_warning(self):
250241 metadata_columns = self ._get_num_data_columns (self .metadata )
251242 print_table = []
252243 distributions = self ._get_distributions ()
253- estimated_columns = self ._estimate_num_columns (
254- self .metadata , distributions , max_total = PERFORMANCE_ALERT_DISPLAY_CAP
255- )
244+ estimated_columns = self ._estimate_num_columns (self .metadata , distributions )
256245 for table , est_cols in estimated_columns .items ():
257246 entry = []
258247 entry .append (table )
259248 entry .append (metadata_columns [table ])
260249 total_est_cols += est_cols
261250 entry .append (est_cols )
262251 print_table .append (entry )
263- if total_est_cols > PERFORMANCE_ALERT_DISPLAY_CAP :
264- break
265252
266253 if total_est_cols > MAX_NUMBER_OF_COLUMNS :
267254 display_total = (
0 commit comments