Skip to content

Commit

Permalink
preserve field order in Schema.select
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Scherz <[email protected]>
  • Loading branch information
felixscherz committed Apr 14, 2024
1 parent 4b91105 commit 9ede25c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,17 @@ def select(self, *names: str, case_sensitive: bool = True) -> Schema:
"""
try:
if case_sensitive:
ids = {self._name_to_id[name] for name in names}
ids = [self._name_to_id[name] for name in names]
else:
ids = {self._lazy_name_to_id_lower[name.lower()] for name in names}
ids = [self._lazy_name_to_id_lower[name.lower()] for name in names]
except KeyError as e:
raise ValueError(f"Could not find column: {e}") from e

return prune_columns(self, ids)
pruned_schema = prune_columns(self, set(ids))

fields = sorted(pruned_schema.fields, key=lambda f: ids.index(f.field_id))

return Schema(*fields, schema_id=pruned_schema.schema_id, identifier_field_ids=pruned_schema.identifier_field_ids)

@property
def field_ids(self) -> Set[int]:
Expand Down
11 changes: 11 additions & 0 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,17 @@ def test_table_scan_projection_single_column(table_v2: Table) -> None:
assert projection_schema.schema_id == 1


def test_table_scan_select_preserves_order(table_v2: Table) -> None:
scan = table_v2.scan()
assert scan.select("y", "x", "z").projection() == Schema(
NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"),
NestedField(field_id=1, name="x", field_type=LongType(), required=True),
NestedField(field_id=3, name="z", field_type=LongType(), required=True),
schema_id=1,
identifier_field_ids=[1, 2],
)


def test_table_scan_projection_single_column_case_sensitive(table_v2: Table) -> None:
scan = table_v2.scan()
projection_schema = scan.with_case_sensitive(False).select("Y").projection()
Expand Down

0 comments on commit 9ede25c

Please sign in to comment.