diff --git a/parquet-variant-compute/src/unshred_variant.rs b/parquet-variant-compute/src/unshred_variant.rs index 264ef458d99f..64eaa46ed06b 100644 --- a/parquet-variant-compute/src/unshred_variant.rs +++ b/parquet-variant-compute/src/unshred_variant.rs @@ -25,15 +25,18 @@ use arrow::array::{ }; use arrow::buffer::NullBuffer; use arrow::datatypes::{ - ArrowPrimitiveType, DataType, Date32Type, Float32Type, Float64Type, Int8Type, Int16Type, - Int32Type, Int64Type, Time64MicrosecondType, TimeUnit, TimestampMicrosecondType, - TimestampNanosecondType, + ArrowPrimitiveType, DataType, Date32Type, Decimal32Type, Decimal64Type, Decimal128Type, + DecimalType, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, + Time64MicrosecondType, TimeUnit, TimestampMicrosecondType, TimestampNanosecondType, }; use arrow::error::{ArrowError, Result}; use arrow::temporal_conversions::time64us_to_time; use chrono::{DateTime, Utc}; use indexmap::IndexMap; -use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt, VariantMetadata}; +use parquet_variant::{ + ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8, + VariantDecimal16, VariantMetadata, +}; use uuid::Uuid; /// Removes all (nested) typed_value columns from a VariantArray by converting them back to binary @@ -92,6 +95,9 @@ enum UnshredVariantRowBuilder<'a> { PrimitiveInt64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveFloat32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveFloat64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), + Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Spec>), + Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Spec>), + Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Spec>), PrimitiveDate32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveTime64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), TimestampMicrosecond(TimestampUnshredRowBuilder<'a, TimestampMicrosecondType>), @@ -130,6 +136,9 @@ impl<'a> UnshredVariantRowBuilder<'a> { Self::PrimitiveInt64(b) => b.append_row(builder, metadata, index), Self::PrimitiveFloat32(b) => b.append_row(builder, metadata, index), Self::PrimitiveFloat64(b) => b.append_row(builder, metadata, index), + Self::Decimal32(b) => b.append_row(builder, metadata, index), + Self::Decimal64(b) => b.append_row(builder, metadata, index), + Self::Decimal128(b) => b.append_row(builder, metadata, index), Self::PrimitiveDate32(b) => b.append_row(builder, metadata, index), Self::PrimitiveTime64(b) => b.append_row(builder, metadata, index), Self::TimestampMicrosecond(b) => b.append_row(builder, metadata, index), @@ -176,6 +185,26 @@ impl<'a> UnshredVariantRowBuilder<'a> { DataType::Int64 => primitive_builder!(PrimitiveInt64, as_primitive), DataType::Float32 => primitive_builder!(PrimitiveFloat32, as_primitive), DataType::Float64 => primitive_builder!(PrimitiveFloat64, as_primitive), + DataType::Decimal32(_, scale) => Self::Decimal32(DecimalUnshredRowBuilder::new( + value, + typed_value.as_primitive(), + *scale, + )), + DataType::Decimal64(_, scale) => Self::Decimal64(DecimalUnshredRowBuilder::new( + value, + typed_value.as_primitive(), + *scale, + )), + DataType::Decimal128(_, scale) => Self::Decimal128(DecimalUnshredRowBuilder::new( + value, + typed_value.as_primitive(), + *scale, + )), + DataType::Decimal256(_, _) => { + return Err(ArrowError::InvalidArgumentError( + "Decimal256 is not a valid variant shredding type".to_string(), + )); + } DataType::Date32 => primitive_builder!(PrimitiveDate32, as_primitive), DataType::Time64(TimeUnit::Microsecond) => { primitive_builder!(PrimitiveTime64, as_primitive) @@ -475,6 +504,96 @@ impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> { } } +/// Trait to unify decimal unshredding across Decimal32/64/128 types +trait DecimalSpec { + type Arrow: ArrowPrimitiveType + DecimalType; + + fn into_variant( + raw: ::Native, + scale: i8, + ) -> Result>; +} + +/// Spec for Decimal32 -> VariantDecimal4 +struct Decimal32Spec; + +impl DecimalSpec for Decimal32Spec { + type Arrow = Decimal32Type; + + fn into_variant(raw: i32, scale: i8) -> Result> { + let scale = + u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; + let value = VariantDecimal4::try_new(raw, scale) + .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; + Ok(value.into()) + } +} + +/// Spec for Decimal64 -> VariantDecimal8 +struct Decimal64Spec; + +impl DecimalSpec for Decimal64Spec { + type Arrow = Decimal64Type; + + fn into_variant(raw: i64, scale: i8) -> Result> { + let scale = + u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; + let value = VariantDecimal8::try_new(raw, scale) + .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; + Ok(value.into()) + } +} + +/// Spec for Decimal128 -> VariantDecimal16 +struct Decimal128Spec; + +impl DecimalSpec for Decimal128Spec { + type Arrow = Decimal128Type; + + fn into_variant(raw: i128, scale: i8) -> Result> { + let scale = + u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; + let value = VariantDecimal16::try_new(raw, scale) + .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; + Ok(value.into()) + } +} + +/// Generic builder for decimal unshredding that caches scale +struct DecimalUnshredRowBuilder<'a, S: DecimalSpec> { + value: Option<&'a BinaryViewArray>, + typed_value: &'a PrimitiveArray, + scale: i8, +} + +impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> { + fn new( + value: Option<&'a BinaryViewArray>, + typed_value: &'a PrimitiveArray, + scale: i8, + ) -> Self { + Self { + value, + typed_value, + scale, + } + } + + fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + metadata: &VariantMetadata, + index: usize, + ) -> Result<()> { + handle_unshredded_case!(self, builder, metadata, index, false); + + let raw = self.typed_value.value(index); + let variant = S::into_variant(raw, self.scale)?; + builder.append_value(variant); + Ok(()) + } +} + /// Builder for unshredding struct/object types with nested fields struct StructUnshredVariantBuilder<'a> { value: Option<&'a arrow::array::BinaryViewArray>, diff --git a/parquet-variant-compute/src/variant_array.rs b/parquet-variant-compute/src/variant_array.rs index 51ed10b3cf5d..5686d102d3fd 100644 --- a/parquet-variant-compute/src/variant_array.rs +++ b/parquet-variant-compute/src/variant_array.rs @@ -26,7 +26,10 @@ use arrow::datatypes::{ TimestampMicrosecondType, TimestampNanosecondType, }; use arrow_schema::extension::ExtensionType; -use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, TimeUnit}; +use arrow_schema::{ + ArrowError, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, + DataType, Field, FieldRef, Fields, TimeUnit, +}; use chrono::DateTime; use parquet_variant::Uuid; use parquet_variant::Variant; @@ -926,6 +929,11 @@ fn typed_value_to_variant<'a>( /// So cast them to get the right type. fn cast_to_binary_view_arrays(array: &dyn Array) -> Result { let new_type = canonicalize_and_verify_data_type(array.data_type())?; + if let Cow::Borrowed(_) = new_type { + if let Some(array) = array.as_struct_opt() { + return Ok(Arc::new(array.clone())); // bypass the unnecessary cast + } + } cast(array, new_type.as_ref()) } @@ -972,9 +980,20 @@ fn canonicalize_and_verify_data_type( UInt8 | UInt16 | UInt32 | UInt64 | Float16 => fail!(), // Most decimal types are allowed, with restrictions on precision and scale - Decimal32(p, s) if is_valid_variant_decimal(p, s, 9) => borrow!(), - Decimal64(p, s) if is_valid_variant_decimal(p, s, 18) => borrow!(), - Decimal128(p, s) if is_valid_variant_decimal(p, s, 38) => borrow!(), + // + // NOTE: arrow-parquet reads widens 32- and 64-bit decimals to 128-bit, but the variant spec + // requires using the narrowest decimal type for a given precision. Fix those up first. + Decimal64(p, s) | Decimal128(p, s) + if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) => + { + Cow::Owned(Decimal32(*p, *s)) + } + Decimal128(p, s) if is_valid_variant_decimal(p, s, DECIMAL64_MAX_PRECISION) => { + Cow::Owned(Decimal64(*p, *s)) + } + Decimal32(p, s) if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) => borrow!(), + Decimal64(p, s) if is_valid_variant_decimal(p, s, DECIMAL64_MAX_PRECISION) => borrow!(), + Decimal128(p, s) if is_valid_variant_decimal(p, s, DECIMAL128_MAX_PRECISION) => borrow!(), Decimal32(..) | Decimal64(..) | Decimal128(..) | Decimal256(..) => fail!(), // Only micro and nano timestamps are allowed diff --git a/parquet/tests/variant_integration.rs b/parquet/tests/variant_integration.rs index 98fa04555d77..48f23c46b83c 100644 --- a/parquet/tests/variant_integration.rs +++ b/parquet/tests/variant_integration.rs @@ -86,31 +86,12 @@ variant_test_case!(20); variant_test_case!(21); variant_test_case!(22); variant_test_case!(23); -// https://github.com/apache/arrow-rs/issues/8332 -variant_test_case!( - 24, - "Unshredding not yet supported for type: Decimal128(9, 4)" -); -variant_test_case!( - 25, - "Unshredding not yet supported for type: Decimal128(9, 4)" -); -variant_test_case!( - 26, - "Unshredding not yet supported for type: Decimal128(18, 9)" -); -variant_test_case!( - 27, - "Unshredding not yet supported for type: Decimal128(18, 9)" -); -variant_test_case!( - 28, - "Unshredding not yet supported for type: Decimal128(38, 9)" -); -variant_test_case!( - 29, - "Unshredding not yet supported for type: Decimal128(38, 9)" -); +variant_test_case!(24); +variant_test_case!(25); +variant_test_case!(26); +variant_test_case!(27); +variant_test_case!(28); +variant_test_case!(29); variant_test_case!(30); variant_test_case!(31); variant_test_case!(32);