Skip to content

Commit 9435513

Browse files
dqkqdalamb
andauthored
fix: DataFrame::select_columns and DataFrame::drop_columns for qualified duplicated field names (#18236)
## Which issue does this PR close? - Closes #18212. ## Rationale for this change `DataFrame::drop_columns` only considers one field for each `name`, it fails to drop columns from dataframe containing duplicated names from different relations. Such as `mark` columns created by multiples `Join::LeftMark`. `DataFrame::select_columns` has the same issue, it fails to select columns with the same name from different relations. ## What changes are included in this PR? Allow `DataFrame::drop_columns` and `DataFrame::select_columns` work with duplicated names from different relations. ## Are these changes tested? Yes. ## Are there any user-facing changes? No. --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 67550f1 commit 9435513

File tree

2 files changed

+103
-7
lines changed

2 files changed

+103
-7
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,12 @@ impl DataFrame {
310310
pub fn select_columns(self, columns: &[&str]) -> Result<DataFrame> {
311311
let fields = columns
312312
.iter()
313-
.map(|name| {
313+
.flat_map(|name| {
314314
self.plan
315315
.schema()
316-
.qualified_field_with_unqualified_name(name)
316+
.qualified_fields_with_unqualified_name(name)
317317
})
318-
.collect::<Result<Vec<_>>>()?;
318+
.collect::<Vec<_>>();
319319
let expr: Vec<Expr> = fields
320320
.into_iter()
321321
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
@@ -439,13 +439,12 @@ impl DataFrame {
439439
pub fn drop_columns(self, columns: &[&str]) -> Result<DataFrame> {
440440
let fields_to_drop = columns
441441
.iter()
442-
.map(|name| {
442+
.flat_map(|name| {
443443
self.plan
444444
.schema()
445-
.qualified_field_with_unqualified_name(name)
445+
.qualified_fields_with_unqualified_name(name)
446446
})
447-
.filter(|r| r.is_ok())
448-
.collect::<Result<Vec<_>>>()?;
447+
.collect::<Vec<_>>();
449448
let expr: Vec<Expr> = self
450449
.plan
451450
.schema()

datafusion/core/tests/dataframe/mod.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,55 @@ async fn select_with_periods() -> Result<()> {
404404
Ok(())
405405
}
406406

407+
#[tokio::test]
408+
async fn select_columns_duplicated_names_from_different_qualifiers() -> Result<()> {
409+
let t1 = test_table_with_name("t1")
410+
.await?
411+
.select_columns(&["c1"])?
412+
.limit(0, Some(3))?;
413+
let t2 = test_table_with_name("t2")
414+
.await?
415+
.select_columns(&["c1"])?
416+
.limit(3, Some(3))?;
417+
let t3 = test_table_with_name("t3")
418+
.await?
419+
.select_columns(&["c1"])?
420+
.limit(6, Some(3))?;
421+
422+
let join_res = t1
423+
.join(t2, JoinType::Left, &["t1.c1"], &["t2.c1"], None)?
424+
.join(t3, JoinType::Left, &["t1.c1"], &["t3.c1"], None)?;
425+
assert_snapshot!(
426+
batches_to_sort_string(&join_res.clone().collect().await.unwrap()),
427+
@r"
428+
+----+----+----+
429+
| c1 | c1 | c1 |
430+
+----+----+----+
431+
| b | b | |
432+
| b | b | |
433+
| c | | |
434+
| d | | d |
435+
+----+----+----+
436+
"
437+
);
438+
439+
let select_res = join_res.select_columns(&["c1"])?;
440+
assert_snapshot!(
441+
batches_to_sort_string(&select_res.clone().collect().await.unwrap()),
442+
@r"
443+
+----+----+----+
444+
| c1 | c1 | c1 |
445+
+----+----+----+
446+
| b | b | |
447+
| b | b | |
448+
| c | | |
449+
| d | | d |
450+
+----+----+----+
451+
"
452+
);
453+
Ok(())
454+
}
455+
407456
#[tokio::test]
408457
async fn drop_columns() -> Result<()> {
409458
// build plan using Table API
@@ -542,6 +591,54 @@ async fn drop_with_periods() -> Result<()> {
542591
Ok(())
543592
}
544593

594+
#[tokio::test]
595+
async fn drop_columns_duplicated_names_from_different_qualifiers() -> Result<()> {
596+
let t1 = test_table_with_name("t1")
597+
.await?
598+
.select_columns(&["c1"])?
599+
.limit(0, Some(3))?;
600+
let t2 = test_table_with_name("t2")
601+
.await?
602+
.select_columns(&["c1"])?
603+
.limit(3, Some(3))?;
604+
let t3 = test_table_with_name("t3")
605+
.await?
606+
.select_columns(&["c1"])?
607+
.limit(6, Some(3))?;
608+
609+
let join_res = t1
610+
.join(t2, JoinType::LeftMark, &["c1"], &["c1"], None)?
611+
.join(t3, JoinType::LeftMark, &["c1"], &["c1"], None)?;
612+
assert_snapshot!(
613+
batches_to_sort_string(&join_res.clone().collect().await.unwrap()),
614+
@r"
615+
+----+-------+-------+
616+
| c1 | mark | mark |
617+
+----+-------+-------+
618+
| b | true | false |
619+
| c | false | false |
620+
| d | false | true |
621+
+----+-------+-------+
622+
"
623+
);
624+
625+
let drop_res = join_res.drop_columns(&["mark"])?;
626+
assert_snapshot!(
627+
batches_to_sort_string(&drop_res.clone().collect().await.unwrap()),
628+
@r"
629+
+----+
630+
| c1 |
631+
+----+
632+
| b |
633+
| c |
634+
| d |
635+
+----+
636+
"
637+
);
638+
639+
Ok(())
640+
}
641+
545642
#[tokio::test]
546643
async fn aggregate() -> Result<()> {
547644
// build plan using DataFrame API

0 commit comments

Comments
 (0)