Skip to content

Commit 0fdbfff

Browse files
committed
Add downcast macros (#2635)
1 parent 41a2d2d commit 0fdbfff

File tree

3 files changed

+303
-208
lines changed

3 files changed

+303
-208
lines changed

arrow/src/array/cast.rs

+285
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,207 @@
2020
use crate::array::*;
2121
use crate::datatypes::*;
2222

23+
/// Downcast an [`Array`] to a [`PrimitiveArray`] based on its [`DataType`], accepts
24+
/// a number of subsequent patterns to match the data type
25+
///
26+
/// ```
27+
/// # use arrow::downcast_primitive_array;
28+
/// # use arrow::array::Array;
29+
/// # use arrow::datatypes::DataType;
30+
/// # use arrow::array::as_string_array;
31+
///
32+
/// fn print_primitive(array: &dyn Array) {
33+
/// downcast_primitive_array!(
34+
/// array => {
35+
/// for v in array {
36+
/// println!("{:?}", v);
37+
/// }
38+
/// }
39+
/// DataType::Utf8 => {
40+
/// for v in as_string_array(array) {
41+
/// println!("{:?}", v);
42+
/// }
43+
/// }
44+
/// t => println!("Unsupported datatype {}", t)
45+
/// )
46+
/// }
47+
/// ```
48+
///
49+
#[macro_export]
50+
macro_rules! downcast_primitive_array {
51+
($values:ident => $e:expr, $($p:pat => $fallback:expr)*) => {
52+
downcast_primitive_array!($values => {$e} $($p => $fallback)*)
53+
};
54+
55+
($values:ident => $e:block $($p:pat => $fallback:expr)*) => {
56+
match $values.data_type() {
57+
$crate::datatypes::DataType::Int8 => {
58+
let $values = $crate::array::as_primitive_array::<
59+
$crate::datatypes::Int8Type,
60+
>($values);
61+
$e
62+
}
63+
$crate::datatypes::DataType::Int16 => {
64+
let $values = $crate::array::as_primitive_array::<
65+
$crate::datatypes::Int16Type,
66+
>($values);
67+
$e
68+
}
69+
$crate::datatypes::DataType::Int32 => {
70+
let $values = $crate::array::as_primitive_array::<
71+
$crate::datatypes::Int32Type,
72+
>($values);
73+
$e
74+
}
75+
$crate::datatypes::DataType::Int64 => {
76+
let $values = $crate::array::as_primitive_array::<
77+
$crate::datatypes::Int64Type,
78+
>($values);
79+
$e
80+
}
81+
$crate::datatypes::DataType::UInt8 => {
82+
let $values = $crate::array::as_primitive_array::<
83+
$crate::datatypes::UInt8Type,
84+
>($values);
85+
$e
86+
}
87+
$crate::datatypes::DataType::UInt16 => {
88+
let $values = $crate::array::as_primitive_array::<
89+
$crate::datatypes::UInt16Type,
90+
>($values);
91+
$e
92+
}
93+
$crate::datatypes::DataType::UInt32 => {
94+
let $values = $crate::array::as_primitive_array::<
95+
$crate::datatypes::UInt32Type,
96+
>($values);
97+
$e
98+
}
99+
$crate::datatypes::DataType::UInt64 => {
100+
let $values = $crate::array::as_primitive_array::<
101+
$crate::datatypes::UInt64Type,
102+
>($values);
103+
$e
104+
}
105+
$crate::datatypes::DataType::Float32 => {
106+
let $values = $crate::array::as_primitive_array::<
107+
$crate::datatypes::Float32Type,
108+
>($values);
109+
$e
110+
}
111+
$crate::datatypes::DataType::Float64 => {
112+
let $values = $crate::array::as_primitive_array::<
113+
$crate::datatypes::Float64Type,
114+
>($values);
115+
$e
116+
}
117+
$crate::datatypes::DataType::Date32 => {
118+
let $values = $crate::array::as_primitive_array::<
119+
$crate::datatypes::Date32Type,
120+
>($values);
121+
$e
122+
}
123+
$crate::datatypes::DataType::Date64 => {
124+
let $values = $crate::array::as_primitive_array::<
125+
$crate::datatypes::Date64Type,
126+
>($values);
127+
$e
128+
}
129+
$crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Second) => {
130+
let $values = $crate::array::as_primitive_array::<
131+
$crate::datatypes::Time32SecondType,
132+
>($values);
133+
$e
134+
}
135+
$crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Millisecond) => {
136+
let $values = $crate::array::as_primitive_array::<
137+
$crate::datatypes::Time32MillisecondType,
138+
>($values);
139+
$e
140+
}
141+
$crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Microsecond) => {
142+
let $values = $crate::array::as_primitive_array::<
143+
$crate::datatypes::Time64MicrosecondType,
144+
>($values);
145+
$e
146+
}
147+
$crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Nanosecond) => {
148+
let $values = $crate::array::as_primitive_array::<
149+
$crate::datatypes::Time64NanosecondType,
150+
>($values);
151+
$e
152+
}
153+
$crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Second, _) => {
154+
let $values = $crate::array::as_primitive_array::<
155+
$crate::datatypes::TimestampSecondType,
156+
>($values);
157+
$e
158+
}
159+
$crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Millisecond, _) => {
160+
let $values = $crate::array::as_primitive_array::<
161+
$crate::datatypes::TimestampMillisecondType,
162+
>($values);
163+
$e
164+
}
165+
$crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Microsecond, _) => {
166+
let $values = $crate::array::as_primitive_array::<
167+
$crate::datatypes::TimestampMicrosecondType,
168+
>($values);
169+
$e
170+
}
171+
$crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Nanosecond, _) => {
172+
let $values = $crate::array::as_primitive_array::<
173+
$crate::datatypes::TimestampNanosecondType,
174+
>($values);
175+
$e
176+
}
177+
$crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::YearMonth) => {
178+
let $values = $crate::array::as_primitive_array::<
179+
$crate::datatypes::IntervalYearMonthType,
180+
>($values);
181+
$e
182+
}
183+
$crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::DayTime) => {
184+
let $values = $crate::array::as_primitive_array::<
185+
$crate::datatypes::IntervalDayTimeType,
186+
>($values);
187+
$e
188+
}
189+
$crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::MonthDayNano) => {
190+
let $values = $crate::array::as_primitive_array::<
191+
$crate::datatypes::IntervalMonthDayNanoType,
192+
>($values);
193+
$e
194+
}
195+
$crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Second) => {
196+
let $values = $crate::array::as_primitive_array::<
197+
$crate::datatypes::DurationSecondType,
198+
>($values);
199+
$e
200+
}
201+
$crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Millisecond) => {
202+
let $values = $crate::array::as_primitive_array::<
203+
$crate::datatypes::DurationMillisecondType,
204+
>($values);
205+
$e
206+
}
207+
$crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Microsecond) => {
208+
let $values = $crate::array::as_primitive_array::<
209+
$crate::datatypes::DurationMicrosecondType,
210+
>($values);
211+
$e
212+
}
213+
$crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Nanosecond) => {
214+
let $values = $crate::array::as_primitive_array::<
215+
$crate::datatypes::DurationNanosecondType,
216+
>($values);
217+
$e
218+
}
219+
$($p => $fallback,)*
220+
}
221+
};
222+
}
223+
23224
/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
24225
/// [`PrimitiveArray<T>`], panic'ing on failure.
25226
///
@@ -53,6 +254,90 @@ where
53254
.expect("Unable to downcast to primitive array")
54255
}
55256

