Skip to content

Commit b7da86e

Browse files
authored
feat(spark): Implement Spark string function luhn_check (#16848)
* feat(spark): Implement Spark luhn_check function Signed-off-by: Alan Tang <[email protected]> * test(spark): add more tests Signed-off-by: Alan Tang <[email protected]> * feat(spark): set the signature to be Utf8 type Signed-off-by: Alan Tang <[email protected]> * chore: add more types for luhn_check function Signed-off-by: Alan Tang <[email protected]> * test: add test for array input Signed-off-by: Alan Tang <[email protected]> --------- Signed-off-by: Alan Tang <[email protected]>
1 parent a6d4798 commit b7da86e

File tree

3 files changed

+303
-21
lines changed

3 files changed

+303
-21
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::{any::Any, sync::Arc};
19+
20+
use arrow::array::{Array, AsArray, BooleanArray};
21+
use arrow::datatypes::DataType;
22+
use arrow::datatypes::DataType::Boolean;
23+
use datafusion_common::utils::take_function_args;
24+
use datafusion_common::{exec_err, Result, ScalarValue};
25+
use datafusion_expr::{
26+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
27+
Volatility,
28+
};
29+
30+
/// Spark-compatible `luhn_check` expression
31+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#luhn_check>
32+
#[derive(Debug)]
33+
pub struct SparkLuhnCheck {
34+
signature: Signature,
35+
}
36+
37+
impl Default for SparkLuhnCheck {
38+
fn default() -> Self {
39+
Self::new()
40+
}
41+
}
42+
43+
impl SparkLuhnCheck {
44+
pub fn new() -> Self {
45+
Self {
46+
signature: Signature::one_of(
47+
vec![
48+
TypeSignature::Exact(vec![DataType::Utf8]),
49+
TypeSignature::Exact(vec![DataType::Utf8View]),
50+
TypeSignature::Exact(vec![DataType::LargeUtf8]),
51+
],
52+
Volatility::Immutable,
53+
),
54+
}
55+
}
56+
}
57+
58+
impl ScalarUDFImpl for SparkLuhnCheck {
59+
fn as_any(&self) -> &dyn Any {
60+
self
61+
}
62+
63+
fn name(&self) -> &str {
64+
"luhn_check"
65+
}
66+
67+
fn signature(&self) -> &Signature {
68+
&self.signature
69+
}
70+
71+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
72+
Ok(Boolean)
73+
}
74+
75+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
76+
let [array] = take_function_args(self.name(), &args.args)?;
77+
78+
match array {
79+
ColumnarValue::Array(array) => match array.data_type() {
80+
DataType::Utf8View => {
81+
let str_array = array.as_string_view();
82+
let values = str_array
83+
.iter()
84+
.map(|s| s.map(luhn_check_impl))
85+
.collect::<BooleanArray>();
86+
Ok(ColumnarValue::Array(Arc::new(values)))
87+
}
88+
DataType::Utf8 => {
89+
let str_array = array.as_string::<i32>();
90+
let values = str_array
91+
.iter()
92+
.map(|s| s.map(luhn_check_impl))
93+
.collect::<BooleanArray>();
94+
Ok(ColumnarValue::Array(Arc::new(values)))
95+
}
96+
DataType::LargeUtf8 => {
97+
let str_array = array.as_string::<i64>();
98+
let values = str_array
99+
.iter()
100+
.map(|s| s.map(luhn_check_impl))
101+
.collect::<BooleanArray>();
102+
Ok(ColumnarValue::Array(Arc::new(values)))
103+
}
104+
other => {
105+
exec_err!("Unsupported data type {other:?} for function `luhn_check`")
106+
}
107+
},
108+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
109+
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s)))
110+
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => Ok(
111+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(luhn_check_impl(s)))),
112+
),
113+
ColumnarValue::Scalar(ScalarValue::Utf8(None))
114+
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))
115+
| ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {
116+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))
117+
}
118+
other => {
119+
exec_err!("Unsupported data type {other:?} for function `luhn_check`")
120+
}
121+
}
122+
}
123+
}
124+
125+
/// Validates a string using the Luhn algorithm.
126+
/// Returns `true` if the input is a valid Luhn number.
127+
fn luhn_check_impl(input: &str) -> bool {
128+
let mut sum = 0u32;
129+
let mut alt = false;
130+
let mut digits_processed = 0;
131+
132+
for b in input.as_bytes().iter().rev() {
133+
let digit = match b {
134+
b'0'..=b'9' => {
135+
digits_processed += 1;
136+
b - b'0'
137+
}
138+
_ => return false,
139+
};
140+
141+
let mut val = digit as u32;
142+
if alt {
143+
val *= 2;
144+
if val > 9 {
145+
val -= 9;
146+
}
147+
}
148+
sum += val;
149+
alt = !alt;
150+
}
151+
152+
digits_processed > 0 && sum % 10 == 0
153+
}

datafusion/spark/src/function/string/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
pub mod ascii;
1919
pub mod char;
20+
pub mod luhn_check;
2021

