From b843e3d3f0e46a41e8b2ecd1c5f3fcbf9d87fe0e Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 10 Apr 2025 08:37:24 -0700 Subject: [PATCH 1/3] feat: Adding `with_schema_force` method for `RecordBatch` --- arrow-array/src/record_batch.rs | 72 ++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index a6c2aee7cbc6..80d1ca1bee88 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -359,6 +359,19 @@ impl RecordBatch { }) } + /// Forcibly overrides the schema of this [`RecordBatch`] + /// No schema checks executed so the method is faster than safe `with_schema` + /// + /// If provided schema is not compatible with this [`RecordBatch`] columns the runtime behavior + /// is undefined + pub fn with_schema_force(self, schema: SchemaRef) -> Result { + Ok(Self { + schema, + columns: self.columns, + row_count: self.row_count, + }) + } + /// Returns the [`Schema`] of the record batch. pub fn schema(&self) -> SchemaRef { self.schema.clone() @@ -744,12 +757,14 @@ impl RecordBatchOptions { row_count: None, } } - /// Sets the row_count of RecordBatchOptions and returns self + + /// Sets the `row_count` of `RecordBatchOptions` and returns this [`RecordBatch`] pub fn with_row_count(mut self, row_count: Option) -> Self { self.row_count = row_count; self } - /// Sets the match_field_names of RecordBatchOptions and returns self + + /// Sets the `match_field_names` of `RecordBatchOptions` and returns this [`RecordBatch`] pub fn with_match_field_names(mut self, match_field_names: bool) -> Self { self.match_field_names = match_field_names; self @@ -1637,4 +1652,57 @@ mod tests { "bar" ); } + + #[test] + fn test_batch_with_force_schema() { + fn force_schema_and_get_err_from_batch( + record_batch: &RecordBatch, + schema_ref: SchemaRef, + idx: usize, + ) -> Option { + record_batch + .clone() + .with_schema_force(schema_ref) + .unwrap() + .project(&[idx]) + .err() + } + + let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); + + let record_batch = + RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion"); + + // Test empty schema for non-empty schema batch + let invalid_schema_empty = Schema::empty(); + assert_eq!( + force_schema_and_get_err_from_batch(&record_batch, invalid_schema_empty.into(), 0) + .unwrap() + .to_string(), + "Schema error: project index 0 out of bounds, max field 0" + ); + + // Wrong number of columns + let invalid_schema_more_cols = Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("a", DataType::Int32, false), + ]); + assert!(force_schema_and_get_err_from_batch( + &record_batch, + invalid_schema_more_cols.clone().into(), + 0 + ) + .is_none()); + assert_eq!( + force_schema_and_get_err_from_batch(&record_batch, invalid_schema_more_cols.into(), 1) + .unwrap() + .to_string(), + "Schema error: project index 1 out of bounds, max field 1" + ); + + // Wrong datatype + let invalid_schema_wrong_datatype = + Schema::new(vec![Field::new("a", DataType::Int32, false)]); + assert_eq!(force_schema_and_get_err_from_batch(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0"); + } } From 3dbaf1052c588b98016206db1a9227145d626ff3 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 10 Apr 2025 08:45:12 -0700 Subject: [PATCH 2/3] feat: Adding `with_schema_force` method for `RecordBatch` --- arrow-array/src/record_batch.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index 80d1ca1bee88..1fe0aa64edb8 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -360,7 +360,8 @@ impl RecordBatch { } /// Forcibly overrides the schema of this [`RecordBatch`] - /// No schema checks executed so the method is faster than safe `with_schema` + /// without additional schema checks however bringing all the schema compatibility responsibilities + /// to the caller site. /// /// If provided schema is not compatible with this [`RecordBatch`] columns the runtime behavior /// is undefined From 8e2e7fb66f21f4519c574cea121aaa771a4dd6cf Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 10 Apr 2025 12:29:23 -0700 Subject: [PATCH 3/3] address comments --- arrow-array/src/record_batch.rs | 61 ++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index 1fe0aa64edb8..cf20f9a059cb 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -359,13 +359,11 @@ impl RecordBatch { }) } - /// Forcibly overrides the schema of this [`RecordBatch`] - /// without additional schema checks however bringing all the schema compatibility responsibilities - /// to the caller site. - /// - /// If provided schema is not compatible with this [`RecordBatch`] columns the runtime behavior - /// is undefined - pub fn with_schema_force(self, schema: SchemaRef) -> Result { + /// Overrides the schema of this [`RecordBatch`] + /// without additional schema checks. Note, however, that this pushes all the schema compatibility responsibilities + /// to the caller site. In particular, the caller guarantees that `schema` is a superset + /// of the current schema as determined by [`Schema::contains`]. + pub fn with_schema_unchecked(self, schema: SchemaRef) -> Result { Ok(Self { schema, columns: self.columns, @@ -1655,15 +1653,15 @@ mod tests { } #[test] - fn test_batch_with_force_schema() { - fn force_schema_and_get_err_from_batch( + fn test_batch_with_unchecked_schema() { + fn apply_schema_unchecked( record_batch: &RecordBatch, schema_ref: SchemaRef, idx: usize, ) -> Option { record_batch .clone() - .with_schema_force(schema_ref) + .with_schema_unchecked(schema_ref) .unwrap() .project(&[idx]) .err() @@ -1677,7 +1675,7 @@ mod tests { // Test empty schema for non-empty schema batch let invalid_schema_empty = Schema::empty(); assert_eq!( - force_schema_and_get_err_from_batch(&record_batch, invalid_schema_empty.into(), 0) + apply_schema_unchecked(&record_batch, invalid_schema_empty.into(), 0) .unwrap() .to_string(), "Schema error: project index 0 out of bounds, max field 0" @@ -1686,16 +1684,16 @@ mod tests { // Wrong number of columns let invalid_schema_more_cols = Schema::new(vec![ Field::new("a", DataType::Utf8, false), - Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), ]); - assert!(force_schema_and_get_err_from_batch( - &record_batch, - invalid_schema_more_cols.clone().into(), - 0 - ) - .is_none()); + + assert!( + apply_schema_unchecked(&record_batch, invalid_schema_more_cols.clone().into(), 0) + .is_none() + ); + assert_eq!( - force_schema_and_get_err_from_batch(&record_batch, invalid_schema_more_cols.into(), 1) + apply_schema_unchecked(&record_batch, invalid_schema_more_cols.into(), 1) .unwrap() .to_string(), "Schema error: project index 1 out of bounds, max field 1" @@ -1704,6 +1702,29 @@ mod tests { // Wrong datatype let invalid_schema_wrong_datatype = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - assert_eq!(force_schema_and_get_err_from_batch(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0"); + assert_eq!(apply_schema_unchecked(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0"); + + // Wrong column name. A instead C + let invalid_schema_wrong_col_name = + Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + assert!(record_batch + .clone() + .with_schema_unchecked(invalid_schema_wrong_col_name.into()) + .unwrap() + .column_by_name("c") + .is_none()); + + // Valid schema + let valid_schema = Schema::new(vec![Field::new("c", DataType::Utf8, false)]); + + assert_eq!( + record_batch + .clone() + .with_schema_unchecked(valid_schema.into()) + .unwrap() + .column_by_name("c"), + record_batch.column_by_name("c") + ); } }