Skip to content

Use single file write when an extension is present in the path. #13079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 135 additions & 41 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2246,47 +2246,7 @@ mod tests {

#[tokio::test]
async fn parquet_sink_write() -> Result<()> {
let field_a = Field::new("a", DataType::Utf8, false);
let field_b = Field::new("b", DataType::Utf8, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
let object_store_url = ObjectStoreUrl::local_filesystem();

let file_sink_config = FileSinkConfig {
object_store_url: object_store_url.clone(),
file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
table_paths: vec![ListingTableUrl::parse("file:///")?],
output_schema: schema.clone(),
table_partition_cols: vec![],
insert_op: InsertOp::Overwrite,
keep_partition_by_columns: false,
};
let parquet_sink = Arc::new(ParquetSink::new(
file_sink_config,
TableParquetOptions {
key_value_metadata: std::collections::HashMap::from([
("my-data".to_string(), Some("stuff".to_string())),
("my-data-bool-key".to_string(), None),
]),
..Default::default()
},
));

// create data
let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"]));
let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"]));
let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap();

// write stream
parquet_sink
.write_all(
Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(vec![Ok(batch)]),
)),
&build_ctx(object_store_url.as_ref()),
)
.await
.unwrap();
let parquet_sink = create_written_parquet_sink("file:///").await?;

// assert written
let mut written = parquet_sink.written();
Expand Down Expand Up @@ -2338,6 +2298,140 @@ mod tests {
Ok(())
}

#[tokio::test]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice too

async fn parquet_sink_write_with_extension() -> Result<()> {
let filename = "test_file.custom_ext";
let file_path = format!("file:///path/to/{}", filename);
let parquet_sink = create_written_parquet_sink(file_path.as_str()).await?;

// assert written
let mut written = parquet_sink.written();
let written = written.drain();
assert_eq!(
written.len(),
1,
"expected a single parquet file to be written, instead found {}",
written.len()
);

let (path, ..) = written.take(1).next().unwrap();

let path_parts = path.parts().collect::<Vec<_>>();
assert_eq!(
path_parts.len(),
3,
"Expected 3 path parts, instead found {}",
path_parts.len()
);
assert_eq!(path_parts.last().unwrap().as_ref(), filename);

Ok(())
}

#[tokio::test]
async fn parquet_sink_write_with_directory_name() -> Result<()> {
let file_path = "file:///path/to";
let parquet_sink = create_written_parquet_sink(file_path).await?;

// assert written
let mut written = parquet_sink.written();
let written = written.drain();
assert_eq!(
written.len(),
1,
"expected a single parquet file to be written, instead found {}",
written.len()
);

let (path, ..) = written.take(1).next().unwrap();

let path_parts = path.parts().collect::<Vec<_>>();
assert_eq!(
path_parts.len(),
3,
"Expected 3 path parts, instead found {}",
path_parts.len()
);
assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet"));

Ok(())
}

#[tokio::test]
async fn parquet_sink_write_with_folder_ending() -> Result<()> {
let file_path = "file:///path/to/";
let parquet_sink = create_written_parquet_sink(file_path).await?;

// assert written
let mut written = parquet_sink.written();
let written = written.drain();
assert_eq!(
written.len(),
1,
"expected a single parquet file to be written, instead found {}",
written.len()
);

let (path, ..) = written.take(1).next().unwrap();

let path_parts = path.parts().collect::<Vec<_>>();
assert_eq!(
path_parts.len(),
3,
"Expected 3 path parts, instead found {}",
path_parts.len()
);
assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet"));

Ok(())
}

async fn create_written_parquet_sink(table_path: &str) -> Result<Arc<ParquetSink>> {
let field_a = Field::new("a", DataType::Utf8, false);
let field_b = Field::new("b", DataType::Utf8, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
let object_store_url = ObjectStoreUrl::local_filesystem();

let file_sink_config = FileSinkConfig {
object_store_url: object_store_url.clone(),
file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
table_paths: vec![ListingTableUrl::parse(table_path)?],
output_schema: schema.clone(),
table_partition_cols: vec![],
insert_op: InsertOp::Overwrite,
keep_partition_by_columns: false,
};
let parquet_sink = Arc::new(ParquetSink::new(
file_sink_config,
TableParquetOptions {
key_value_metadata: std::collections::HashMap::from([
("my-data".to_string(), Some("stuff".to_string())),
("my-data-bool-key".to_string(), None),
]),
..Default::default()
},
));

// create data
let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"]));
let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"]));
let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap();

// write stream
parquet_sink
.write_all(
Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(vec![Ok(batch)]),
)),
&build_ctx(object_store_url.as_ref()),
)
.await
.unwrap();