2122
use datafusion_expr::ScalarUDF;
2223
use datafusion_functions::make_udf_function;
2324
use std::sync::Arc;
2425

2526
make_udf_function!(ascii::SparkAscii, ascii);
2627
make_udf_function!(char::SparkChar, char);
28+
make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check);
2729

2830
pub mod expr_fn {
2931
use datafusion_functions::export_functions;
@@ -38,8 +40,13 @@ pub mod expr_fn {
3840
"Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).",
3941
arg1
4042
));
43+
export_functions!((
44+
luhn_check,
45+
"Returns whether the input string of digits is valid according to the Luhn algorithm.",
46+
arg1
47+
));
4148
}
4249

4350
pub fn functions() -> Vec<Arc<ScalarUDF>> {
44-
vec![ascii(), char()]
51+
vec![ascii(), char(), luhn_check()]
4552
}

datafusion/sqllogictest/test_files/spark/string/luhn_check.slt

Lines changed: 142 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,145 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
# This file was originally created by a porting script from:
19-
# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function
20-
# This file is part of the implementation of the datafusion-spark function library.
21-
# For more information, please see:
22-
# https://github.com/apache/datafusion/issues/15914
23-
24-
## Original Query: SELECT luhn_check('79927398713');
25-
## PySpark 3.5.5 Result: {'luhn_check(79927398713)': True, 'typeof(luhn_check(79927398713))': 'boolean', 'typeof(79927398713)': 'string'}
26-
#query
27-
#SELECT luhn_check('79927398713'::string);
28-
29-
## Original Query: SELECT luhn_check('79927398714');
30-
## PySpark 3.5.5 Result: {'luhn_check(79927398714)': False, 'typeof(luhn_check(79927398714))': 'boolean', 'typeof(79927398714)': 'string'}
31-
#query
32-
#SELECT luhn_check('79927398714'::string);
33-
34-
## Original Query: SELECT luhn_check('8112189876');
35-
## PySpark 3.5.5 Result: {'luhn_check(8112189876)': True, 'typeof(luhn_check(8112189876))': 'boolean', 'typeof(8112189876)': 'string'}
36-
#query
37-
#SELECT luhn_check('8112189876'::string);
18+
19+
query B
20+
SELECT luhn_check('79927398713'::string);
21+
----
22+
true
23+
24+
25+
query B
26+
SELECT luhn_check('79927398714'::string);
27+
----
28+
false
29+
30+
31+
query B
32+
SELECT luhn_check('8112189876'::string);
33+
----
34+
true
35+
36+
query B
37+
select luhn_check('4111111111111111'::string);
38+
----
39+
true
40+
41+
query B
42+
select luhn_check('5500000000000004'::string);
43+
----
44+
true
45+
46+
query B
47+
select luhn_check('340000000000009'::string);
48+
----
49+
true
50+
51+
query B
52+
select luhn_check('6011000000000004'::string);
53+
----
54+
true
55+
56+
57+
query B
58+
select luhn_check('6011000000000005'::string);
59+
----
60+
false
61+
62+
63+
query B
64+
select luhn_check('378282246310006'::string);
65+
----
66+
false
67+
68+
69+
query B
70+
select luhn_check('0'::string);
71+
----
72+
true
73+
74+
75+
query B
76+
select luhn_check('79927398713'::string)
77+
----
78+
true
79+
80+
query B
81+
select luhn_check('4417123456789113'::string)
82+
----
83+
true
84+
85+
query B
86+
select luhn_check('7992 7398 714'::string)
87+
----
88+
false
89+
90+
query B
91+
select luhn_check('79927398714'::string)
92+
----
93+
false
94+
95+
query B
96+
select luhn_check('4111111111111111 '::string)
97+
----
98+
false
99+
100+
101+
query B
102+
select luhn_check('4111111 111111111'::string)
103+
----
104+
false
105+
106+
query B
107+
select luhn_check(' 4111111111111111'::string)
108+
----
109+
false
110+
111+
query B
112+
select luhn_check(''::string)
113+
----
114+
false
115+
116+
query B
117+
select luhn_check(' ')
118+
----
119+
false
120+
121+
122+
query B
123+
select luhn_check('510B105105105106'::string)
124+
----
125+
false
126+
127+
128+
query B
129+
select luhn_check('ABCDED'::string)
130+
----
131+
false
132+
133+
query B
134+
select luhn_check(null);
135+
----
136+
NULL
137+
138+
query B
139+
select luhn_check(6011111111111117::BIGINT)
140+
----
141+
true
142+
143+
144+
query B
145+
select luhn_check(6011111111111118::BIGINT)
146+
----
147+
false
148+
149+
150+
query B
151+
select luhn_check(123.456::decimal(6,3))
152+
----
153+
false
154+
155+
query B
156+
SELECT luhn_check(a) FROM (VALUES ('79927398713'::string), ('79927398714'::string)) AS t(a);
157+
----
158+
true
159+
false

0 commit comments

Comments
 (0)