Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 171 additions & 32 deletions datafusion/functions-window/src/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ impl WindowUDFImpl for NthValue {
state,
ignore_nulls: partition_evaluator_args.ignore_nulls(),
n: 0,
valid_index_cache: None,
}));
}

Expand Down Expand Up @@ -313,6 +314,7 @@ impl WindowUDFImpl for NthValue {
state,
ignore_nulls: partition_evaluator_args.ignore_nulls(),
n,
valid_index_cache: None,
}))
}

Expand Down Expand Up @@ -375,6 +377,80 @@ pub(crate) struct NthValueEvaluator {
state: NthValueState,
ignore_nulls: bool,
n: i64,
valid_index_cache: Option<ValidIndexCache>,
}

#[derive(Debug)]
struct ValidIndexCache {
nulls: NullBuffer,
scanned_to: usize,
valid_indices: Vec<usize>,
}

impl ValidIndexCache {
fn new(nulls: &NullBuffer) -> Self {
Self {
nulls: nulls.clone(),
scanned_to: 0,
valid_indices: Vec::new(),
}
}

fn matches(&self, nulls: &NullBuffer) -> bool {
self.nulls.validity().as_ptr() == nulls.validity().as_ptr()
&& self.nulls.offset() == nulls.offset()
&& self.nulls.len() == nulls.len()
&& self.nulls.null_count() == nulls.null_count()
}

fn extend_to(&mut self, end: usize) {
let end = end.min(self.nulls.len());
if end <= self.scanned_to {
return;
}

self.valid_indices
.extend((self.scanned_to..end).filter(|index| self.nulls.is_valid(*index)));
self.scanned_to = end;
}

fn valid_index(
&mut self,
range: &Range<usize>,
kind: NthValueKind,
n: i64,
) -> Option<usize> {
self.extend_to(range.end);

let start = self
.valid_indices
.partition_point(|index| *index < range.start);
let end = self
.valid_indices
.partition_point(|index| *index < range.end);

match kind {
NthValueKind::First => self
.valid_indices
.get(start)
.copied()
.filter(|index| *index < range.end),
NthValueKind::Last => (start < end).then(|| self.valid_indices[end - 1]),
NthValueKind::Nth => match n.cmp(&0) {
Ordering::Greater => {
let index = (n as usize) - 1;
let target = start.checked_add(index)?;
(target < end).then(|| self.valid_indices[target])
}
Ordering::Less => {
let reverse_index = (-n) as usize;
(reverse_index <= end - start)
.then(|| self.valid_indices[end - reverse_index])
}
Ordering::Equal => None,
},
}
}
}

impl PartitionEvaluator for NthValueEvaluator {
Expand Down Expand Up @@ -483,15 +559,20 @@ impl PartitionEvaluator for NthValueEvaluator {
}

impl NthValueEvaluator {
fn valid_index(&self, array: &ArrayRef, range: &Range<usize>) -> Option<usize> {
fn valid_index(&mut self, array: &ArrayRef, range: &Range<usize>) -> Option<usize> {
let n_range = range.end - range.start;
if self.ignore_nulls {
// Calculate valid indices, inside the window frame boundaries.
let slice = array.slice(range.start, n_range);
if let Some(nulls) = slice.nulls()
if let Some(nulls) = array.nulls()
&& nulls.null_count() > 0
{
return self.valid_index_with_nulls(nulls, range.start);
let cache = self
.valid_index_cache
.get_or_insert_with(|| ValidIndexCache::new(nulls));
if !cache.matches(nulls) {
*cache = ValidIndexCache::new(nulls);
}
return cache.valid_index(range, self.state.kind, self.n);
}
}
// Either no nulls, or nulls are regarded as valid rows
Expand Down Expand Up @@ -522,34 +603,6 @@ impl NthValueEvaluator {
},
}
}

fn valid_index_with_nulls(&self, nulls: &NullBuffer, offset: usize) -> Option<usize> {
match self.state.kind {
NthValueKind::First => nulls.valid_indices().next().map(|idx| idx + offset),
NthValueKind::Last => nulls.valid_indices().last().map(|idx| idx + offset),
NthValueKind::Nth => {
match self.n.cmp(&0) {
Ordering::Greater => {
// SQL indices are not 0-based.
let index = (self.n as usize) - 1;
nulls.valid_indices().nth(index).map(|idx| idx + offset)
}
Ordering::Less => {
let reverse_index = (-self.n) as usize;
let valid_indices_len = nulls.len() - nulls.null_count();
if reverse_index > valid_indices_len {
return None;
}
nulls
.valid_indices()
.nth(valid_indices_len - reverse_index)
.map(|idx| idx + offset)
}
Ordering::Equal => None,
}
}
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -670,6 +723,92 @@ mod tests {
Ok(())
}

fn test_i32_ignore_nulls_result(
expr: NthValue,
input_exprs: &[Arc<dyn PhysicalExpr>],
expected: Int32Array,
) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![
None,
Some(2),
None,
Some(4),
Some(5),
None,
]));
let values = vec![arr];
let ranges = [
Range { start: 0, end: 6 },
Range { start: 0, end: 1 },
Range { start: 2, end: 3 },
Range { start: 0, end: 2 },
Range { start: 0, end: 4 },
];
let input_fields: Vec<FieldRef> =
vec![Field::new("f", DataType::Int32, true).into()];

let mut evaluator = expr.partition_evaluator(PartitionEvaluatorArgs::new(
input_exprs,
&input_fields,
false,
true,
))?;
let result = ranges
.iter()
.map(|range| evaluator.evaluate(&values, range))
.collect::<Result<Vec<ScalarValue>>>()?;
let result = ScalarValue::iter_to_array(result)?;
let result = as_int32_array(&result)?;
assert_eq!(expected, *result);
Ok(())
}

#[test]
fn first_value_ignore_nulls_cached_ranges() -> Result<()> {
let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
test_i32_ignore_nulls_result(
NthValue::first(),
&[expr],
Int32Array::from(vec![Some(2), None, None, Some(2), Some(2)]),
)
}

#[test]
fn last_value_ignore_nulls_cached_ranges() -> Result<()> {
let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
test_i32_ignore_nulls_result(
NthValue::last(),
&[expr],
Int32Array::from(vec![Some(5), None, None, Some(2), Some(4)]),
)
}

#[test]
fn nth_value_ignore_nulls_cached_ranges() -> Result<()> {
let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
let n_value =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;

test_i32_ignore_nulls_result(
NthValue::nth(),
&[expr, n_value],
Int32Array::from(vec![Some(4), None, None, None, Some(4)]),
)
}

#[test]
fn nth_value_negative_ignore_nulls_cached_ranges() -> Result<()> {
let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
let n_value =
Arc::new(Literal::new(ScalarValue::Int32(Some(-2)))) as Arc<dyn PhysicalExpr>;

test_i32_ignore_nulls_result(
NthValue::nth(),
&[expr, n_value],
Int32Array::from(vec![Some(4), None, None, None, Some(2)]),
)
}

#[test]
fn nth_value_i64_min_returns_error() {
let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
Expand Down
Loading