Skip to content

Commit c0c017a

Browse files
zhengruifengdongjoon-hyun
authored andcommitted
[SPARK-52573][PYTHON][TESTS] Add tests for decimal type
### What changes were proposed in this pull request? Add tests for decimal type ### Why are the changes needed? for test coverage ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #51278 from zhengruifeng/py_arrow_decimal. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent d24c718 commit c0c017a

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,9 @@ def test_arrow_udf_datatype_string(self):
536536
F.col("id").alias("long"),
537537
F.col("id").cast("float").alias("float"),
538538
F.col("id").cast("double").alias("double"),
539-
# F.col("id").cast("decimal").alias("decimal"),
539+
F.col("id").cast("decimal").alias("decimal1"),
540+
F.col("id").cast("decimal(10, 0)").alias("decimal2"),
541+
F.col("id").cast("decimal(38, 18)").alias("decimal3"),
540542
F.col("id").cast("boolean").alias("bool"),
541543
)
542544

@@ -549,15 +551,19 @@ def f(x):
549551
long_f = arrow_udf(f, "long", udf_type)
550552
float_f = arrow_udf(f, "float", udf_type)
551553
double_f = arrow_udf(f, "double", udf_type)
552-
# decimal_f = arrow_udf(f, "decimal(38, 18)", udf_type)
554+
decimal1_f = arrow_udf(f, "decimal", udf_type)
555+
decimal2_f = arrow_udf(f, "decimal(10, 0)", udf_type)
556+
decimal3_f = arrow_udf(f, "decimal(38, 18)", udf_type)
553557
bool_f = arrow_udf(f, "boolean", udf_type)
554558
res = df.select(
555559
str_f(F.col("str")),
556560
int_f(F.col("int")),
557561
long_f(F.col("long")),
558562
float_f(F.col("float")),
559563
double_f(F.col("double")),
560-
# decimal_f("decimal"),
564+
decimal1_f("decimal1"),
565+
decimal2_f("decimal2"),
566+
decimal3_f("decimal3"),
561567
bool_f(F.col("bool")),
562568
)
563569
self.assertEqual(df.collect(), res.collect())

0 commit comments

Comments
 (0)