From 9b9ed534b2022cb9a687f4ed876fadcc2457b31b Mon Sep 17 00:00:00 2001 From: Anthony Lam <2289591+AnthonyLam@users.noreply.github.com> Date: Mon, 9 Sep 2024 12:36:30 -0700 Subject: [PATCH] fix: Invert `case_sensitive` logic in StructType (#1147) * fix: Invert logic in StructType * Add test for StructType.field_by_name * Remove var I forgot about. * Fix formatting post-lint --- pyiceberg/types.py | 6 +++--- tests/test_types.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 97ddea0e57..8fa745384d 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -377,13 +377,13 @@ def field(self, field_id: int) -> Optional[NestedField]: def field_by_name(self, name: str, case_sensitive: bool = True) -> Optional[NestedField]: if case_sensitive: - name_lower = name.lower() for field in self.fields: - if field.name.lower() == name_lower: + if field.name == name: return field else: + name_lower = name.lower() for field in self.fields: - if field.name == name: + if field.name.lower() == name_lower: return field return None diff --git a/tests/test_types.py b/tests/test_types.py index 52bdce4de8..b19df17e08 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -149,6 +149,20 @@ def test_struct_type() -> None: assert type_var == pickle.loads(pickle.dumps(type_var)) +def test_struct_field_by_name() -> None: + lower_field = NestedField(1, "lower_case_field", IntegerType(), required=True) + upper_field = NestedField(2, "UPPER_CASE_FIELD", IntegerType(), required=True) + type_var = StructType(lower_field, upper_field) + + assert type_var.field_by_name("lower_case_field", case_sensitive=False) == lower_field + assert type_var.field_by_name("upper_case_field", case_sensitive=False) == upper_field + assert type_var.field_by_name("nonexistent_field", case_sensitive=False) is None + + assert type_var.field_by_name("lower_case_field", case_sensitive=True) == lower_field + assert type_var.field_by_name("upper_case_field", case_sensitive=True) is None + assert type_var.field_by_name("nonexistent_field", case_sensitive=True) is None + + def test_list_type() -> None: type_var = ListType( 1,