257+
/// Downcast an [`Array`] to a [`DictionaryArray`] based on its [`DataType`], accepts
258+
/// a number of subsequent patterns to match the data type
259+
///
260+
/// ```
261+
/// # use arrow::downcast_dict_array;
262+
/// # use arrow::array::Array;
263+
/// # use arrow::datatypes::DataType;
264+
/// # use arrow::array::as_string_array;
265+
///
266+
/// fn print_keys(array: &dyn Array) {
267+
/// downcast_dict_array!(
268+
/// array => {
269+
/// for v in array.keys() {
270+
/// println!("{:?}", v);
271+
/// }
272+
/// }
273+
/// t => println!("Unsupported datatype {}", t)
274+
/// )
275+
/// }
276+
/// ```
277+
#[macro_export]
278+
macro_rules! downcast_dict_array {
279+
($values:ident => $e:expr, $($p:pat => $fallback:expr)*) => {
280+
downcast_dict_array!($values => {$e} $($p => $fallback)*)
281+
};
282+
283+
($values:ident => $e:block $($p:pat => $fallback:expr)*) => {
284+
match $values.data_type() {
285+
$crate::datatypes::DataType::Dictionary(k, _) => match k.as_ref() {
286+
$crate::datatypes::DataType::Int8 => {
287+
let $values = $crate::array::as_dictionary_array::<
288+
$crate::datatypes::Int8Type,
289+
>($values);
290+
$e
291+
},
292+
$crate::datatypes::DataType::Int16 => {
293+
let $values = $crate::array::as_dictionary_array::<
294+
$crate::datatypes::Int16Type,
295+
>($values);
296+
$e
297+
},
298+
$crate::datatypes::DataType::Int32 => {
299+
let $values = $crate::array::as_dictionary_array::<
300+
$crate::datatypes::Int32Type,
301+
>($values);
302+
$e
303+
},
304+
$crate::datatypes::DataType::Int64 => {
305+
let $values = $crate::array::as_dictionary_array::<
306+
$crate::datatypes::Int64Type,
307+
>($values);
308+
$e
309+
},
310+
$crate::datatypes::DataType::UInt8 => {
311+
let $values = $crate::array::as_dictionary_array::<
312+
$crate::datatypes::UInt8Type,
313+
>($values);
314+
$e
315+
},
316+
$crate::datatypes::DataType::UInt16 => {
317+
let $values = $crate::array::as_dictionary_array::<
318+
$crate::datatypes::UInt16Type,
319+
>($values);
320+
$e
321+
},
322+
$crate::datatypes::DataType::UInt32 => {
323+
let $values = $crate::array::as_dictionary_array::<
324+
$crate::datatypes::UInt32Type,
325+
>($values);
326+
$e
327+
},
328+
$crate::datatypes::DataType::UInt64 => {
329+
let $values = $crate::array::as_dictionary_array::<
330+
$crate::datatypes::UInt64Type,
331+
>($values);
332+
$e
333+
},
334+
k => unreachable!("unsupported dictionary key type: {}", k)
335+
}
336+
$($p => $fallback,)*
337+
}
338+
}
339+
}
340+
56341
/// Force downcast of an [`Array`], such as an [`ArrayRef`] to
57342
/// [`DictionaryArray<T>`], panic'ing on failure.
58343
///

0 commit comments

Comments
 (0)