Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandas-stubs/core/indexes/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ from pandas._typing import (
S2_NSDT,
T_COMPLEX,
AnyAll,
AnyArrayLike,
ArrayLike,
AxesData,
CategoryDtypeArg,
Expand Down Expand Up @@ -440,7 +441,7 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
@property
def values(self) -> np_1darray: ...
def memory_usage(self, deep: bool = False): ...
def where(self, cond, other: Scalar | ArrayLike | None = None): ...
def where(self, cond, other: Scalar | AnyArrayLike | None = None) -> Self: ...
def __contains__(self, key) -> bool: ...
@final
def __setitem__(self, key, value) -> None: ...
Expand Down
25 changes: 25 additions & 0 deletions tests/indexes/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pandas.core.arrays.timedeltas import TimedeltaArray
from pandas.core.indexes.base import Index
from pandas.core.indexes.category import CategoricalIndex
from pandas.core.indexes.datetimes import DatetimeIndex
from typing_extensions import (
Never,
assert_type,
Expand Down Expand Up @@ -1541,3 +1542,27 @@ def test_multiindex_swaplevel() -> None:
"""Test that MultiIndex.swaplevel returns MultiIndex"""
mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"])
check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex)


def test_index_where() -> None:
"""Test Index.where with multiple types of other GH1419."""
idx = pd.Index(range(48))
mask = np.ones(48, dtype=bool)
val_idx = idx.where(mask, idx)
check(assert_type(val_idx, "pd.Index[int]"), pd.Index, int)

val_sr = idx.where(mask, (idx).to_series())
check(assert_type(val_sr, "pd.Index[int]"), pd.Index, int)


def test_datetimeindex_where() -> None:
"""Test DatetimeIndex.where with multiple types of other GH1419."""
datetime_index = pd.date_range(start="2025-01-01", freq="h", periods=48)
mask = np.ones(48, dtype=bool)
val_idx = datetime_index.where(mask, datetime_index - pd.Timedelta(days=1))
check(assert_type(val_idx, DatetimeIndex), DatetimeIndex)

val_sr = datetime_index.where(
mask, (datetime_index - pd.Timedelta(days=1)).to_series()
)
check(assert_type(val_sr, DatetimeIndex), DatetimeIndex)