Skip to content

Commit 9b0aaba

Browse files
authored
Fix incorrect Not predicate evaluation in filtering (#78)
* fix NOT logic in predicate * format
1 parent 1b7a553 commit 9b0aaba

File tree

2 files changed

+316
-15
lines changed

2 files changed

+316
-15
lines changed

src/predicate.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ pub enum ComparisonOp {
6565
GreaterThanOrEqual,
6666
}
6767

68+
impl ComparisonOp {
69+
/// Returns the negated comparison operator.
70+
pub fn negate(&self) -> Self {
71+
match self {
72+
ComparisonOp::Equal => ComparisonOp::NotEqual,
73+
ComparisonOp::NotEqual => ComparisonOp::Equal,
74+
ComparisonOp::LessThan => ComparisonOp::GreaterThanOrEqual,
75+
ComparisonOp::LessThanOrEqual => ComparisonOp::GreaterThan,
76+
ComparisonOp::GreaterThan => ComparisonOp::LessThanOrEqual,
77+
ComparisonOp::GreaterThanOrEqual => ComparisonOp::LessThan,
78+
}
79+
}
80+
}
81+
6882
/// A predicate that can be evaluated against row group statistics
6983
///
7084
/// Predicates are simplified expressions used for filtering row groups before

src/row_group_filter.rs

Lines changed: 302 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,39 @@ fn evaluate_predicate_recursive(
106106
result[i] = temp_results.iter().any(|tr| tr[i]);
107107
}
108108
}
109-
Predicate::Not(predicate) => {
110-
// For NOT: evaluate predicate, then negate
111-
let mut temp_result = vec![true; result.len()];
112-
evaluate_predicate_recursive(predicate, row_index, schema, &mut temp_result)?;
113-
// NOT logic: result[i] = !temp_result[i]
114-
for (r, t) in result.iter_mut().zip(temp_result.iter()) {
115-
*r = !*t;
109+
Predicate::Not(predicate) => match &**predicate {
110+
Predicate::Not(inner) => {
111+
evaluate_predicate_recursive(inner, row_index, schema, result)?;
116112
}
117-
}
113+
Predicate::IsNull { column } => {
114+
evaluate_is_not_null(column, row_index, schema, result)?;
115+
}
116+
Predicate::IsNotNull { column } => {
117+
evaluate_is_null(column, row_index, schema, result)?;
118+
}
119+
Predicate::Comparison { column, op, value } => {
120+
evaluate_comparison(column, op.negate(), value, row_index, schema, result)?;
121+
}
122+
Predicate::And(predicates) => {
123+
let not_preds: Vec<Predicate> = predicates
124+
.iter()
125+
.map(|p| Predicate::Not(Box::new(p.clone())))
126+
.collect();
127+
evaluate_predicate_recursive(&Predicate::Or(not_preds), row_index, schema, result)?;
128+
}
129+
Predicate::Or(predicates) => {
130+
let not_preds: Vec<Predicate> = predicates
131+
.iter()
132+
.map(|p| Predicate::Not(Box::new(p.clone())))
133+
.collect();
134+
evaluate_predicate_recursive(
135+
&Predicate::And(not_preds),
136+
row_index,
137+
schema,
138+
result,
139+
)?;
140+
}
141+
},
118142
}
119143

120144
Ok(())
@@ -1015,26 +1039,289 @@ mod tests {
10151039
}
10161040

