Skip to content

Commit 043fe73

Browse files
committed
unify replace op, check for forecastoperatorspec
1 parent a20b90f commit 043fe73

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
InvalidParameterError,
1515
)
1616
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
17+
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorSpec
1718

1819

1920
class Transformations(ABC):
@@ -60,7 +61,7 @@ def run(self, data):
6061
6162
"""
6263
clean_df = self._remove_trailing_whitespace(data)
63-
if hasattr(self.dataset_info, 'horizon'):
64+
if isinstance(self.dataset_info, ForecastOperatorSpec):
6465
clean_df = self._clean_column_names(clean_df)
6566
if self.name == "historical_data":
6667
self._check_historical_dataset(clean_df)
@@ -109,9 +110,11 @@ def _clean_column_names(self, df):
109110
Returns:
110111
pd.DataFrame: The DataFrame with cleaned column names.
111112
"""
113+
112114
self.raw_column_names = {
113115
col: col.replace(" ", "") for col in df.columns if " " in col
114116
}
117+
df.columns = [self.raw_column_names.get(col, col) for col in df.columns]
115118

116119
if self.target_column_name:
117120
self.target_column_name = self.raw_column_names.get(
@@ -123,9 +126,9 @@ def _clean_column_names(self, df):
123126

124127
if self.target_category_columns:
125128
self.target_category_columns = [
126-
self.raw_column_names.get(col, col) for col in self.target_category_columns
129+
self.raw_column_names.get(col, col)
130+
for col in self.target_category_columns
127131
]
128-
df.columns = df.columns.str.replace(" ", "")
129132
return df
130133

131134
def _set_series_id_column(self, df):

0 commit comments

Comments
 (0)