Skip to content

Commit 45afcf0

Browse files
committed
Add uft8 support and bench
1 parent ee14adf commit 45afcf0

File tree

4 files changed

+155
-8
lines changed

4 files changed

+155
-8
lines changed

datafusion/functions/benches/regx.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
extern crate criterion;
1919

2020
use arrow::array::builder::StringBuilder;
21-
use arrow::array::{ArrayRef, StringArray};
21+
use arrow::array::{ArrayRef, Int64Array, StringArray};
2222
use criterion::{black_box, criterion_group, criterion_main, Criterion};
23+
use datafusion_functions::regex::regexpcount::regexp_count;
2324
use datafusion_functions::regex::regexplike::regexp_like;
2425
use datafusion_functions::regex::regexpmatch::regexp_match;
2526
use datafusion_functions::regex::regexpreplace::regexp_replace;
@@ -59,6 +60,15 @@ fn regex(rng: &mut ThreadRng) -> StringArray {
5960
StringArray::from(data)
6061
}
6162

63+
fn start(rng: &mut ThreadRng) -> Int64Array {
64+
let mut data: Vec<i64> = vec![];
65+
for _ in 0..1000 {
66+
data.push(rng.gen_range(1..5));
67+
}
68+
69+
Int64Array::from(data)
70+
}
71+
6272
fn flags(rng: &mut ThreadRng) -> StringArray {
6373
let samples = [Some("i".to_string()), Some("im".to_string()), None];
6474
let mut sb = StringBuilder::new();
@@ -75,6 +85,26 @@ fn flags(rng: &mut ThreadRng) -> StringArray {
7585
}
7686

7787
fn criterion_benchmark(c: &mut Criterion) {
88+
c.bench_function("regexp_count_1000", |b| {
89+
let mut rng = rand::thread_rng();
90+
let data = Arc::new(data(&mut rng)) as ArrayRef;
91+
let regex = Arc::new(regex(&mut rng)) as ArrayRef;
92+
let start = Arc::new(start(&mut rng)) as ArrayRef;
93+
let flags = Arc::new(flags(&mut rng)) as ArrayRef;
94+
95+
b.iter(|| {
96+
black_box(
97+
regexp_count::<i32>(&[
98+
Arc::clone(&data),
99+
Arc::clone(&regex),
100+
Arc::clone(&start),
101+
Arc::clone(&flags),
102+
])
103+
.expect("regexp_count should work on valid values"),
104+
)
105+
})
106+
});
107+
78108
c.bench_function("regexp_like_1000", |b| {
79109
let mut rng = rand::thread_rng();
80110
let data = Arc::new(data(&mut rng)) as ArrayRef;

datafusion/functions/src/regex/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@ make_udf_function!(
3434
pub mod expr_fn {
3535
use datafusion_expr::Expr;
3636

37-
pub fn regexp_count(values: Expr, regex: Expr, flags: Option<Expr>) -> Expr {
37+
/// Returns the number of consecutive occurrences of a regular expression in a string.
38+
pub fn regexp_count(values: Expr, regex: Expr, start: Option<Expr>, flags: Option<Expr>) -> Expr {
3839
let mut args = vec![values, regex];
40+
if let Some(start) = start {
41+
args.push(start);
42+
};
43+
3944
if let Some(flags) = flags {
4045
args.push(flags);
4146
};

datafusion/functions/src/regex/regexpcount.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ use arrow::array::{
2222
Array, ArrayRef, AsArray, Datum, GenericStringArray, Int64Array, OffsetSizeTrait,
2323
Scalar,
2424
};
25-
use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8};
25+
use arrow::datatypes::DataType::{self, Int64, LargeUtf8, Utf8, Utf8View};
2626
use arrow::datatypes::Int64Type;
2727
use arrow::error::ArrowError;
2828
use datafusion_common::cast::{as_generic_string_array, as_primitive_array};
2929
use datafusion_common::{
3030
arrow_err, exec_err, internal_err, DataFusionError, Result, ScalarValue,
3131
};
32-
use datafusion_expr::TypeSignature::Exact;
32+
use datafusion_expr::TypeSignature::{Exact, Uniform};
3333
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
3434
use itertools::izip;
3535
use regex::Regex;
@@ -51,14 +51,13 @@ impl RegexpCountFunc {
5151
Self {
5252
signature: Signature::one_of(
5353
vec![
54-
Exact(vec![Utf8, Utf8]),
54+
Uniform(2, vec![Utf8, LargeUtf8, Utf8View]),
5555
Exact(vec![Utf8, Utf8, Int64]),
5656
Exact(vec![Utf8, Utf8, Int64, Utf8]),
57-
Exact(vec![Utf8, Utf8, Int64, LargeUtf8]),
58-
Exact(vec![LargeUtf8, LargeUtf8]),
5957
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
60-
Exact(vec![LargeUtf8, LargeUtf8, Int64, Utf8]),
6158
Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]),
59+
Exact(vec![Utf8View, Utf8View, Int64]),
60+
Exact(vec![Utf8View, Utf8View, Int64, Utf8View]),
6261
],
6362
Volatility::Immutable,
6463
),

datafusion/sqllogictest/test_files/regexp.slt

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,119 @@ SELECT regexp_count(str, 'ab', 1, 'i') from t;
553553
0
554554

555555

556+
query I
557+
SELECT regexp_count(str, pattern) from t;
558+
----
559+
1
560+
1
561+
0
562+
0
563+
0
564+
0
565+
1
566+
1
567+
1
568+
1
569+
1
570+
571+
query I
572+
SELECT regexp_count(str, pattern, start) from t;
573+
----
574+
1
575+
1
576+
0
577+
0
578+
0
579+
0
580+
0
581+
1
582+
1
583+
1
584+
1
585+
586+
query I
587+
SELECT regexp_count(str, pattern, start, flags) from t;
588+
----
589+
1
590+
1
591+
1
592+
0
593+
0
594+
0
595+
0
596+
1
597+
1
598+
1
599+
1
600+
601+
# test type coercion
602+
query I
603+
SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t;
604+
----
605+
1
606+
1
607+
1
608+
0
609+
0
610+
0
611+
0
612+
1
613+
1
614+
1
615+
1
616+
617+
# test string views
618+
619+
statement ok
620+
CREATE TABLE t_stringview AS
621+
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t;
622+
623+
query I
624+
SELECT regexp_count(str, '\w') from t;
625+
----
626+
3
627+
3
628+
3
629+
3
630+
3
631+
4
632+
4
633+
10
634+
6
635+
4
636+
7
637+
638+
query I
639+
SELECT regexp_count(str, '\w{2}', start) from t;
640+
----
641+
1
642+
1
643+
1
644+
1
645+
0
646+
2
647+
1
648+
4
649+
1
650+
2
651+
3
652+
653+
query I
654+
SELECT regexp_count(str, 'ab', 1, 'i') from t;
655+
----
656+
1
657+
1
658+
1
659+
1
660+
1
661+
0
662+
0
663+
0
664+
0
665+
0
666+
0
667+
668+
556669
query I
557670
SELECT regexp_count(str, pattern) from t;
558671
----

0 commit comments

Comments
 (0)