Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 100 additions & 19 deletions solvers/tfc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@
group id Chronos-2 keys on. When cutoff offsets-from-end are not
homogeneous across series, the solver falls back to a per-series loop.

Covariates
----------
Static / historical / future covariates from ``ForecastInput.covariates``
are forwarded to the SDK as extra ``train_df`` columns, named via its
``static_variables`` / ``historical_variables`` / ``future_variables``
parameters. Time-varying covariates span the full series timeline, so
``cross_validate`` reads each cutoff's future-covariate values directly
from ``train_df`` (no separate ``future_df`` is needed). Datasets without
covariates (e.g. Monash) carry empty sequences, so nothing is sent and
behaviour is unchanged. Whether a given model actually consumes a covariate
kind is left to the SDK.

Adding a new model
------------------
Pass any model id from ``theforecastingcompany.utils.TFCModels`` via the
Expand Down Expand Up @@ -53,6 +65,63 @@ def _to_pandas_freq(api_freq: str) -> str:
return _PD_FREQ_REMAP.get(api_freq, api_freq)


def _covariate_column_names(covariates) -> dict[str, list[str]]:
"""Map each covariate kind to stable train_df column names.

Names are derived from the channel count of the first present series.
A kind that is empty for every series yields an empty list. We assume a
kind is either present (with the same channel count) for all series or
absent for all of them — true for every benchmark dataset today.
"""

def _names(seq, prefix):
for arr in seq:
arr = np.asarray(arr)
n = arr.shape[-1] if arr.ndim else 1
return [f"{prefix}_{k}" for k in range(n)]
return []

return {
"future": _names(covariates.future_covars, "future"),
"historical": _names(covariates.hist_covars, "hist"),
"static": _names(covariates.static_covars, "static"),
}


def _attach_covariates(frame, index_len, covariates, series_idx, col_names):
"""Add this series' covariate columns to a per-channel target ``frame``.

Covariates are per-series, so the same values attach to every channel
frame of ``series_idx``. ``hist``/``future`` arrays are ``(T, Ch)`` and
must align with the ``index_len`` rows of ``frame``; ``static`` arrays
are ``(Ch,)`` and broadcast over all rows.
"""
if series_idx < len(covariates.future_covars):
arr = np.asarray(covariates.future_covars[series_idx], dtype=np.float32)
_set_timed_columns(frame, arr, index_len, col_names["future"], "future")
if series_idx < len(covariates.hist_covars):
arr = np.asarray(covariates.hist_covars[series_idx], dtype=np.float32)
_set_timed_columns(frame, arr, index_len, col_names["historical"], "historical")
if series_idx < len(covariates.static_covars):
arr = np.asarray(covariates.static_covars[series_idx], dtype=np.float32)
arr = arr.reshape(-1)
for k, col in enumerate(col_names["static"]):
frame[col] = arr[k]


def _set_timed_columns(frame, arr, index_len, columns, kind):
"""Attach time-varying covariate columns, validating their length."""
if arr.ndim == 1:
arr = arr[:, None]
if arr.shape[0] != index_len:
raise ValueError(
f"{kind} covariate has length {arr.shape[0]} but the series has "
f"{index_len} steps; time-varying covariates must align with x."
)
for k, col in enumerate(columns):
frame[col] = arr[:, k]


