Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,15 @@ impl CaseExpr {
let mut current_value = new_null_array(&return_type, batch.num_rows());
// We only consider non-null values while comparing with whens
let mut remainder = not(&base_nulls)?;
let mut non_null_remainder_count = remainder.true_count();
for i in 0..self.when_then_expr.len() {
let when_value = self.when_then_expr[i]
.0
.evaluate_selection(batch, &remainder)?;
// If there are no rows left to process, break out of the loop early
if non_null_remainder_count == 0 {
break;
}

let when_predicate = &self.when_then_expr[i].0;
let when_value = when_predicate.evaluate_selection(batch, &remainder)?;
let when_value = when_value.into_array(batch.num_rows())?;
// build boolean array representing which rows match the "when" value
let when_match = compare_with_eq(
Expand All @@ -224,41 +229,46 @@ impl CaseExpr {
_ => Cow::Owned(prep_null_mask_filter(&when_match)),
};
// Make sure we only consider rows that have not been matched yet
let when_match = and(&when_match, &remainder)?;
let when_value = and(&when_match, &remainder)?;

// When no rows available for when clause, skip then clause
if when_match.true_count() == 0 {
// If the predicate did not match any rows, continue to the next branch immediately
let when_match_count = when_value.true_count();
if when_match_count == 0 {
continue;
}

let then_value = self.when_then_expr[i]
.1
.evaluate_selection(batch, &when_match)?;
let then_expression = &self.when_then_expr[i].1;
let then_value = then_expression.evaluate_selection(batch, &when_value)?;

current_value = match then_value {
ColumnarValue::Scalar(ScalarValue::Null) => {
nullif(current_value.as_ref(), &when_match)?
nullif(current_value.as_ref(), &when_value)?
}
ColumnarValue::Scalar(then_value) => {
zip(&when_match, &then_value.to_scalar()?, &current_value)?
zip(&when_value, &then_value.to_scalar()?, &current_value)?
}
ColumnarValue::Array(then_value) => {
zip(&when_match, &then_value, &current_value)?
zip(&when_value, &then_value, &current_value)?
}
};

remainder = and_not(&remainder, &when_match)?;
remainder = and_not(&remainder, &when_value)?;
non_null_remainder_count -= when_match_count;
}

if let Some(e) = self.else_expr() {
// keep `else_expr`'s data type and return type consistent
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
// null and unmatched tuples should be assigned else value
remainder = or(&base_nulls, &remainder)?;
let else_ = expr
.evaluate_selection(batch, &remainder)?
.into_array(batch.num_rows())?;
current_value = zip(&remainder, &else_, &current_value)?;

if remainder.true_count() > 0 {
// keep `else_expr`'s data type and return type consistent
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;

let else_ = expr
.evaluate_selection(batch, &remainder)?
.into_array(batch.num_rows())?;
current_value = zip(&remainder, &else_, &current_value)?;
}
}

Ok(ColumnarValue::Array(current_value))
Expand All @@ -277,10 +287,15 @@ impl CaseExpr {
// start with nulls as default output
let mut current_value = new_null_array(&return_type, batch.num_rows());
let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
let mut remainder_count = batch.num_rows();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given the similarity of this code and the one above, I wonder if there is some way to avoid the duplication (as part of a follow on PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #18152 the code is changing a bit further. If the approach there pans out I want to try to do the same for case_when_with_expr. It'll be easier to see if there's an extractable pattern once that work settles down, so if it's ok with you I would like to postpone your suggestion for a little bit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, of course -- we can always make the code better as follow on PRs

for i in 0..self.when_then_expr.len() {
let when_value = self.when_then_expr[i]
.0
.evaluate_selection(batch, &remainder)?;
// If there are no rows left to process, break out of the loop early
if remainder_count == 0 {
break;
}

let when_predicate = &self.when_then_expr[i].0;
let when_value = when_predicate.evaluate_selection(batch, &remainder)?;
let when_value = when_value.into_array(batch.num_rows())?;
let when_value = as_boolean_array(&when_value).map_err(|_| {
internal_datafusion_err!("WHEN expression did not return a BooleanArray")
Expand All @@ -293,14 +308,14 @@ impl CaseExpr {
// Make sure we only consider rows that have not been matched yet
let when_value = and(&when_value, &remainder)?;

// When no rows available for when clause, skip then clause
if when_value.true_count() == 0 {
// If the predicate did not match any rows, continue to the next branch immediately
let when_match_count = when_value.true_count();
if when_match_count == 0 {
continue;
}

let then_value = self.when_then_expr[i]
.1
.evaluate_selection(batch, &when_value)?;
let then_expression = &self.when_then_expr[i].1;
let then_value = then_expression.evaluate_selection(batch, &when_value)?;

current_value = match then_value {
ColumnarValue::Scalar(ScalarValue::Null) => {
Expand All @@ -317,10 +332,11 @@ impl CaseExpr {
// Succeed tuples should be filtered out for short-circuit evaluation,
// null values for the current when expr should be kept
remainder = and_not(&remainder, &when_value)?;
remainder_count -= when_match_count;
}

if let Some(e) = self.else_expr() {
if remainder.true_count() > 0 {
if remainder_count > 0 {
// keep `else_expr`'s data type and return type consistent
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
let else_ = expr
Expand Down
76 changes: 76 additions & 0 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,79 @@ query I
SELECT case when false then 1 / 0 else 1 / 1 end;
----
1

# Else branch evaluation with case expression, 1 when branch, null input
query I
SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL)) t(a)
----
1

# Else branch evaluation with case expression, 2 when branches, null input
query I
SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL)) t(a)
----
2

# Else branch evaluation without case expression, 1 when branch, null input
query I
SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL)) t(a)
----
1

# Else branch evaluation without case expression, 2 when branches, null input
query I
SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL)) t(a)
----
2

# Else branch evaluation with case expression, 1 when branch, non-null input
query I
SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES ('z')) t(a)
----
1

# Else branch evaluation with case expression, 2 when branches, non-null input
query I
SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES ('z')) t(a)
----
2

# Else branch evaluation without case expression, 1 when branch, non-null input
query I
SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES ('z')) t(a)
----
1

# Else branch evaluation without case expression, 2 when branches, non-null input
query I
SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES ('z')) t(a)
----
2

# Else branch evaluation with case expression, 1 when branch, mixed input
query I
SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL), ('z')) t(a)
----
1
1

# Else branch evaluation with case expression, 2 when branches, mixed input
query I
SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL), ('z')) t(a)
----
2
2

# Else branch evaluation without case expression, 1 when branch, mixed input
query I
SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL), ('z')) t(a)
----
1
1

# Else branch evaluation without case expression, 2 when branches, mixed input
query I
SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL), ('z')) t(a)
----
2
2