Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/chronos/df_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ def validate_df_inputs(
df[timestamp_column] = pd.to_datetime(df[timestamp_column])
df = df.sort_values([id_column, timestamp_column])

# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()
# Get series lengths in the exact order that appears in the sorted dataframe.
# This avoids dtype-specific ordering differences (e.g., string[python]) that can
# break the alignment with contiguous timestamp slices below.
series_lengths = df.groupby(id_column, sort=False).size().to_list()

def validate_freq(timestamps: pd.DatetimeIndex, series_id: str):
freq = pd.infer_freq(timestamps)
Expand Down Expand Up @@ -273,8 +275,8 @@ def convert_df_input_to_list_of_dicts_input(
# Get the original order of time series IDs
original_order = df[id_column].unique()

# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()
# Keep lengths aligned with dataframe row order regardless of ID dtype.
series_lengths = df.groupby(id_column, sort=False).size().to_list()

# If freq is not provided, infer from the first series with >= 3 points
if freq is None:
Expand Down
80 changes: 80 additions & 0 deletions test/test_df_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import hashlib
from unittest.mock import patch

import numpy as np
Expand All @@ -16,6 +17,22 @@
# Tests for validate_df_inputs function


def _create_unequal_length_weekly_df(
num_series: int = 200, num_periods: int = 60, period_variation: int = 30, seed: int = 42
) -> pd.DataFrame:
rng = np.random.default_rng(seed=seed)
end_date = pd.date_range(start="2023-01-02", periods=num_periods + period_variation, freq="W-MON")[-1]
series_data = []
for i in range(num_series):
series_id = hashlib.sha256(f"series_{i}".encode()).hexdigest()
periods = int(rng.integers(num_periods - period_variation, num_periods + period_variation + 1))
timestamps = pd.date_range(end=end_date, periods=periods, freq="W-MON")
series_data.append(
pd.DataFrame({"item_id": series_id, "timestamp": timestamps, "target": rng.normal(size=periods)})
)
return pd.concat(series_data, ignore_index=True)


@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
def test_validate_df_inputs_returns_correct_metadata_for_valid_inputs(freq):
"""Test that function returns validated dataframes, frequency, series lengths, and original order."""
Expand Down Expand Up @@ -71,6 +88,47 @@ def test_validate_df_inputs_casts_mixed_dtypes_correctly():
assert validated_df["bool_cov"].dtype == np.float32 # booleans are cast to float32


def test_validate_df_inputs_accepts_string_python_ids_with_unequal_lengths():
"""Regression test for issue #440 with string[python] IDs and unequal series lengths."""
df = _create_unequal_length_weekly_df()
df["item_id"] = df["item_id"].astype("string[python]")

_, _, inferred_freq, series_lengths, _ = validate_df_inputs(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)

assert inferred_freq == "W-MON"
assert len(series_lengths) == df["item_id"].nunique()
assert sum(series_lengths) == len(df)


def test_validate_df_inputs_has_consistent_metadata_for_object_and_string_python_ids():
"""Validation metadata should not depend on whether ID dtype is object or string[python]."""
object_df = _create_unequal_length_weekly_df(seed=7)
string_df = object_df.copy()
string_df["item_id"] = string_df["item_id"].astype("string[python]")

_, _, object_freq, object_lengths, object_order = validate_df_inputs(
df=object_df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)
_, _, string_freq, string_lengths, string_order = validate_df_inputs(
df=string_df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)

assert string_freq == object_freq
assert string_lengths == object_lengths
assert [str(x) for x in string_order] == [str(x) for x in object_order]


def test_validate_df_inputs_raises_error_when_series_has_insufficient_data():
"""Test that ValueError is raised for series with < 3 data points."""
# Create dataframe with one series having only 2 points
Expand Down Expand Up @@ -460,3 +518,25 @@ def test_convert_df_with_mismatched_freq_uses_user_provided_freq(use_future_df):
# Verify the frequency matches user-provided freq
inferred_freq = pd.infer_freq(pred_ts)
assert inferred_freq == user_freq


def test_convert_df_with_validate_inputs_false_handles_string_python_ids():
"""validate_inputs=False should work with string[python] IDs and preserve per-series lengths."""
df = _create_unequal_length_weekly_df(seed=11)
df["item_id"] = df["item_id"].astype("string[python]")
df = df.sort_values(["item_id", "timestamp"])

inputs, original_order, _ = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
validate_inputs=False,
freq="W-MON",
)

expected_lengths = df.groupby("item_id", sort=False).size().to_list()
observed_lengths = [task["target"].shape[1] for task in inputs]

assert observed_lengths == expected_lengths
assert len(original_order) == len(expected_lengths)