14
14
InvalidParameterError ,
15
15
)
16
16
from ads .opctl .operator .lowcode .common .utils import merge_category_columns
17
+ from ads .opctl .operator .lowcode .forecast .operator_config import ForecastOperatorSpec
17
18
18
19
19
20
class Transformations (ABC ):
@@ -60,7 +61,7 @@ def run(self, data):
60
61
61
62
"""
62
63
clean_df = self ._remove_trailing_whitespace (data )
63
- if hasattr (self .dataset_info , 'horizon' ):
64
+ if isinstance (self .dataset_info , ForecastOperatorSpec ):
64
65
clean_df = self ._clean_column_names (clean_df )
65
66
if self .name == "historical_data" :
66
67
self ._check_historical_dataset (clean_df )
@@ -109,9 +110,11 @@ def _clean_column_names(self, df):
109
110
Returns:
110
111
pd.DataFrame: The DataFrame with cleaned column names.
111
112
"""
113
+
112
114
self .raw_column_names = {
113
115
col : col .replace (" " , "" ) for col in df .columns if " " in col
114
116
}
117
+ df .columns = [self .raw_column_names .get (col , col ) for col in df .columns ]
115
118
116
119
if self .target_column_name :
117
120
self .target_column_name = self .raw_column_names .get (
@@ -123,9 +126,9 @@ def _clean_column_names(self, df):
123
126
124
127
if self .target_category_columns :
125
128
self .target_category_columns = [
126
- self .raw_column_names .get (col , col ) for col in self .target_category_columns
129
+ self .raw_column_names .get (col , col )
130
+ for col in self .target_category_columns
127
131
]
128
- df .columns = df .columns .str .replace (" " , "" )
129
132
return df
130
133
131
134
def _set_series_id_column (self , df ):
0 commit comments