def _shared_offsets_from_end(x, cutoff_indexes):
"""Return per-series cutoff offsets if shared across series, else None."""
if not cutoff_indexes:
Expand Down Expand Up @@ -105,24 +174,27 @@ def __init__(
self.batch_size = batch_size

def predict(self, x: ForecastInput) -> ForecastOutput:
# TODO: thread ``x.covariates`` (static/hist/future) through to the SDK
# once the benchmark datasets populate them. Monash currently
# carries none, so the dataclass arrives with empty sequences.
# Static / historical / future covariates ride along as extra
# ``train_df`` columns; the SDK reads them via its ``*_variables``
# params. Datasets without covariates (e.g. Monash) carry empty
# sequences, so the column lists are empty and nothing is sent.
series_list, cutoff_indexes = x.x, x.cutoff_indexes
covariates = x.covariates
col_names = _covariate_column_names(covariates)
pd_freq = _to_pandas_freq(self.freq)

offsets = _shared_offsets_from_end(series_list, cutoff_indexes)
if getattr(self.model, "supports_batching", False) and offsets is not None:
per_series, levels = self._predict_batched(
series_list, cutoff_indexes, pd_freq, offsets
series_list, cutoff_indexes, pd_freq, offsets, covariates, col_names
)
else:
per_series, levels = self._predict_per_series(
series_list, cutoff_indexes, pd_freq
series_list, cutoff_indexes, pd_freq, covariates, col_names
)
return ForecastOutput(quantiles=per_series, quantile_levels=levels)

def _predict_per_series(self, x, cutoff_indexes, pd_freq):
def _predict_per_series(self, x, cutoff_indexes, pd_freq, covariates, col_names):
per_series = []
levels = None
for series_idx, (series, cutoffs) in enumerate(zip(x, cutoff_indexes)):
Expand All @@ -132,16 +204,17 @@ def _predict_per_series(self, x, cutoff_indexes, pd_freq):
T, C = series.shape
index = pd.date_range("2000-01-01", periods=T, freq=pd_freq)

frames = [
pd.DataFrame(
frames = []
for c in range(C):
frame = pd.DataFrame(
{
"unique_id": f"s{series_idx}_c{c}",
"ds": index,
"target": series[:, c],
}
)
for c in range(C)
]
_attach_covariates(frame, T, covariates, series_idx, col_names)
frames.append(frame)
train_df = pd.concat(frames, ignore_index=True)
fcds = [pd.Timestamp(index[cutoff]) for cutoff in cutoffs]

Expand All @@ -157,6 +230,9 @@ def _predict_per_series(self, x, cutoff_indexes, pd_freq):
add_events=self.add_events,
country_isocode=self.country_isocode,
batch_size=self.batch_size,
future_variables=col_names["future"] or None,
historical_variables=col_names["historical"] or None,
static_variables=col_names["static"] or None,
)

arr, series_levels = self._gather_series_output(
Expand All @@ -166,7 +242,9 @@ def _predict_per_series(self, x, cutoff_indexes, pd_freq):
levels = series_levels
return per_series, (levels if levels is not None else (0.5,))

def _predict_batched(self, x, cutoff_indexes, pd_freq, offsets):
def _predict_batched(
self, x, cutoff_indexes, pd_freq, offsets, covariates, col_names
):
"""One ``cross_validate`` call covering every series in ``x``.

Series are aligned to share an end date so all cutoffs collapse to
Expand All @@ -183,15 +261,15 @@ def _predict_batched(self, x, cutoff_indexes, pd_freq, offsets):
T, C = series.shape
index = pd.date_range(end=end, periods=T, freq=pd_freq)
for c in range(C):
frames.append(
pd.DataFrame(
{
"unique_id": f"s{series_idx}_c{c}",
"ds": index,
"target": series[:, c],
}
)
frame = pd.DataFrame(
{
"unique_id": f"s{series_idx}_c{c}",
"ds": index,
"target": series[:, c],
}
)
_attach_covariates(frame, T, covariates, series_idx, col_names)
frames.append(frame)
per_series_meta.append((series_idx, C, index, cutoffs))

train_df = pd.concat(frames, ignore_index=True)
Expand All @@ -213,6 +291,9 @@ def _predict_batched(self, x, cutoff_indexes, pd_freq, offsets):
add_events=self.add_events,
country_isocode=self.country_isocode,
batch_size=self.batch_size,
future_variables=col_names["future"] or None,
historical_variables=col_names["historical"] or None,
static_variables=col_names["static"] or None,
)

per_series = []
Expand Down
Loading