Add custom group_ids support to Chronos2Pipeline#429
Add custom group_ids support to Chronos2Pipeline#429StatMixedML wants to merge 2 commits intoamazon-science:mainfrom
Conversation
|
@StatMixedML Thanks for the PR. We are currently working towards AutoGluon 1.5, so I will take a careful look after the release. Before that, one small feedback I have after a quick skim is that a 900 line PR sounds a bit too much to enable this capability. Could you please check if the size of the PR can be reduced? |
| batch_future_covariates = batch["future_covariates"] | ||
| batch_target_idx_ranges = batch["target_idx_ranges"] | ||
|
|
||
| if cross_learning: |
There was a problem hiding this comment.
Could the user instead just run something like the following code?
prediction_per_group = []
for _, group_df in df.groupby("group_id"):
prediction_per_group.append(pipeline.predict(group_df.drop(columns=["group_id"], cross_learning=True, ...))
predictions = pd.concat(prediction_per_group)There was a problem hiding this comment.
@shchur I suppose your suggestion is an elegant way of doing it :-) It works the same way the PR suggests
Most of the additions come from the notebook examples and the unit tests, so the actual changes in |
|
@StatMixedML Thanks! In that case, maybe users can just go with the idea that @shchur suggested. That said, it would be cool to cover non-trivial grouping in the "advanced" section of the tutorials. Do you happen to have good (preferably not synthetic) examples of such grouping helping accuracy? |
Shall I then close the PR?
I can keep the PR open and add some examples to the notebook using publicly available data? Sth. like M5 or some monthly seasonal data with geographic grouping? |
Summary
This PR adds support for custom group IDs in Chronos2Pipeline, enabling fine-grained control over which time series share information during prediction through cross-attention. Users can now specify meaningful groupings (e.g., by geography, sector, ...) to improve forecast accuracy while preventing information leakage between unrelated series.
Motivation
Currently, users can either:
This PR adds a middle ground: selective information sharing where only series within the same group exchange information, while different groups remain independent.
Changes
Core API Changes
Added group_ids parameter to predict_df() and predict_quantiles()
Added helper functions in src/chronos/utils.py
Backward Compatibility
✅ Fully backward compatible
Documentation