10171041
#[test]
1018-
fn test_evaluate_predicate_missing_statistics() {
1042+
fn test_evaluate_predicate_not_is_null() {
1043+
use crate::predicate::Predicate;
1044+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1045+
use std::collections::HashMap;
1046+
1047+
// Create row index with mixed nulls and values
1048+
let mut columns = HashMap::new();
1049+
let entries = vec![RowGroupEntry::new(
1050+
Some({
1051+
let proto_stats = proto::ColumnStatistics {
1052+
number_of_values: Some(5000),
1053+
has_null: Some(true),
1054+
int_statistics: Some(proto::IntegerStatistics {
1055+
minimum: Some(18),
1056+
maximum: Some(25),
1057+
sum: Some(107500),
1058+
}),
1059+
..Default::default()
1060+
};
1061+
ColumnStatistics::try_from(&proto_stats).unwrap()
1062+
}),
1063+
vec![],
1064+
)];
1065+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1066+
let row_index = StripeRowIndex::new(columns, 10000, 10000);
1067+
let schema = create_test_schema();
1068+
1069+
// Test: Not(age IS NULL) -> age IS NOT NULL
1070+
let predicate = Predicate::not(Predicate::is_null("age"));
1071+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1072+
1073+
assert_eq!(result.len(), 1);
1074+
assert!(result[0]); // Should keep because there are non-null values
1075+
}
1076+
1077+
#[test]
1078+
fn test_evaluate_predicate_not_is_not_null() {
1079+
use crate::predicate::Predicate;
1080+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1081+
use std::collections::HashMap;
1082+
1083+
// Create row index with mixed nulls and values
1084+
let mut columns = HashMap::new();
1085+
let entries = vec![
1086+
// Row group 0: Has nulls (and values)
1087+
RowGroupEntry::new(
1088+
Some({
1089+
let proto_stats = proto::ColumnStatistics {
1090+
number_of_values: Some(5000),
1091+
has_null: Some(true),
1092+
int_statistics: Some(proto::IntegerStatistics {
1093+
minimum: Some(18),
1094+
maximum: Some(25),
1095+
sum: Some(107500),
1096+
}),
1097+
..Default::default()
1098+
};
1099+
ColumnStatistics::try_from(&proto_stats).unwrap()
1100+
}),
1101+
vec![],
1102+
),
1103+
// Row group 1: No nulls
1104+
RowGroupEntry::new(
1105+
Some({
1106+
let proto_stats = proto::ColumnStatistics {
1107+
number_of_values: Some(10000),
1108+
has_null: Some(false),
1109+
int_statistics: Some(proto::IntegerStatistics {
1110+
minimum: Some(26),
1111+
maximum: Some(65),
1112+
sum: Some(455000),
1113+
}),
1114+
..Default::default()
1115+
};
1116+
ColumnStatistics::try_from(&proto_stats).unwrap()
1117+
}),
1118+
vec![],
1119+
),
1120+
];
1121+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1122+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1123+
let schema = create_test_schema();
1124+
1125+
// Test: Not(age IS NOT NULL) -> age IS NULL
1126+
let predicate = Predicate::not(Predicate::is_not_null("age"));
1127+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1128+
1129+
assert_eq!(result.len(), 2);
1130+
assert!(result[0]); // Row group 0: has_null = true -> Keep
1131+
assert!(!result[1]); // Row group 1: has_null = false -> Skip
1132+
}
1133+
1134+
#[test]
1135+
fn test_evaluate_predicate_not_comparison() {
1136+
use crate::predicate::{Predicate, PredicateValue};
1137+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1138+
use std::collections::HashMap;
1139+
1140+
let mut columns = HashMap::new();
1141+
let entries = vec![RowGroupEntry::new(
1142+
Some({
1143+
let proto_stats = proto::ColumnStatistics {
1144+
number_of_values: Some(10000),
1145+
has_null: Some(false),
1146+
int_statistics: Some(proto::IntegerStatistics {
1147+
minimum: Some(0),
1148+
maximum: Some(10),
1149+
sum: Some(50000),
1150+
}),
1151+
..Default::default()
1152+
};
1153+
ColumnStatistics::try_from(&proto_stats).unwrap()
1154+
}),
1155+
vec![],
1156+
)];
1157+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1158+
let row_index = StripeRowIndex::new(columns, 10000, 10000);
1159+
let schema = create_test_schema();
1160+
1161+
// Test: Not(age > 5) -> age <= 5
1162+
let predicate = Predicate::not(Predicate::gt("age", PredicateValue::Int32(Some(5))));
1163+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1164+
1165+
assert_eq!(result.len(), 1);
1166+
assert!(result[0]);
1167+
}
1168+
1169+
#[test]
1170+
fn test_evaluate_predicate_not_and() {
10191171
use crate::predicate::{Predicate, PredicateValue};
10201172
use crate::row_index::{RowGroupEntry, RowGroupIndex};
10211173
use std::collections::HashMap;
10221174

1023-
// Create row index with missing statistics
10241175
let mut columns = HashMap::new();
10251176
let entries = vec![
1026-
RowGroupEntry::new(None, vec![]), // No statistics
1177+
RowGroupEntry::new(
1178+
Some({
1179+
let proto_stats = proto::ColumnStatistics {
1180+
number_of_values: Some(10000),
1181+
has_null: Some(false),
1182+
int_statistics: Some(proto::IntegerStatistics {
1183+
minimum: Some(0),
1184+
maximum: Some(10),
1185+
sum: Some(50000),
1186+
}),
1187+
..Default::default()
1188+
};
1189+
ColumnStatistics::try_from(&proto_stats).unwrap()
1190+
}),
1191+
vec![],
1192+
),
1193+
RowGroupEntry::new(
1194+
Some({
1195+
let proto_stats = proto::ColumnStatistics {
1196+
number_of_values: Some(10000),
1197+
has_null: Some(false),
1198+
int_statistics: Some(proto::IntegerStatistics {
1199+
minimum: Some(20),
1200+
maximum: Some(30),
1201+
sum: Some(250000),
1202+
}),
1203+
..Default::default()
1204+
};
1205+
ColumnStatistics::try_from(&proto_stats).unwrap()
1206+
}),
1207+
vec![],
1208+
),
1209+
];
1210+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1211+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1212+
let schema = create_test_schema();
1213+
1214+
// Test: Not(age >= 15 AND age <= 25)
1215+
// Equivalent to: age < 15 OR age > 25
1216+
// Row Group 1: [0, 10] -> Fits age < 15 -> Keep
1217+
// Row Group 2: [20, 30] -> Fits age > 25 -> Keep
1218+
let predicate = Predicate::not(Predicate::and(vec![
1219+
Predicate::gte("age", PredicateValue::Int32(Some(15))),
1220+
Predicate::lte("age", PredicateValue::Int32(Some(25))),
1221+
]));
1222+
1223+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1224+
1225+
assert_eq!(result.len(), 2);
1226+
assert!(result[0]); // [0, 10] is < 15
1227+
assert!(result[1]); // [20, 30] contains values > 25 (26..30)
1228+
}
1229+
1230+
#[test]
1231+
fn test_evaluate_predicate_not_or() {
1232+
use crate::predicate::{Predicate, PredicateValue};
1233+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1234+
use std::collections::HashMap;
1235+
1236+
let mut columns = HashMap::new();
1237+
let entries = vec![
1238+
RowGroupEntry::new(
1239+
Some({
1240+
let proto_stats = proto::ColumnStatistics {
1241+
number_of_values: Some(10000),
1242+
has_null: Some(false),
1243+
int_statistics: Some(proto::IntegerStatistics {
1244+
minimum: Some(0),
1245+
maximum: Some(5),
1246+
sum: Some(25000),
1247+
}),
1248+
..Default::default()
1249+
};
1250+
ColumnStatistics::try_from(&proto_stats).unwrap()
1251+
}),
1252+
vec![],
1253+
),
1254+
RowGroupEntry::new(
1255+
Some({
1256+
let proto_stats = proto::ColumnStatistics {
1257+
number_of_values: Some(10000),
1258+
has_null: Some(false),
1259+
int_statistics: Some(proto::IntegerStatistics {
1260+
minimum: Some(5),
1261+
maximum: Some(15),
1262+
sum: Some(100000),
1263+
}),
1264+
..Default::default()
1265+
};
1266+
ColumnStatistics::try_from(&proto_stats).unwrap()
1267+
}),
1268+
vec![],
1269+
),
10271270
];
10281271
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1272+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1273+
let schema = create_test_schema();
1274+
1275+
// Test: Not(age < 10 OR age > 30)
1276+
// Equivalent to: age >= 10 AND age <= 30
1277+
let predicate = Predicate::not(Predicate::or(vec![
1278+
Predicate::lt("age", PredicateValue::Int32(Some(10))),
1279+
Predicate::gt("age", PredicateValue::Int32(Some(30))),
1280+
]));
1281+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1282+
1283+
assert_eq!(result.len(), 2);
1284+
assert!(!result[0]); // [0, 5] is outside [10, 30] -> Skip
1285+
assert!(result[1]); // [5, 15] overlaps [10, 30] -> Keep
1286+
}
1287+
1288+
#[test]
1289+
fn test_evaluate_predicate_double_negation() {
1290+
use crate::predicate::{Predicate, PredicateValue};
1291+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1292+
use std::collections::HashMap;
1293+
1294+
let mut columns = HashMap::new();
1295+
// Row group: [0, 10]
1296+
let entries = vec![RowGroupEntry::new(
1297+
Some({
1298+
let proto_stats = proto::ColumnStatistics {
1299+
number_of_values: Some(10000),
1300+
has_null: Some(false),
1301+
int_statistics: Some(proto::IntegerStatistics {
1302+
minimum: Some(0),
1303+
maximum: Some(10),
1304+
sum: Some(50000),
1305+
}),
1306+
..Default::default()
1307+
};
1308+
ColumnStatistics::try_from(&proto_stats).unwrap()
1309+
}),
1310+
vec![],
1311+
)];
1312+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
10291313
let row_index = StripeRowIndex::new(columns, 10000, 10000);
10301314
let schema = create_test_schema();
10311315

1032-
// Test: age > 10
1033-
// Should keep row group when statistics are missing (conservative)
1034-
let predicate = Predicate::gt("age", PredicateValue::Int32(Some(10)));
1316+
// Test: Not(Not(age > 5)) -> age > 5
1317+
// Row group [0, 10] contains values > 5 -> Keep
1318+
let predicate = Predicate::not(Predicate::not(Predicate::gt(
1319+
"age",
1320+
PredicateValue::Int32(Some(5)),
1321+
)));
10351322
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
10361323

10371324
assert_eq!(result.len(), 1);
1038-
assert!(result[0]); // Keep when statistics missing
1325+
assert!(result[0]);
10391326
}
10401327
}

0 commit comments

Comments
 (0)