@@ -53,14 +53,16 @@ def _get_num_data_columns(metadata):
5353 columns_per_table = {}
5454 for table_name , table in metadata .tables .items ():
5555 key_columns = metadata ._get_all_keys (table_name )
56- columns_per_table [ table_name ] = sum ([
56+ num_data_columns = sum ([
5757 1
5858 for col_name , col_meta in table .columns .items ()
5959 if (
6060 col_meta ['sdtype' ] != 'id'
6161 or (col_name not in key_columns and col_meta .get ('pii' , False ) is False )
6262 )
6363 ])
64+ num_extended_columns = 0
65+ columns_per_table [table_name ] = [num_data_columns , num_extended_columns ]
6466
6567 return columns_per_table
6668
@@ -85,18 +87,29 @@ def _get_num_extended_columns(
8587 table_name , cls .DEFAULT_SYNTHESIZER_KWARGS ['default_distribution' ]
8688 )
8789
88- num_parameters = cls .DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS [distribution ]
89-
90+ num_params_data = cls .DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS [distribution ]
91+ num_params_extended = cls .DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS [
92+ DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION
93+ ]
9094 num_rows_columns = len (metadata ._get_foreign_keys (parent_table , table_name ))
9195
92- # no parameter columns are generated if there are no data columns
93- num_data_columns = columns_per_table [table_name ]
94- if num_data_columns == 0 :
96+ # no parameter columns are generated if there are no data or extended columns
97+ num_data_columns = columns_per_table [table_name ][0 ]
98+ num_extended_columns = columns_per_table [table_name ][1 ]
99+
100+ if (num_data_columns + num_extended_columns ) == 0 :
95101 return num_rows_columns
96102
97- num_parameters_columns = num_rows_columns * num_data_columns * num_parameters
103+ num_parameters_columns = (num_rows_columns * num_data_columns * num_params_data ) + (
104+ num_rows_columns * num_extended_columns * num_params_extended
105+ )
98106
99- num_correlation_columns = num_rows_columns * (num_data_columns - 1 ) * num_data_columns // 2
107+ num_correlation_columns = (
108+ num_rows_columns
109+ * (num_data_columns + num_extended_columns - 1 )
110+ * (num_data_columns + num_extended_columns )
111+ // 2
112+ )
100113
101114 return num_correlation_columns + num_rows_columns + num_parameters_columns
102115
@@ -118,9 +131,11 @@ def _estimate_columns_traversal(
118131 """
119132 for child_name in metadata ._get_child_map ()[table_name ]:
120133 if child_name not in visited :
121- cls ._estimate_columns_traversal (metadata , child_name , columns_per_table , visited )
134+ cls ._estimate_columns_traversal (
135+ metadata , child_name , columns_per_table , visited , distributions
136+ )
122137
123- columns_per_table [table_name ] += cls ._get_num_extended_columns (
138+ columns_per_table [table_name ][ 1 ] += cls ._get_num_extended_columns (
124139 metadata , child_name , table_name , columns_per_table , distributions
125140 )
126141
@@ -157,7 +172,9 @@ def _estimate_num_columns(cls, metadata, distributions=None):
157172 metadata , table_name , columns_per_table , visited , distributions
158173 )
159174
160- return columns_per_table
175+ return {
176+ table_name : sum (columns_list ) for table_name , columns_list in columns_per_table .items ()
177+ }
161178
162179 def __init__ (self , metadata , locales = ['en_US' ], verbose = True ):
163180 BaseMultiTableSynthesizer .__init__ (self , metadata , locales = locales )
@@ -173,6 +190,11 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
173190 BaseHierarchicalSampler .__init__ (
174191 self , self .metadata , self ._table_synthesizers , self ._table_sizes
175192 )
193+ child_tables = set ()
194+ for relationship in metadata .relationships :
195+ child_tables .add (relationship ['child_table_name' ])
196+ for child_table_name in child_tables :
197+ self .set_table_parameters (child_table_name , {'default_distribution' : 'norm' })
176198 self ._print_estimate_warning ()
177199
178200 def set_table_parameters (self , table_name , table_parameters ):
@@ -238,7 +260,7 @@ def _print_estimate_warning(self):
238260 for table , est_cols in self ._estimate_num_columns (self .metadata , distributions ).items ():
239261 entry = []
240262 entry .append (table )
241- entry .append (metadata_columns [table ])
263+ entry .append (sum ( metadata_columns [table ]) )
242264 total_est_cols += est_cols
243265 entry .append (est_cols )
244266 print_table .append (entry )
@@ -679,6 +701,9 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
679701 parameters = self ._extract_parameters (row , table_name , foreign_key )
680702 table_meta = self ._table_synthesizers [table_name ].get_metadata ()
681703 synthesizer = self ._synthesizer (table_meta , ** self ._table_parameters [table_name ])
704+ extended_columns = getattr (self , '_parent_extended_columns' , {}).get (table_name , [])
705+ if extended_columns :
706+ self ._set_extended_columns_distributions (synthesizer , table_name , extended_columns )
682707 synthesizer ._set_parameters (parameters )
683708 try :
684709 likelihoods [parent_id ] = synthesizer ._get_likelihood (table_rows )
0 commit comments