File tree Expand file tree Collapse file tree 4 files changed +8
-6
lines changed
scripts/sagemaker_scripts Expand file tree Collapse file tree 4 files changed +8
-6
lines changed Original file line number Diff line number Diff line change 1
- import json
2
1
import logging
3
2
from abc import abstractmethod
4
3
from typing import Optional
5
4
6
- import pandas as pd
7
5
import sagemaker
8
6
9
7
from ..utils .ag_sagemaker import (
Original file line number Diff line number Diff line change @@ -181,15 +181,15 @@ def predict_real_time(
181
181
self .id_column = id_column or self .id_column
182
182
self .timestamp_column = timestamp_column or self .timestamp_column
183
183
self .target_column = target or self .target_column
184
-
184
+
185
185
return self .backend .predict_real_time (
186
186
test_data = test_data ,
187
187
id_column = self .id_column ,
188
188
timestamp_column = self .timestamp_column ,
189
189
target = self .target_column ,
190
190
static_features = static_features ,
191
191
accept = accept ,
192
- inference_kwargs = kwargs
192
+ inference_kwargs = kwargs ,
193
193
)
194
194
195
195
def predict_proba_real_time (self , ** kwargs ) -> pd .DataFrame :
Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ def _save_image_and_update_dataframe_column(bytes):
33
33
34
34
return im_path
35
35
36
+
36
37
def _custom_json_deserializer (serialized_str ):
37
38
"""
38
39
Deserialize the JSON string that may include representations of complex data types like DataFrames
@@ -55,6 +56,7 @@ def _custom_json_deserializer(serialized_str):
55
56
56
57
return deserialized_kwargs
57
58
59
+
58
60
def model_fn (model_dir ):
59
61
"""loads model from previously saved artifact"""
60
62
logger .info ("Loading the model" )
Original file line number Diff line number Diff line change 1
1
# flake8: noqa
2
+ import logging
2
3
import os
3
4
import pickle
4
5
import shutil
6
+ import sys
5
7
from io import BytesIO , StringIO
6
8
7
9
import pandas as pd
8
- import logging
9
- import sys
10
10
11
11
from autogluon .timeseries import TimeSeriesDataFrame , TimeSeriesPredictor
12
+
12
13
logging .basicConfig (stream = sys .stdout , level = logging .INFO )
13
14
logger = logging .getLogger (__name__ )
14
15
16
+
15
17
def model_fn (model_dir ):
16
18
"""loads model from previously saved artifact"""
17
19
# TSPredictor will write to the model file during inference while the default model_dir is read only
You can’t perform that action at this time.
0 commit comments