Ok(parquet_sink)
}

#[tokio::test]
async fn parquet_sink_write_partitions() -> Result<()> {
let field_a = Field::new("a", DataType::Utf8, false);
Expand Down
14 changes: 8 additions & 6 deletions datafusion/core/src/datasource/file_format/write/demux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
/// which should be contained within the same output file. The outer channel
/// is used to send a dynamic number of inner channels, representing a dynamic
/// number of total output files. The caller is also responsible to monitor
/// the demux task for errors and abort accordingly. The single_file_output parameter
/// overrides all other settings to force only a single file to be written.
/// the demux task for errors and abort accordingly. A path with an extension will
/// force only a single file to be written with the extension from the path. Otherwise
/// the default extension will be used and the output will be split into multiple files.
/// partition_by parameter will additionally split the input based on the unique
/// values of a specific column `<https://github.com/apache/datafusion/issues/7744>``
/// ┌───────────┐ ┌────────────┐ ┌─────────────┐
Expand All @@ -79,12 +80,13 @@ pub(crate) fn start_demuxer_task(
context: &Arc<TaskContext>,
partition_by: Option<Vec<(String, DataType)>>,
base_output_path: ListingTableUrl,
file_extension: String,
default_extension: String,
keep_partition_by_columns: bool,
) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
let (tx, rx) = mpsc::unbounded_channel();
let context = context.clone();
let single_file_output = !base_output_path.is_collection();
let single_file_output =
!base_output_path.is_collection() && base_output_path.file_extension().is_some();
let task = match partition_by {
Some(parts) => {
// There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot
Expand All @@ -96,7 +98,7 @@ pub(crate) fn start_demuxer_task(
context,
parts,
base_output_path,
file_extension,
default_extension,
keep_partition_by_columns,
)
.await
Expand All @@ -108,7 +110,7 @@ pub(crate) fn start_demuxer_task(
input,
context,
base_output_path,
file_extension,
default_extension,
single_file_output,
)
.await
Expand Down
63 changes: 63 additions & 0 deletions datafusion/core/src/datasource/listing/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,19 @@ impl ListingTableUrl {
self.url.path().ends_with(DELIMITER)
}

/// Returns the file extension of the last path segment if it exists
pub fn file_extension(&self) -> Option<&str> {
if let Some(segments) = self.url.path_segments() {
if let Some(last_segment) = segments.last() {
if last_segment.contains(".") && !last_segment.ends_with(".") {
return last_segment.split('.').last();
}
}
}

None
}

/// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning
/// an iterator of the remaining path segments
pub(crate) fn strip_prefix<'a, 'b: 'a>(
Expand Down Expand Up @@ -493,4 +506,54 @@ mod tests {
"path not ends with / - fragment ends with / - not collection",
);
}

#[test]
fn test_file_extension() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice tests 👏

fn test(input: &str, expected: Option<&str>, message: &str) {
let url = ListingTableUrl::parse(input).unwrap();
assert_eq!(url.file_extension(), expected, "{message}");
}

test("https://a.b.c/path/", None, "path ends with / - not a file");
test(
"https://a.b.c/path/?a=b",
None,
"path ends with / - with query args - not a file",
);
test(
"https://a.b.c/path?a=b/",
None,
"path not ends with / - query ends with / but no file extension",
);
test(
"https://a.b.c/path/#a=b",
None,
"path ends with / - with fragment - not a file",
);
test(
"https://a.b.c/path#a=b/",
None,
"path not ends with / - fragment ends with / but no file extension",
);
test(
"file///some/path/",
None,
"file path ends with / - not a file",
);
test(
"file///some/path/file",
None,
"file path does not end with - no extension",
);
test(
"file///some/path/file.",
None,
"file path ends with . - no value after .",
);
test(
"file///some/path/file.ext",
Some("ext"),
"file path ends with .ext - extension is ext",
);
}
}