Skip to content

Commit 2f78536

Browse files
authored
Consistently coerce dictionaries for arithmetic (#6785)
* Coerce dictionaries for arithmetic * Clippy
1 parent 25b60e4 commit 2f78536

File tree

4 files changed

+53
-237
lines changed

4 files changed

+53
-237
lines changed

datafusion/expr/src/type_coercion/binary.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,13 @@ fn math_decimal_coercion(
226226
use arrow::datatypes::DataType::*;
227227

228228
match (lhs_type, rhs_type) {
229-
(Dictionary(key_type, value_type), _) => {
229+
(Dictionary(_, value_type), _) => {
230230
let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?;
231-
Some((Dictionary(key_type.clone(), Box::new(value_type)), rhs_type))
231+
Some((value_type, rhs_type))
232232
}
233-
(_, Dictionary(key_type, value_type)) => {
233+
(_, Dictionary(_, value_type)) => {
234234
let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?;
235-
Some((lhs_type, Dictionary(key_type.clone(), Box::new(value_type))))
235+
Some((lhs_type, value_type))
236236
}
237237
(Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => {
238238
Some((dec_type.clone(), dec_type.clone()))
@@ -490,10 +490,8 @@ fn mathematics_numerical_coercion(
490490
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
491491
mathematics_numerical_coercion(lhs_value_type, rhs_value_type)
492492
}
493-
(Dictionary(key_type, value_type), _) => {
494-
let value_type = mathematics_numerical_coercion(value_type, rhs_type);
495-
value_type
496-
.map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)))
493+
(Dictionary(_, value_type), _) => {
494+
mathematics_numerical_coercion(value_type, rhs_type)
497495
}
498496
(_, Dictionary(_, value_type)) => {
499497
mathematics_numerical_coercion(lhs_type, value_type)

datafusion/physical-expr/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
3737
default = ["crypto_expressions", "regex_expressions", "unicode_expressions"]
3838
# Enables support for non-scalar, binary operations on dictionaries
3939
# Note: this results in significant additional codegen
40-
dictionary_expressions = ["arrow/dyn_cmp_dict", "arrow/dyn_arith_dict"]
40+
dictionary_expressions = ["arrow/dyn_cmp_dict"]
4141
regex_expressions = ["regex"]
4242
unicode_expressions = ["unicode-segmentation"]
4343

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 37 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,15 +1351,20 @@ mod tests {
13511351
use datafusion_common::{ColumnStatistics, Result, Statistics};
13521352
use datafusion_expr::type_coercion::binary::get_input_types;
13531353

1354-
// Create a binary expression without coercion. Used here when we do not want to coerce the expressions
1355-
// to valid types. Usage can result in an execution (after plan) error.
1356-
fn binary_simple(
1357-
l: Arc<dyn PhysicalExpr>,
1354+
/// Performs a binary operation, applying any type coercion necessary
1355+
fn binary_op(
1356+
left: Arc<dyn PhysicalExpr>,
13581357
op: Operator,
1359-
r: Arc<dyn PhysicalExpr>,
1360-
input_schema: &Schema,
1361-
) -> Arc<dyn PhysicalExpr> {
1362-
binary(l, op, r, input_schema).unwrap()
1358+
right: Arc<dyn PhysicalExpr>,
1359+
schema: &Schema,
1360+
) -> Result<Arc<dyn PhysicalExpr>> {
1361+
let left_type = left.data_type(schema)?;
1362+
let right_type = right.data_type(schema)?;
1363+
let (lhs, rhs) = get_input_types(&left_type, &op, &right_type)?;
1364+
1365+
let left_expr = try_cast(left, schema, lhs)?;
1366+
let right_expr = try_cast(right, schema, rhs)?;
1367+
binary(left_expr, op, right_expr, schema)
13631368
}
13641369

13651370
#[test]
@@ -1372,12 +1377,12 @@ mod tests {
13721377
let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
13731378

13741379
// expression: "a < b"
1375-
let lt = binary_simple(
1380+
let lt = binary(
13761381
col("a", &schema)?,
13771382
Operator::Lt,
13781383
col("b", &schema)?,
13791384
&schema,
1380-
);
1385+
)?;
13811386
let batch =
13821387
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
13831388

@@ -1404,22 +1409,22 @@ mod tests {
14041409
let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
14051410

14061411
// expression: "a < b OR a == b"
1407-
let expr = binary_simple(
1408-
binary_simple(
1412+
let expr = binary(
1413+
binary(
14091414
col("a", &schema)?,
14101415
Operator::Lt,
14111416
col("b", &schema)?,
14121417
&schema,
1413-
),
1418+
)?,
14141419
Operator::Or,
1415-
binary_simple(
1420+
binary(
14161421
col("a", &schema)?,
14171422
Operator::Eq,
14181423
col("b", &schema)?,
14191424
&schema,
1420-
),
1425+
)?,
14211426
&schema,
1422-
);
1427+
)?;
14231428
let batch =
14241429
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
14251430

@@ -1492,7 +1497,7 @@ mod tests {
14921497
}
14931498

14941499
#[test]
1495-
fn test_type_coersion() -> Result<()> {
1500+
fn test_type_coercion() -> Result<()> {
14961501
test_coercion!(
14971502
Int32Array,
14981503
DataType::Int32,
@@ -1814,8 +1819,7 @@ mod tests {
18141819
// is no way at the time of this writing to create a dictionary
18151820
// array using the `From` trait
18161821
#[test]
1817-
#[cfg(feature = "dictionary_expressions")]
1818-
fn test_dictionary_type_to_array_coersion() -> Result<()> {
1822+
fn test_dictionary_type_to_array_coercion() -> Result<()> {
18191823
// Test string a string dictionary
18201824
let dict_type =
18211825
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
@@ -1878,7 +1882,6 @@ mod tests {
18781882
}
18791883

18801884
#[test]
1881-
#[cfg(feature = "dictionary_expressions")]
18821885
fn plus_op_dict() -> Result<()> {
18831886
let schema = Schema::new(vec![
18841887
Field::new(
@@ -1912,7 +1915,6 @@ mod tests {
19121915
}
19131916

19141917
#[test]
1915-
#[cfg(feature = "dictionary_expressions")]
19161918
fn plus_op_dict_decimal() -> Result<()> {
19171919
let schema = Schema::new(vec![
19181920
Field::new(
@@ -2096,7 +2098,6 @@ mod tests {
20962098
}
20972099

20982100
#[test]
2099-
#[cfg(feature = "dictionary_expressions")]
21002101
fn minus_op_dict() -> Result<()> {
21012102
let schema = Schema::new(vec![
21022103
Field::new(
@@ -2130,7 +2131,6 @@ mod tests {
21302131
}
21312132

21322133
#[test]
2133-
#[cfg(feature = "dictionary_expressions")]
21342134
fn minus_op_dict_decimal() -> Result<()> {
21352135
let schema = Schema::new(vec![
21362136
Field::new(
@@ -2306,7 +2306,6 @@ mod tests {
23062306
}
23072307

23082308
#[test]
2309-
#[cfg(feature = "dictionary_expressions")]
23102309
fn multiply_op_dict() -> Result<()> {
23112310
let schema = Schema::new(vec![
23122311
Field::new(
@@ -2340,7 +2339,6 @@ mod tests {
23402339
}
23412340

23422341
#[test]
2343-
#[cfg(feature = "dictionary_expressions")]
23442342
fn multiply_op_dict_decimal() -> Result<()> {
23452343
let schema = Schema::new(vec![
23462344
Field::new(
@@ -2514,7 +2512,6 @@ mod tests {
25142512
}
25152513

25162514
#[test]
2517-
#[cfg(feature = "dictionary_expressions")]
25182515
fn divide_op_dict() -> Result<()> {
25192516
let schema = Schema::new(vec![
25202517
Field::new(
@@ -2554,7 +2551,6 @@ mod tests {
25542551
}
25552552

25562553
#[test]
2557-
#[cfg(feature = "dictionary_expressions")]
25582554
fn divide_op_dict_decimal() -> Result<()> {
25592555
let schema = Schema::new(vec![
25602556
Field::new(
@@ -2740,7 +2736,6 @@ mod tests {
27402736
}
27412737

27422738
#[test]
2743-
#[cfg(feature = "dictionary_expressions")]
27442739
fn modulus_op_dict() -> Result<()> {
27452740
let schema = Schema::new(vec![
27462741
Field::new(
@@ -2780,7 +2775,6 @@ mod tests {
27802775
}
27812776

27822777
#[test]
2783-
#[cfg(feature = "dictionary_expressions")]
27842778
fn modulus_op_dict_decimal() -> Result<()> {
27852779
let schema = Schema::new(vec![
27862780
Field::new(
@@ -2937,7 +2931,7 @@ mod tests {
29372931
expected: PrimitiveArray<T>,
29382932
) -> Result<()> {
29392933
let arithmetic_op =
2940-
binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema);
2934+
binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?;
29412935
let batch = RecordBatch::try_new(schema, data)?;
29422936
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
29432937

@@ -2953,7 +2947,7 @@ mod tests {
29532947
expected: ArrayRef,
29542948
) -> Result<()> {
29552949
let lit = Arc::new(Literal::new(literal));
2956-
let arithmetic_op = binary_simple(col("a", &schema)?, op, lit, &schema);
2950+
let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?;
29572951
let batch = RecordBatch::try_new(schema, data)?;
29582952
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
29592953

@@ -2968,16 +2962,10 @@ mod tests {
29682962
op: Operator,
29692963
expected: BooleanArray,
29702964
) -> Result<()> {
2971-
let left_type = left.data_type();
2972-
let right_type = right.data_type();
2973-
let (lhs, rhs) = get_input_types(left_type, &op, right_type)?;
2974-
2975-
let left_expr = try_cast(col("a", schema)?, schema, lhs)?;
2976-
let right_expr = try_cast(col("b", schema)?, schema, rhs)?;
2977-
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
2965+
let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
29782966
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
29792967
let batch = RecordBatch::try_new(schema.clone(), data)?;
2980-
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
2968+
let result = op.evaluate(&batch)?.into_array(batch.num_rows());
29812969

29822970
assert_eq!(result.as_ref(), &expected);
29832971
Ok(())
@@ -2992,14 +2980,9 @@ mod tests {
29922980
expected: &BooleanArray,
29932981
) -> Result<()> {
29942982
let scalar = lit(scalar.clone());
2995-
let (lhs, rhs) =
2996-
get_input_types(&scalar.data_type(schema)?, &op, arr.data_type())?;
2997-
let left_expr = try_cast(scalar, schema, lhs)?;
2998-
let right_expr = try_cast(col("a", schema)?, schema, rhs)?;
2999-
3000-
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
2983+
let op = binary_op(scalar, op, col("a", schema)?, schema)?;
30012984
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
3002-
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
2985+
let result = op.evaluate(&batch)?.into_array(batch.num_rows());
30032986
assert_eq!(result.as_ref(), expected);
30042987

30052988
Ok(())
@@ -3014,14 +2997,9 @@ mod tests {
30142997
expected: &BooleanArray,
30152998
) -> Result<()> {
30162999
let scalar = lit(scalar.clone());
3017-
let (lhs, rhs) =
3018-
get_input_types(arr.data_type(), &op, &scalar.data_type(schema)?)?;
3019-
let left_expr = try_cast(col("a", schema)?, schema, lhs)?;
3020-
let right_expr = try_cast(scalar, schema, rhs)?;
3021-
3022-
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
3000+
let op = binary_op(col("a", schema)?, op, scalar, schema)?;
30233001
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
3024-
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
3002+
let result = op.evaluate(&batch)?.into_array(batch.num_rows());
30253003
assert_eq!(result.as_ref(), expected);
30263004

30273005
Ok(())
@@ -3587,7 +3565,7 @@ mod tests {
35873565
let tree_depth: i32 = 100;
35883566
let expr = (0..tree_depth)
35893567
.map(|_| col("a", schema.as_ref()).unwrap())
3590-
.reduce(|l, r| binary_simple(l, Operator::Plus, r, &schema))
3568+
.reduce(|l, r| binary(l, Operator::Plus, r, &schema).unwrap())
35913569
.unwrap();
35923570

35933571
let result = expr
@@ -4069,26 +4047,7 @@ mod tests {
40694047
op: Operator,
40704048
expected: ArrayRef,
40714049
) -> Result<()> {
4072-
let (lhs_type, rhs_type) =
4073-
get_input_types(left.data_type(), &op, right.data_type()).unwrap();
4074-
4075-
let left_expr = try_cast(col("a", schema)?, schema, lhs_type.clone())?;
4076-
let right_expr = try_cast(col("b", schema)?, schema, rhs_type.clone())?;
4077-
4078-
let coerced_schema = Schema::new(vec![
4079-
Field::new(
4080-
schema.field(0).name(),
4081-
lhs_type,
4082-
schema.field(0).is_nullable(),
4083-
),
4084-
Field::new(
4085-
schema.field(1).name(),
4086-
rhs_type,
4087-
schema.field(1).is_nullable(),
4088-
),
4089-
]);
4090-
4091-
let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema);
4050+
let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
40924051
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
40934052
let batch = RecordBatch::try_new(schema.clone(), data)?;
40944053
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
@@ -4761,12 +4720,12 @@ mod tests {
47614720

47624721
// expression: "a >= 25"
47634722
let a = col("a", &schema).unwrap();
4764-
let gt = binary_simple(
4723+
let gt = binary(
47654724
a.clone(),
47664725
Operator::GtEq,
47674726
lit(ScalarValue::from(25)),
47684727
&schema,
4769-
);
4728+
)?;
47704729

47714730
let context = AnalysisContext::from_statistics(&schema, &statistics);
47724731
let predicate_boundaries = gt
@@ -4790,12 +4749,12 @@ mod tests {
47904749

47914750
// expression: "50 >= a"
47924751
let a = col("a", &schema).unwrap();
4793-
let gt = binary_simple(
4752+
let gt = binary(
47944753
lit(ScalarValue::from(50)),
47954754
Operator::GtEq,
47964755
a.clone(),
47974756
&schema,
4798-
);
4757+
)?;
47994758

48004759
let context = AnalysisContext::from_statistics(&schema, &statistics);
48014760
let predicate_boundaries = gt

0 commit comments

Comments
 (0)