diff --git a/src/autogluon/cloud/backend/timeseries_sagemaker_backend.py b/src/autogluon/cloud/backend/timeseries_sagemaker_backend.py index 3bee739..e2efcb1 100644 --- a/src/autogluon/cloud/backend/timeseries_sagemaker_backend.py +++ b/src/autogluon/cloud/backend/timeseries_sagemaker_backend.py @@ -166,7 +166,7 @@ def predict_real_time( static_features: Optional[pd.DataFrame] An optional data frame describing the metadata attributes of individual items in the item index. For more detail, please refer to `TimeSeriesDataFrame` documentation: - https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe + https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html target: str Name of column that contains the target values to forecast accept: str, default = application/x-parquet @@ -225,7 +225,7 @@ def predict( static_features: Optional[Union[str, pd.DataFrame]] An optional data frame describing the metadata attributes of individual items in the item index. For more detail, please refer to `TimeSeriesDataFrame` documentation: - https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe + https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html target: str Name of column that contains the target values to forecast kwargs: diff --git a/src/autogluon/cloud/predictor/timeseries_cloud_predictor.py b/src/autogluon/cloud/predictor/timeseries_cloud_predictor.py index 8216957..97cb04c 100644 --- a/src/autogluon/cloud/predictor/timeseries_cloud_predictor.py +++ b/src/autogluon/cloud/predictor/timeseries_cloud_predictor.py @@ -80,7 +80,7 @@ def fit( static_features: Optional[pd.DataFrame] An optional data frame describing the metadata attributes of individual items in the item index. For more detail, please refer to `TimeSeriesDataFrame` documentation: - https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe + https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html framework_version: str, default = `latest` Training container version of autogluon. If `latest`, will use the latest available container version. @@ -159,6 +159,7 @@ def predict_real_time( self, test_data: Union[str, pd.DataFrame], static_features: Optional[Union[str, pd.DataFrame]] = None, + known_covariates: Optional[pd.DataFrame] = None, accept: str = "application/x-parquet", **kwargs, ) -> pd.DataFrame: @@ -175,7 +176,12 @@ def predict_real_time( static_features: Optional[pd.DataFrame] An optional data frame describing the metadata attributes of individual items in the item index. For more detail, please refer to `TimeSeriesDataFrame` documentation: - https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe + https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html + known_covariates : Optional[pd.DataFrame] + If ``known_covariates_names`` were specified when creating the predictor, it is necessary to provide the + values of the known covariates for each time series during the forecast horizon. + For more details, please refer to the `TimeSeriesPredictor.predictor` documentation: + https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesPredictor.predict.html accept: str, default = application/x-parquet Type of accept output content. Valid options are application/x-parquet, text/csv, application/json @@ -198,6 +204,7 @@ def predict_real_time( target=self.target_column, static_features=static_features, accept=accept, + inference_kwargs=dict(known_covariates=known_covariates, **kwargs), ) def predict_proba_real_time(self, **kwargs) -> pd.DataFrame: @@ -224,6 +231,9 @@ def predict( This method would first create a AutoGluonSagemakerInferenceModel with the trained predictor, then create a transformer with it, and call transform in the end. + Note that batch prediction with `known_covariates` is currently not supported. Please use `predict_real_time` + to predict with `known_covariates` instead. + Parameters ---------- test_data: str @@ -232,7 +242,7 @@ def predict( static_features: Optional[Union[str, pd.DataFrame]] An optional data frame describing the metadata attributes of individual items in the item index. For more detail, please refer to `TimeSeriesDataFrame` documentation: - https://auto.gluon.ai/stable/api/autogluon.predictor.html#timeseriesdataframe + https://auto.gluon.ai/stable/api/autogluon.timeseries.TimeSeriesDataFrame.html target: str Name of column that contains the target values to forecast predictor_path: str