diff --git a/crates/iceberg/src/writer/file_writer/parquet_writer.rs b/crates/iceberg/src/writer/file_writer/parquet_writer.rs index 5561b1913..ef08a66dc 100644 --- a/crates/iceberg/src/writer/file_writer/parquet_writer.rs +++ b/crates/iceberg/src/writer/file_writer/parquet_writer.rs @@ -17,11 +17,13 @@ //! The module contains the file writer for parquet file format. +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::atomic::AtomicI64; use std::sync::Arc; -use arrow_schema::SchemaRef as ArrowSchemaRef; +use arrow_array::Float32Array; +use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef}; use bytes::Bytes; use futures::future::BoxFuture; use itertools::Itertools; @@ -101,6 +103,7 @@ impl FileWriterBuilder for ParquetWr written_size, current_row_num: 0, out_file, + nan_value_counts: HashMap::new(), }) } } @@ -216,6 +219,7 @@ pub struct ParquetWriter { writer: AsyncArrowWriter>, written_size: Arc, current_row_num: usize, + nan_value_counts: HashMap, } /// Used to aggregate min and max value of each column. @@ -312,6 +316,7 @@ impl ParquetWriter { metadata: FileMetaData, written_size: usize, file_path: String, + nan_value_counts: HashMap, ) -> Result { let index_by_parquet_path = { let mut visitor = IndexByParquetPathName::new(); @@ -378,8 +383,8 @@ impl ParquetWriter { .null_value_counts(null_value_counts) .lower_bounds(lower_bounds) .upper_bounds(upper_bounds) + .nan_value_counts(nan_value_counts) // # TODO(#417) - // - nan_value_counts // - distinct_counts .key_metadata(metadata.footer_signing_key_metadata) .split_offsets( @@ -396,6 +401,37 @@ impl ParquetWriter { impl FileWriter for ParquetWriter { async fn write(&mut self, batch: &arrow_array::RecordBatch) -> crate::Result<()> { self.current_row_num += batch.num_rows(); + + for (col, field) in batch + .columns() + .iter() + .zip(self.schema.as_struct().fields().iter()) + { + let dt = col.data_type(); + + let nan_val_cnt: u64 = match dt { + DataType::Float32 => { + let float_array = col.as_any().downcast_ref::().unwrap(); + + float_array + .iter() + .filter(|value| value.map_or(false, |v| v.is_nan())) + .count() as u64 + } + _ => 0, + }; + + match self.nan_value_counts.entry(field.id) { + Entry::Occupied(mut ele) => { + let total_nan_val_cnt = ele.get() + nan_val_cnt; + ele.insert(total_nan_val_cnt); + } + Entry::Vacant(v) => { + v.insert(nan_val_cnt); + } + } + } + self.writer.write(batch).await.map_err(|err| { Error::new( ErrorKind::Unexpected, @@ -403,6 +439,7 @@ impl FileWriter for ParquetWriter { ) .with_source(err) })?; + Ok(()) } @@ -418,6 +455,7 @@ impl FileWriter for ParquetWriter { metadata, written_size as usize, self.out_file.location().to_string(), + self.nan_value_counts, )?]) } } @@ -478,8 +516,8 @@ mod tests { use anyhow::Result; use arrow_array::types::Int64Type; use arrow_array::{ - Array, ArrayRef, BooleanArray, Decimal128Array, Int32Array, Int64Array, ListArray, - RecordBatch, StructArray, + Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Int32Array, Int64Array, + ListArray, RecordBatch, StructArray, }; use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef}; use arrow_select::concat::concat_batches; @@ -659,13 +697,27 @@ mod tests { arrow_schema::Field::new("col", arrow_schema::DataType::Int64, true).with_metadata( HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "0".to_string())]), ), + arrow_schema::Field::new("col1", arrow_schema::DataType::Float32, true) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "1".to_string(), + )])), ]; Arc::new(arrow_schema::Schema::new(fields)) }; let col = Arc::new(Int64Array::from_iter_values(0..1024)) as ArrayRef; let null_col = Arc::new(Int64Array::new_null(1024)) as ArrayRef; - let to_write = RecordBatch::try_new(schema.clone(), vec![col]).unwrap(); - let to_write_null = RecordBatch::try_new(schema.clone(), vec![null_col]).unwrap(); + let float_col = Arc::new(Float32Array::from_iter_values((0..1024).map(|x| { + if x % 100 == 0 { + // There will be 11 NANs as there are 1024 entries + f32::NAN + } else { + x as f32 + } + }))) as ArrayRef; + let to_write = RecordBatch::try_new(schema.clone(), vec![col, float_col.clone()]).unwrap(); + let to_write_null = + RecordBatch::try_new(schema.clone(), vec![null_col, float_col]).unwrap(); // write data let mut pw = ParquetWriterBuilder::new( @@ -677,6 +729,7 @@ mod tests { ) .build() .await?; + pw.write(&to_write).await?; pw.write(&to_write_null).await?; let res = pw.close().await?; @@ -693,16 +746,26 @@ mod tests { // check data file assert_eq!(data_file.record_count(), 2048); - assert_eq!(*data_file.value_counts(), HashMap::from([(0, 2048)])); + assert_eq!( + *data_file.value_counts(), + HashMap::from([(0, 2048), (1, 2048)]) + ); assert_eq!( *data_file.lower_bounds(), - HashMap::from([(0, Datum::long(0))]) + HashMap::from([(0, Datum::long(0)), (1, Datum::float(1.0))]) ); assert_eq!( *data_file.upper_bounds(), - HashMap::from([(0, Datum::long(1023))]) + HashMap::from([(0, Datum::long(1023)), (1, Datum::float(1023.0))]) + ); + assert_eq!( + *data_file.null_value_counts(), + HashMap::from([(0, 1024), (1, 0)]) + ); + assert_eq!( + *data_file.nan_value_counts(), + HashMap::from([(0, 0), (1, 22)]) // 22, cause we wrote float column twice ); - assert_eq!(*data_file.null_value_counts(), HashMap::from([(0, 1024)])); // check the written file let expect_batch = concat_batches(&schema, vec![&to_write, &to_write_null]).unwrap();