Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/checkedframe/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ def get_class_members(object, predicate=None):
def _parse_args_into_iterable(
args: tuple[Any, ...] | tuple[Iterable[Any]],
) -> Iterable[Any]:
if len(args) == 1 and isinstance(args[0], Iterable):
return args[0]
if len(args) == 1:
x = args[0]

if isinstance(x, Iterable) and not isinstance(x, str):
return x

return args

Expand Down
19 changes: 13 additions & 6 deletions src/checkedframe/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, Union

from ._dtypes import Boolean, Categorical, CfUnion, Date, Datetime, String, TypedColumn
from ._utils import _parse_args_into_iterable

TypeOrInstance = Union[TypedColumn, type[TypedColumn], CfUnion]

Expand Down Expand Up @@ -41,13 +42,19 @@ def __xor__(self, other: Selector) -> Selector:
lambda col, dtype: self.condition(col, dtype) != other.condition(col, dtype)
)

def exclude(self, other: str | Iterable[str] | Selector) -> Selector:
if not isinstance(other, Selector):
selector_other = by_name(other)
else:
selector_other = other
def exclude(self, *other: str | Iterable[str] | Selector) -> Selector:

other_list = _parse_args_into_iterable(other)

for o in other_list:
if not isinstance(o, Selector):
selector_other = by_name(o)
else:
selector_other = o

self = self.__sub__(selector_other)

return self.__sub__(selector_other)
return self


def _flatten_str_iterable(lst: Iterable[str | Iterable[str]]) -> list[str]:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def test_exclude():
assert cfs.numeric().exclude("int8")(SCHEMA) == ["float32"]
assert cfs.numeric().exclude(cfs.by_name("int8"))(SCHEMA) == ["float32"]
assert cfs.numeric().exclude(["int8"])(SCHEMA) == ["float32"]
assert cfs.all().exclude("list_string", "list_list_string")(SCHEMA) == [
"string",
"int8",
"float32",
"boolean",
]


def test_temporal_example():
Expand Down
Loading