diff --git a/lib/data/features/transform.py b/lib/data/features/transform.py index a96911d..6f65d5f 100644 --- a/lib/data/features/transform.py +++ b/lib/data/features/transform.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd from abc import abstractmethod from typing import Callable, Iterable, List @@ -11,7 +12,13 @@ def transform(iterable: Iterable, inplace: bool = True, columns: List[str] = Non else: transformed_iterable = iterable.copy() - transformed_iterable = transformed_iterable.fillna(method='bfill') + if isinstance(transformed_iterable, pd.DataFrame): + is_list = False + transformed_iterable.fillna(method='bfill', inplace=True) + else: + is_list = True + transformed_iterable = pd.DataFrame(transformed_iterable) + transformed_iterable.fillna(method='bfill', axis=1, inplace=True) if transform_fn is None: raise NotImplementedError() @@ -23,6 +30,9 @@ def transform(iterable: Iterable, inplace: bool = True, columns: List[str] = Non transformed_iterable[column] = transform_fn( transformed_iterable[column]) + if is_list: + transformed_iterable = transformed_iterable.as_matrix() + return transformed_iterable