diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index f7f8c89ab2783..7cb9ade3d8d99 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -104,6 +104,7 @@ from pyspark.pandas.plot import PandasOnSparkPlotAccessor from pyspark.pandas.utils import ( combine_frames, + is_ansi_mode_enabled, is_name_like_tuple, is_name_like_value, name_like_string, @@ -5081,33 +5082,68 @@ def replace( ) ) to_replace = {k: v for k, v in zip(to_replace, value)} + + spark_session = self._internal.spark_frame.sparkSession + ansi_mode = is_ansi_mode_enabled(spark_session) + col_type = self.spark.data_type + if isinstance(to_replace, dict): is_start = True if len(to_replace) == 0: current = self.spark.column else: for to_replace_, value in to_replace.items(): - cond = ( - (F.isnan(self.spark.column) | self.spark.column.isNull()) - if pd.isna(to_replace_) - else (self.spark.column == F.lit(to_replace_)) - ) + if pd.isna(to_replace_): + if ansi_mode and isinstance(col_type, NumericType): + cond = F.isnan(self.spark.column) | self.spark.column.isNull() + else: + cond = self.spark.column.isNull() + else: + lit = ( + F.lit(to_replace_).try_cast(col_type) + if ansi_mode + else F.lit(to_replace_) + ) + cond = self.spark.column == lit + value_expr = F.lit(value).try_cast(col_type) if ansi_mode else F.lit(value) if is_start: - current = F.when(cond, value) + current = F.when(cond, value_expr) is_start = False else: - current = current.when(cond, value) + current = current.when(cond, value_expr) current = current.otherwise(self.spark.column) else: if regex: # to_replace must be a string cond = self.spark.column.rlike(cast(str, to_replace)) else: - cond = self.spark.column.isin(to_replace) + if ansi_mode: + to_replace_values = ( + [to_replace] + if not is_list_like(to_replace) or isinstance(to_replace, str) + else to_replace + ) + to_replace_values = cast(List[Any], to_replace_values) + literals = [F.lit(v).try_cast(col_type) for v in to_replace_values] + cond = self.spark.column.isin(literals) + else: + cond = self.spark.column.isin(to_replace) # to_replace may be a scalar if np.array(pd.isna(to_replace)).any(): - cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull() - current = F.when(cond, value).otherwise(self.spark.column) + if ansi_mode: + if isinstance(col_type, NumericType): + cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull() + else: + cond = cond | self.spark.column.isNull() + else: + cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull() + + if ansi_mode: + value_expr = F.lit(value).try_cast(col_type) + current = F.when(cond, value_expr).otherwise(self.spark.column.try_cast(col_type)) + + else: + current = F.when(cond, value).otherwise(self.spark.column) return self._with_new_scol(current) # TODO: dtype? diff --git a/python/pyspark/pandas/tests/computation/test_missing_data.py b/python/pyspark/pandas/tests/computation/test_missing_data.py index dfecaf4be20b0..8489ce64b68ff 100644 --- a/python/pyspark/pandas/tests/computation/test_missing_data.py +++ b/python/pyspark/pandas/tests/computation/test_missing_data.py @@ -274,7 +274,6 @@ def test_fillna(self): pdf.fillna({("x", "a"): -1, ("x", "b"): -2, ("y", "c"): -5}), ) - @unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message) def test_replace(self): pdf = pd.DataFrame( {