diff --git a/src/checkedframe/_utils.py b/src/checkedframe/_utils.py index f8ded0b..db40da2 100644 --- a/src/checkedframe/_utils.py +++ b/src/checkedframe/_utils.py @@ -63,8 +63,11 @@ def get_class_members(object, predicate=None): def _parse_args_into_iterable( args: tuple[Any, ...] | tuple[Iterable[Any]], ) -> Iterable[Any]: - if len(args) == 1 and isinstance(args[0], Iterable): - return args[0] + if len(args) == 1: + x = args[0] + + if isinstance(x, Iterable) and not isinstance(x, str): + return x return args diff --git a/src/checkedframe/selectors.py b/src/checkedframe/selectors.py index f4289ab..4fe4c1b 100644 --- a/src/checkedframe/selectors.py +++ b/src/checkedframe/selectors.py @@ -5,6 +5,7 @@ from typing import Callable, Union from ._dtypes import Boolean, Categorical, CfUnion, Date, Datetime, String, TypedColumn +from ._utils import _parse_args_into_iterable TypeOrInstance = Union[TypedColumn, type[TypedColumn], CfUnion] @@ -41,13 +42,19 @@ def __xor__(self, other: Selector) -> Selector: lambda col, dtype: self.condition(col, dtype) != other.condition(col, dtype) ) - def exclude(self, other: str | Iterable[str] | Selector) -> Selector: - if not isinstance(other, Selector): - selector_other = by_name(other) - else: - selector_other = other + def exclude(self, *other: str | Iterable[str] | Selector) -> Selector: + + other_list = _parse_args_into_iterable(other) + + for o in other_list: + if not isinstance(o, Selector): + selector_other = by_name(o) + else: + selector_other = o + + self = self.__sub__(selector_other) - return self.__sub__(selector_other) + return self def _flatten_str_iterable(lst: Iterable[str | Iterable[str]]) -> list[str]: diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 81c7b6f..7244f34 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -81,6 +81,12 @@ def test_exclude(): assert cfs.numeric().exclude("int8")(SCHEMA) == ["float32"] assert cfs.numeric().exclude(cfs.by_name("int8"))(SCHEMA) == ["float32"] assert cfs.numeric().exclude(["int8"])(SCHEMA) == ["float32"] + assert cfs.all().exclude("list_string", "list_list_string")(SCHEMA) == [ + "string", + "int8", + "float32", + "boolean", + ] def test_temporal_example():