Skip to content

Commit 3a7f47b

Browse files
authored
feat: allow if expressions for fallbacks in downcast macro (#7322)
* feat: allow if expressions for fallbacks in downcast macro * add examples of using this guards * add examples of using this guards
1 parent 9f1ab95 commit 3a7f47b

File tree

1 file changed

+83
-45
lines changed

1 file changed

+83
-45
lines changed

arrow-array/src/cast.rs

Lines changed: 83 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ macro_rules! repeat_pat {
6060
/// k.as_ref() => (dictionary_key_size_helper, u8),
6161
/// _ => unreachable!(),
6262
/// },
63+
/// // You can also add a guard to the pattern
64+
/// DataType::LargeUtf8 if true => u8::MAX,
6365
/// _ => u8::MAX,
6466
/// }
6567
/// }
@@ -72,7 +74,7 @@ macro_rules! repeat_pat {
7274
/// [`DataType`]: arrow_schema::DataType
7375
#[macro_export]
7476
macro_rules! downcast_integer {
75-
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => {
77+
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
7678
match ($($data_type),+) {
7779
$crate::repeat_pat!($crate::cast::__private::DataType::Int8, $($data_type),+) => {
7880
$m!($crate::types::Int8Type $(, $args)*)
@@ -98,7 +100,7 @@ macro_rules! downcast_integer {
98100
$crate::repeat_pat!($crate::cast::__private::DataType::UInt64, $($data_type),+) => {
99101
$m!($crate::types::UInt64Type $(, $args)*)
100102
}
101-
$($p => $fallback,)*
103+
$($p $(if $pred)* => $fallback,)*
102104
}
103105
};
104106
}
@@ -107,7 +109,7 @@ macro_rules! downcast_integer {
107109
/// with the corresponding array, along with match statements for any non integer array types
108110
///
109111
/// ```
110-
/// # use arrow_array::{Array, downcast_integer_array, cast::as_string_array};
112+
/// # use arrow_array::{Array, downcast_integer_array, cast::as_string_array, cast::as_largestring_array};
111113
/// # use arrow_schema::DataType;
112114
///
113115
/// fn print_integer(array: &dyn Array) {
@@ -122,6 +124,12 @@ macro_rules! downcast_integer {
122124
/// println!("{:?}", v);
123125
/// }
124126
/// }
127+
/// // You can also add a guard to the pattern
128+
/// DataType::LargeUtf8 if true => {
129+
/// for v in as_largestring_array(array) {
130+
/// println!("{:?}", v);
131+
/// }
132+
/// }
125133
/// t => println!("Unsupported datatype {}", t)
126134
/// )
127135
/// }
@@ -130,19 +138,19 @@ macro_rules! downcast_integer {
130138
/// [`DataType`]: arrow_schema::DataType
131139
#[macro_export]
132140
macro_rules! downcast_integer_array {
133-
($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
134-
$crate::downcast_integer_array!($values => {$e} $($p => $fallback)*)
141+
($values:ident => $e:expr, $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
142+
$crate::downcast_integer_array!($values => {$e} $($p $(if $pred)* => $fallback)*)
135143
};
136-
(($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
137-
$crate::downcast_integer_array!($($values),+ => {$e} $($p => $fallback)*)
144+
(($($values:ident),+) => $e:expr, $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
145+
$crate::downcast_integer_array!($($values),+ => {$e} $($p $(if $pred)* => $fallback)*)
138146
};
139-
($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
140-
$crate::downcast_integer_array!(($($values),+) => $e $($p => $fallback)*)
147+
($($values:ident),+ => $e:block $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
148+
$crate::downcast_integer_array!(($($values),+) => $e $($p $(if $pred)* => $fallback)*)
141149
};
142-
(($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
150+
(($($values:ident),+) => $e:block $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
143151
$crate::downcast_integer!{
144152
$($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e),
145-
$($p => $fallback,)*
153+
$($p $(if $pred)* => $fallback,)*
146154
}
147155
};
148156
}
@@ -167,6 +175,8 @@ macro_rules! downcast_integer_array {
167175
/// k.data_type() => (run_end_size_helper, u8),
168176
/// _ => unreachable!(),
169177
/// },
178+
/// // You can also add a guard to the pattern
179+
/// DataType::LargeUtf8 if true => u8::MAX,
170180
/// _ => u8::MAX,
171181
/// }
172182
/// }
@@ -179,7 +189,7 @@ macro_rules! downcast_integer_array {
179189
/// [`DataType`]: arrow_schema::DataType
180190
#[macro_export]
181191
macro_rules! downcast_run_end_index {
182-
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => {
192+
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
183193
match ($($data_type),+) {
184194
$crate::repeat_pat!($crate::cast::__private::DataType::Int16, $($data_type),+) => {
185195
$m!($crate::types::Int16Type $(, $args)*)
@@ -190,7 +200,7 @@ macro_rules! downcast_run_end_index {
190200
$crate::repeat_pat!($crate::cast::__private::DataType::Int64, $($data_type),+) => {
191201
$m!($crate::types::Int64Type $(, $args)*)
192202
}
193-
$($p => $fallback,)*
203+
$($p $(if $pred)* => $fallback,)*
194204
}
195205
};
196206
}
@@ -211,6 +221,8 @@ macro_rules! downcast_run_end_index {
211221
/// fn temporal_size(t: &DataType) -> u8 {
212222
/// downcast_temporal! {
213223
/// t => (temporal_size_helper, u8),
224+
/// // You can also add a guard to the pattern
225+
/// DataType::LargeUtf8 if true => u8::MAX,
214226
/// _ => u8::MAX
215227
/// }
216228
/// }
@@ -222,7 +234,7 @@ macro_rules! downcast_run_end_index {
222234
/// [`DataType`]: arrow_schema::DataType
223235
#[macro_export]
224236
macro_rules! downcast_temporal {
225-
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => {
237+
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
226238
match ($($data_type),+) {
227239
$crate::repeat_pat!($crate::cast::__private::DataType::Time32($crate::cast::__private::TimeUnit::Second), $($data_type),+) => {
228240
$m!($crate::types::Time32SecondType $(, $args)*)
@@ -254,7 +266,7 @@ macro_rules! downcast_temporal {
254266
$crate::repeat_pat!($crate::cast::__private::DataType::Timestamp($crate::cast::__private::TimeUnit::Nanosecond, _), $($data_type),+) => {
255267
$m!($crate::types::TimestampNanosecondType $(, $args)*)
256268
}
257-
$($p => $fallback,)*
269+
$($p $(if $pred)* => $fallback,)*
258270
}
259271
};
260272
}
@@ -263,7 +275,7 @@ macro_rules! downcast_temporal {
263275
/// accepts a number of subsequent patterns to match the data type
264276
///
265277
/// ```
266-
/// # use arrow_array::{Array, downcast_temporal_array, cast::as_string_array};
278+
/// # use arrow_array::{Array, downcast_temporal_array, cast::as_string_array, cast::as_largestring_array};
267279
/// # use arrow_schema::DataType;
268280
///
269281
/// fn print_temporal(array: &dyn Array) {
@@ -278,6 +290,12 @@ macro_rules! downcast_temporal {
278290
/// println!("{:?}", v);
279291
/// }
280292
/// }
293+
/// // You can also add a guard to the pattern
294+
/// DataType::LargeUtf8 if true => {
295+
/// for v in as_largestring_array(array) {
296+
/// println!("{:?}", v);
297+
/// }
298+
/// }
281299
/// t => println!("Unsupported datatype {}", t)
282300
/// )
283301
/// }
@@ -286,19 +304,19 @@ macro_rules! downcast_temporal {
286304
/// [`DataType`]: arrow_schema::DataType
287305
#[macro_export]
288306
macro_rules! downcast_temporal_array {
289-
($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
290-
$crate::downcast_temporal_array!($values => {$e} $($p => $fallback)*)
307+
($values:ident => $e:expr, $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
308+
$crate::downcast_temporal_array!($values => {$e} $($p $(if $pred)* => $fallback)*)
291309
};
292-
(($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
293-
$crate::downcast_temporal_array!($($values),+ => {$e} $($p => $fallback)*)
310+
(($($values:ident),+) => $e:expr, $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
311+
$crate::downcast_temporal_array!($($values),+ => {$e} $($p $(if $pred)* => $fallback)*)
294312
};
295-
($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
296-
$crate::downcast_temporal_array!(($($values),+) => $e $($p => $fallback)*)
313+
($($values:ident),+ => $e:block $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
314+
$crate::downcast_temporal_array!(($($values),+) => $e $($p $(if $pred)* => $fallback)*)
297315
};
298-
(($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
316+
(($($values:ident),+) => $e:block $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
299317
$crate::downcast_temporal!{
300318
$($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e),
301-
$($p => $fallback,)*
319+
$($p $(if $pred)* => $fallback,)*
302320
}
303321
};
304322
}
@@ -319,6 +337,8 @@ macro_rules! downcast_temporal_array {
319337
/// fn primitive_size(t: &DataType) -> u8 {
320338
/// downcast_primitive! {
321339
/// t => (primitive_size_helper, u8),
340+
/// // You can also add a guard to the pattern
341+
/// DataType::LargeUtf8 if true => u8::MAX,
322342
/// _ => u8::MAX
323343
/// }
324344
/// }
@@ -333,7 +353,7 @@ macro_rules! downcast_temporal_array {
333353
/// [`DataType`]: arrow_schema::DataType
334354
#[macro_export]
335355
macro_rules! downcast_primitive {
336-
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => {
356+
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
337357
$crate::downcast_integer! {
338358
$($data_type),+ => ($m $(, $args)*),
339359
$crate::repeat_pat!($crate::cast::__private::DataType::Float16, $($data_type),+) => {
@@ -375,7 +395,7 @@ macro_rules! downcast_primitive {
375395
_ => {
376396
$crate::downcast_temporal! {
377397
$($data_type),+ => ($m $(, $args)*),
378-
$($p => $fallback,)*
398+
$($p $(if $pred)* => $fallback,)*
379399
}
380400
}
381401
}
@@ -395,7 +415,7 @@ macro_rules! downcast_primitive_array_helper {
395415
/// accepts a number of subsequent patterns to match the data type
396416
///
397417
/// ```
398-
/// # use arrow_array::{Array, downcast_primitive_array, cast::as_string_array};
418+
/// # use arrow_array::{Array, downcast_primitive_array, cast::as_string_array, cast::as_largestring_array};
399419
/// # use arrow_schema::DataType;
400420
///
401421
/// fn print_primitive(array: &dyn Array) {
@@ -410,6 +430,12 @@ macro_rules! downcast_primitive_array_helper {
410430
/// println!("{:?}", v);
411431
/// }
412432
/// }
433+
/// // You can also add a guard to the pattern
434+
/// DataType::LargeUtf8 if true => {
435+
/// for v in as_largestring_array(array) {
436+
/// println!("{:?}", v);
437+
/// }
438+
/// }
413439
/// t => println!("Unsupported datatype {}", t)
414440
/// )
415441
/// }
@@ -418,19 +444,19 @@ macro_rules! downcast_primitive_array_helper {
418444
/// [`DataType`]: arrow_schema::DataType
419445
#[macro_export]
420446
macro_rules! downcast_primitive_array {
421-
($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
422-
$crate::downcast_primitive_array!($values => {$e} $($p => $fallback)*)
447+
($values:ident => $e:expr, $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
448+
$crate::downcast_primitive_array!($values => {$e} $($p $(if $pred)* => $fallback)*)
423449
};
424-
(($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
425-
$crate::downcast_primitive_array!($($values),+ => {$e} $($p => $fallback)*)
450+
(($($values:ident),+) => $e:expr, $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
451+
$crate::downcast_primitive_array!($($values),+ => {$e} $($p $(if $pred)* => $fallback)*)
426452
};
427-
($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
428-
$crate::downcast_primitive_array!(($($values),+) => $e $($p => $fallback)*)
453+
($($values:ident),+ => $e:block $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
454+
$crate::downcast_primitive_array!(($($values),+) => $e $($p $(if $pred)* => $fallback)*)
429455
};
430-
(($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
456+
(($($values:ident),+) => $e:block $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
431457
$crate::downcast_primitive!{
432458
$($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e),
433-
$($p => $fallback,)*
459+
$($p $(if $pred)* => $fallback,)*
434460
}
435461
};
436462
}
@@ -482,7 +508,7 @@ macro_rules! downcast_dictionary_array_helper {
482508
/// a number of subsequent patterns to match the data type
483509
///
484510
/// ```
485-
/// # use arrow_array::{Array, StringArray, downcast_dictionary_array, cast::as_string_array};
511+
/// # use arrow_array::{Array, StringArray, downcast_dictionary_array, cast::as_string_array, cast::as_largestring_array};
486512
/// # use arrow_schema::DataType;
487513
///
488514
/// fn print_strings(array: &dyn Array) {
@@ -500,6 +526,12 @@ macro_rules! downcast_dictionary_array_helper {
500526
/// println!("{:?}", v);
501527
/// }
502528
/// }
529+
/// // You can also add a guard to the pattern
530+
/// DataType::LargeUtf8 if true => {
531+
/// for v in as_largestring_array(array) {
532+
/// println!("{:?}", v);
533+
/// }
534+
/// }
503535
/// t => println!("Unsupported datatype {}", t)
504536
/// )
505537
/// }
@@ -508,19 +540,19 @@ macro_rules! downcast_dictionary_array_helper {
508540
/// [`DataType`]: arrow_schema::DataType
509541
#[macro_export]
510542
macro_rules! downcast_dictionary_array {
511-
($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
512-
downcast_dictionary_array!($values => {$e} $($p => $fallback)*)
543+
($values:ident => $e:expr, $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
544+
downcast_dictionary_array!($values => {$e} $($p $(if $pred)* => $fallback)*)
513545
};
514546

515-
($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
547+
($values:ident => $e:block $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
516548
match $values.data_type() {
517549
$crate::cast::__private::DataType::Dictionary(k, _) => {
518550
$crate::downcast_integer! {
519551
k.as_ref() => ($crate::downcast_dictionary_array_helper, $values, $e),
520552
k => unreachable!("unsupported dictionary key type: {}", k)
521553
}
522554
}
523-
$($p => $fallback,)*
555+
$($p $(if $pred)* => $fallback,)*
524556
}
525557
}
526558
}
@@ -584,7 +616,7 @@ macro_rules! downcast_run_array_helper {
584616
/// a number of subsequent patterns to match the data type
585617
///
586618
/// ```
587-
/// # use arrow_array::{Array, StringArray, downcast_run_array, cast::as_string_array};
619+
/// # use arrow_array::{Array, StringArray, downcast_run_array, cast::as_string_array, cast::as_largestring_array};
588620
/// # use arrow_schema::DataType;
589621
///
590622
/// fn print_strings(array: &dyn Array) {
@@ -602,6 +634,12 @@ macro_rules! downcast_run_array_helper {
602634
/// println!("{:?}", v);
603635
/// }
604636
/// }
637+
/// // You can also add a guard to the pattern
638+
/// DataType::LargeUtf8 if true => {
639+
/// for v in as_largestring_array(array) {
640+
/// println!("{:?}", v);
641+
/// }
642+
/// }
605643
/// t => println!("Unsupported datatype {}", t)
606644
/// )
607645
/// }
@@ -610,19 +648,19 @@ macro_rules! downcast_run_array_helper {
610648
/// [`DataType`]: arrow_schema::DataType
611649
#[macro_export]
612650
macro_rules! downcast_run_array {
613-
($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
614-
downcast_run_array!($values => {$e} $($p => $fallback)*)
651+
($values:ident => $e:expr, $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
652+
downcast_run_array!($values => {$e} $($p $(if $pred)* => $fallback)*)
615653
};
616654

617-
($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
655+
($values:ident => $e:block $($p:pat $(if $pred:expr)* => $fallback:expr $(,)*)*) => {
618656
match $values.data_type() {
619657
$crate::cast::__private::DataType::RunEndEncoded(k, _) => {
620658
$crate::downcast_run_end_index! {
621659
k.data_type() => ($crate::downcast_run_array_helper, $values, $e),
622660
k => unreachable!("unsupported run end index type: {}", k)
623661
}
624662
}
625-
$($p => $fallback,)*
663+
$($p $(if $pred)* => $fallback,)*
626664
}
627665
}
628666
}

0 commit comments

Comments
 (0)