Skip to content

Commit

Permalink
use Datum instead of Literal
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME committed Jun 1, 2024
1 parent 1c207cb commit 3715276
Showing 1 changed file with 120 additions and 103 deletions.
223 changes: 120 additions & 103 deletions crates/iceberg/src/writer/file_writer/parquet_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
use crate::arrow::DEFAULT_MAP_FIELD_NAME;
use crate::spec::{
visit_schema, ListType, Literal, MapType, NestedFieldRef, PrimitiveType, Schema, SchemaRef,
SchemaVisitor, StructType,
visit_schema, Datum, ListType, MapType, NestedFieldRef, PrimitiveLiteral, PrimitiveType,
Schema, SchemaRef, SchemaVisitor, StructType,
};
use crate::ErrorKind;
use crate::{io::FileIO, io::FileWrite, Result};
Expand Down Expand Up @@ -225,8 +225,8 @@ pub struct ParquetWriter {

/// Used to aggregate min and max value of each column.
struct MinMaxColAggregator {
lower_bounds: HashMap<i32, Literal>,
upper_bounds: HashMap<i32, Literal>,
lower_bounds: HashMap<i32, Datum>,
upper_bounds: HashMap<i32, Datum>,
schema: SchemaRef,
}

Expand Down Expand Up @@ -259,36 +259,6 @@ impl MinMaxColAggregator {
};

macro_rules! update_stat {
($self:ident, $stat:ident, $convert_func:expr) => {
if $stat.min_is_exact() {
let val = $convert_func($stat.min().clone());
match $self.lower_bounds.entry(col_id) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
if entry.get() > &val {
entry.insert(val);
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(val);
}
}
}
if $stat.max_is_exact() {
let val = $convert_func($stat.max().clone());
match $self.upper_bounds.entry(col_id) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
if entry.get() < &val {
entry.insert(val);
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(val);
}
}
}
};
}
macro_rules! update_stat_with_err {
($self:ident, $stat:ident, $convert_func:expr) => {
if $stat.min_is_exact() {
let val = $convert_func($stat.min().clone())?;
Expand Down Expand Up @@ -321,41 +291,53 @@ impl MinMaxColAggregator {

match (ty, value) {
(PrimitiveType::Boolean, Statistics::Boolean(stat)) => {
update_stat!(self, stat, Literal::bool);
let convert_func = |v: bool| Result::<Datum>::Ok(Datum::bool(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Int, Statistics::Int32(stat)) => {
update_stat!(self, stat, Literal::int);
let convert_func = |v: i32| Result::<Datum>::Ok(Datum::int(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Long, Statistics::Int64(stat)) => {
update_stat!(self, stat, Literal::long);
let convert_func = |v: i64| Result::<Datum>::Ok(Datum::long(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Float, Statistics::Float(stat)) => {
update_stat!(self, stat, Literal::float);
let convert_func = |v: f32| Result::<Datum>::Ok(Datum::float(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Double, Statistics::Double(stat)) => {
update_stat!(self, stat, Literal::double);
let convert_func = |v: f64| Result::<Datum>::Ok(Datum::double(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::String, Statistics::ByteArray(stat)) => {
let convert_func = |v: ByteArray| -> Result<Literal> {
Ok(Literal::string(v.as_utf8()?.to_string()))
let convert_func = |v: ByteArray| {
Result::<Datum>::Ok(Datum::string(
String::from_utf8(v.data().to_vec()).unwrap(),
))
};
update_stat_with_err!(self, stat, convert_func);
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Binary, Statistics::ByteArray(stat)) => {
let convert_func = |v: ByteArray| Literal::binary(v.data().to_vec());
let convert_func =
|v: ByteArray| Result::<Datum>::Ok(Datum::binary(v.data().to_vec()));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Date, Statistics::Int32(stat)) => {
update_stat!(self, stat, Literal::date);
let convert_func = |v: i32| Result::<Datum>::Ok(Datum::date(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Time, Statistics::Int64(stat)) => {
update_stat!(self, stat, Literal::time);
let convert_func = |v: i64| Datum::time_micros(v);
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Timestamp, Statistics::Int64(stat)) => {
update_stat!(self, stat, Literal::timestamp);
let convert_func = |v: i64| Result::<Datum>::Ok(Datum::timestamp_micros(v));
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Timestamptz, Statistics::Int64(stat)) => {
update_stat!(self, stat, Literal::timestamptz);
let convert_func = |v: i64| Result::<Datum>::Ok(Datum::timestamptz_micros(v));
update_stat!(self, stat, convert_func);
}
(
PrimitiveType::Decimal {
Expand All @@ -364,10 +346,15 @@ impl MinMaxColAggregator {
},
Statistics::ByteArray(stat),
) => {
let convert_func = |v: ByteArray| -> Result<Literal> {
Ok(Literal::decimal(i128::from_be_bytes(v.data().try_into()?)))
let convert_func = |v: ByteArray| -> Result<Datum> {
Result::<Datum>::Ok(Datum::new(
ty.clone(),
PrimitiveLiteral::Decimal(i128::from_le_bytes(
v.data().try_into().unwrap(),
)),
))
};
update_stat_with_err!(self, stat, convert_func);
update_stat!(self, stat, convert_func);
}
(
PrimitiveType::Decimal {
Expand All @@ -376,7 +363,12 @@ impl MinMaxColAggregator {
},
Statistics::Int32(stat),
) => {
let convert_func = |v: i32| Literal::decimal(v as i128);
let convert_func = |v: i32| {
Result::<Datum>::Ok(Datum::new(
ty.clone(),
PrimitiveLiteral::Decimal(i128::from(v)),
))
};
update_stat!(self, stat, convert_func);
}
(
Expand All @@ -386,7 +378,12 @@ impl MinMaxColAggregator {
},
Statistics::Int64(stat),
) => {
let convert_func = |v: i64| Literal::decimal(v as i128);
let convert_func = |v: i64| {
Result::<Datum>::Ok(Datum::new(
ty.clone(),
PrimitiveLiteral::Decimal(i128::from(v)),
))
};
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Uuid, Statistics::FixedLenByteArray(stat)) => {
Expand All @@ -397,11 +394,11 @@ impl MinMaxColAggregator {
"Invalid length of uuid bytes.",
));
}
Ok(Literal::uuid(Uuid::from_bytes(
Ok(Datum::uuid(Uuid::from_bytes(
v.data()[..16].try_into().unwrap(),
)))
};
update_stat_with_err!(self, stat, convert_func);
update_stat!(self, stat, convert_func);
}
(PrimitiveType::Fixed(len), Statistics::FixedLenByteArray(stat)) => {
let convert_func = |v: FixedLenByteArray| {
Expand All @@ -411,9 +408,9 @@ impl MinMaxColAggregator {
"Invalid length of fixed bytes.",
));
}
Ok(Literal::fixed(v.data().to_vec()))
Ok(Datum::fixed(v.data().to_vec()))
};
update_stat_with_err!(self, stat, convert_func);
update_stat!(self, stat, convert_func);
}
(ty, value) => {
return Err(Error::new(
Expand All @@ -425,7 +422,7 @@ impl MinMaxColAggregator {
Ok(())
}

fn produce(self) -> (HashMap<i32, Literal>, HashMap<i32, Literal>) {
fn produce(self) -> (HashMap<i32, Datum>, HashMap<i32, Datum>) {
(self.lower_bounds, self.upper_bounds)
}
}
Expand Down Expand Up @@ -936,11 +933,11 @@ mod tests {
assert_eq!(*data_file.value_counts(), HashMap::from([(0, 2048)]));
assert_eq!(
*data_file.lower_bounds(),
HashMap::from([(0, Literal::long(0))])
HashMap::from([(0, Datum::long(0))])
);
assert_eq!(
*data_file.upper_bounds(),
HashMap::from([(0, Literal::long(1023))])
HashMap::from([(0, Datum::long(1023))])
);
assert_eq!(*data_file.null_value_counts(), HashMap::from([(0, 1024)]));

Expand Down Expand Up @@ -1145,27 +1142,27 @@ mod tests {
assert_eq!(
*data_file.lower_bounds(),
HashMap::from([
(0, Literal::long(0)),
(5, Literal::long(0)),
(6, Literal::long(0)),
(2, Literal::string("0")),
(7, Literal::long(0)),
(9, Literal::long(0)),
(11, Literal::string("0")),
(13, Literal::long(0))
(0, Datum::long(0)),
(5, Datum::long(0)),
(6, Datum::long(0)),
(2, Datum::string("0")),
(7, Datum::long(0)),
(9, Datum::long(0)),
(11, Datum::string("0")),
(13, Datum::long(0))
])
);
assert_eq!(
*data_file.upper_bounds(),
HashMap::from([
(0, Literal::long(1023)),
(5, Literal::long(1023)),
(6, Literal::long(1023)),
(2, Literal::string("999")),
(7, Literal::long(1023)),
(9, Literal::long(1023)),
(11, Literal::string("999")),
(13, Literal::long(1023))
(0, Datum::long(1023)),
(5, Datum::long(1023)),
(6, Datum::long(1023)),
(2, Datum::string("999")),
(7, Datum::long(1023)),
(9, Datum::long(1023)),
(11, Datum::string("999")),
(13, Datum::long(1023))
])
);

Expand Down Expand Up @@ -1315,41 +1312,61 @@ mod tests {
assert_eq!(
*data_file.lower_bounds(),
HashMap::from([
(0, Literal::bool(false)),
(1, Literal::int(1)),
(2, Literal::long(1)),
(3, Literal::float(0.5)),
(4, Literal::double(0.5)),
(5, Literal::string("a")),
(6, Literal::binary(vec![])),
(7, Literal::date(0)),
(8, Literal::time(0)),
(9, Literal::timestamp(0)),
(10, Literal::timestamptz(0)),
(11, Literal::decimal(1)),
(12, Literal::uuid(Uuid::from_u128(0))),
(13, Literal::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
(0, Datum::bool(false)),
(1, Datum::int(1)),
(2, Datum::long(1)),
(3, Datum::float(0.5)),
(4, Datum::double(0.5)),
(5, Datum::string("a")),
(6, Datum::binary(vec![])),
(7, Datum::date(0)),
(8, Datum::time_micros(0).unwrap()),
(9, Datum::timestamp_micros(0)),
(10, Datum::timestamptz_micros(0)),
(
11,
Datum::new(
PrimitiveType::Decimal {
precision: 10,
scale: 5
},
PrimitiveLiteral::Decimal(1)
)
),
(12, Datum::uuid(Uuid::from_u128(0))),
(13, Datum::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
(12, Datum::uuid(Uuid::from_u128(0))),
(13, Datum::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
])
);
assert_eq!(
*data_file.upper_bounds(),
HashMap::from([
(0, Literal::bool(true)),
(1, Literal::int(4)),
(2, Literal::long(4)),
(3, Literal::float(3.5)),
(4, Literal::double(3.5)),
(5, Literal::string("d")),
(6, Literal::binary(vec![122, 122, 122, 122])),
(7, Literal::date(3)),
(8, Literal::time(3)),
(9, Literal::timestamp(3)),
(10, Literal::timestamptz(3)),
(11, Literal::decimal(100)),
(12, Literal::uuid(Uuid::from_u128(3))),
(0, Datum::bool(true)),
(1, Datum::int(4)),
(2, Datum::long(4)),
(3, Datum::float(3.5)),
(4, Datum::double(3.5)),
(5, Datum::string("d")),
(6, Datum::binary(vec![122, 122, 122, 122])),
(7, Datum::date(3)),
(8, Datum::time_micros(3).unwrap()),
(9, Datum::timestamp_micros(3)),
(10, Datum::timestamptz_micros(3)),
(
11,
Datum::new(
PrimitiveType::Decimal {
precision: 10,
scale: 5
},
PrimitiveLiteral::Decimal(100)
)
),
(12, Datum::uuid(Uuid::from_u128(3))),
(
13,
Literal::fixed(vec![21, 22, 23, 24, 25, 26, 27, 28, 29, 30])
Datum::fixed(vec![21, 22, 23, 24, 25, 26, 27, 28, 29, 30])
),
])
);
Expand Down

0 comments on commit 3715276

Please sign in to comment.