Skip to content

Commit 5164316

Browse files
uyoldasUemit Yoldaspintaoz-awspintaozsage-maker
authored
feature: integrate amtviz for visualization of tuning jobs (#5044)
* feature: integrate amtviz for visualization of tuning jobs * Move RecordSerializer and RecordDeserializer to sagemaker.serializers and sagemaker.deserialzers (#5037) * Move RecordSerializer and RecordDeserializer to sagemaker.serializers and sagemaker.deserializers * fix codestyle * fix test --------- Co-authored-by: pintaoz <[email protected]> * Add framework_version to all TensorFlowModel examples (#5038) * Add framework_version to all TensorFlowModel examples * update framework_version to x.x.x --------- Co-authored-by: pintaoz <[email protected]> * Fix hyperparameter strategy docs (#5045) * fix: pass in inference_ami_version to model_based endpoint type (#5043) * fix: pass in inference_ami_version to model_based endpoint type * documentation: update contributing.md w/ venv instructions and pip install fixes --------- Co-authored-by: Zhaoqi <[email protected]> * Add warning about not supporting torch.nn.SyncBatchNorm (#5046) * Add warning about not supporting * update wording --------- Co-authored-by: pintaoz <[email protected]> * prepare release v2.239.2 * update development version to v2.239.3.dev0 * change: update image_uri_configs 02-19-2025 06:18:15 PST * fix: codestyle, type hints, license, and docstrings * documentation: add docstring for amtviz module * fix: fix docstyle and flake8 errors * fix: code reformat using black --------- Co-authored-by: Uemit Yoldas <[email protected]> Co-authored-by: pintaoz-aws <[email protected]> Co-authored-by: pintaoz <[email protected]> Co-authored-by: parknate@ <[email protected]> Co-authored-by: timkuo-amazon <[email protected]> Co-authored-by: Zhaoqi <[email protected]> Co-authored-by: ci <ci> Co-authored-by: sagemaker-bot <[email protected]>
1 parent 31f34dd commit 5164316

File tree

7 files changed

+1597
-0
lines changed

7 files changed

+1597
-0
lines changed

src/sagemaker/amtviz/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Amazon SageMaker Automatic Model Tuning Visualization module.
14+
15+
This module provides visualization capabilities for SageMaker hyperparameter tuning jobs.
16+
It enables users to create interactive visualizations to analyze and understand the
17+
performance of hyperparameter optimization experiments.
18+
19+
Example:
20+
>>> from sagemaker.amtviz import visualize_tuning_job
21+
>>> visualize_tuning_job('my-tuning-job')
22+
"""
23+
from __future__ import absolute_import
24+
25+
from sagemaker.amtviz.visualization import visualize_tuning_job
26+
27+
__all__ = ["visualize_tuning_job"]

src/sagemaker/amtviz/job_metrics.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Helper functions to retrieve job metrics from CloudWatch."""
14+
from __future__ import absolute_import
15+
16+
from datetime import datetime, timedelta
17+
from typing import Callable, List, Optional, Tuple, Dict, Any
18+
import hashlib
19+
import os
20+
from pathlib import Path
21+
22+
import logging
23+
import pandas as pd
24+
import numpy as np
25+
import boto3
26+
27+
logger = logging.getLogger(__name__)
28+
29+
cw = boto3.client("cloudwatch")
30+
sm = boto3.client("sagemaker")
31+
32+
33+
def disk_cache(outer: Callable) -> Callable:
34+
"""A decorator that implements disk-based caching for CloudWatch metrics data.
35+
36+
This decorator caches the output of the wrapped function to disk in JSON Lines format.
37+
It creates a cache key using MD5 hash of the function arguments and stores the data
38+
in the user's home directory under .amtviz/cw_metrics_cache/.
39+
40+
Args:
41+
outer (Callable): The function to be wrapped. Must return a pandas DataFrame
42+
containing CloudWatch metrics data.
43+
44+
Returns:
45+
Callable: A wrapper function that implements the caching logic.
46+
"""
47+
48+
def inner(*args: Any, **kwargs: Any) -> pd.DataFrame:
49+
key_input = str(args) + str(kwargs)
50+
# nosec b303 - Not used for cryptography, but to create lookup key
51+
key = hashlib.md5(key_input.encode("utf-8")).hexdigest()
52+
cache_dir = Path.home().joinpath(".amtviz/cw_metrics_cache")
53+
fn = f"{cache_dir}/req_{key}.jsonl.gz"
54+
if Path(fn).exists():
55+
try:
56+
df = pd.read_json(fn, lines=True)
57+
logger.debug("H", end="")
58+
df["ts"] = pd.to_datetime(df["ts"])
59+
df["ts"] = df["ts"].dt.tz_localize(None)
60+
# pyright: ignore [reportIndexIssue, reportOptionalSubscript]
61+
df["rel_ts"] = pd.to_datetime(df["rel_ts"])
62+
df["rel_ts"] = df["rel_ts"].dt.tz_localize(None)
63+
return df
64+
except KeyError:
65+
# Empty file leads to empty df, hence no df['ts'] possible
66+
pass
67+
# nosec b110 - doesn't matter why we could not load it.
68+
except BaseException as e:
69+
logger.error("\nException: %s - %s", type(e), e)
70+
71+
logger.debug("M", end="")
72+
df = outer(*args, **kwargs)
73+
assert isinstance(df, pd.DataFrame), "Only caching Pandas DataFrames."
74+
75+
os.makedirs(cache_dir, exist_ok=True)
76+
df.to_json(fn, orient="records", date_format="iso", lines=True)
77+
78+
return df
79+
80+
return inner
81+
82+
83+
def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> Dict[str, Any]:
84+
"""Returns a CloudWatch metric data query template."""
85+
return {
86+
"Id": metric_name.lower().replace(":", "_").replace("-", "_"),
87+
"MetricStat": {
88+
"Stat": "Average",
89+
"Metric": {
90+
"Namespace": "/aws/sagemaker/TrainingJobs",
91+
"MetricName": metric_name,
92+
"Dimensions": [
93+
{"Name": dim_name, "Value": dim_value},
94+
],
95+
},
96+
"Period": 60,
97+
},
98+
"ReturnData": True,
99+
}
100+
101+
102+
def _get_metric_data(
103+
queries: List[Dict[str, Any]], start_time: datetime, end_time: datetime
104+
) -> pd.DataFrame:
105+
"""Fetches CloudWatch metrics between timestamps, returns a DataFrame with selected columns."""
106+
start_time = start_time - timedelta(hours=1)
107+
end_time = end_time + timedelta(hours=1)
108+
response = cw.get_metric_data(MetricDataQueries=queries, StartTime=start_time, EndTime=end_time)
109+
110+
df = pd.DataFrame()
111+
if "MetricDataResults" not in response:
112+
return df
113+
114+
for metric_data in response["MetricDataResults"]:
115+
values = metric_data["Values"]
116+
ts = np.array(metric_data["Timestamps"], dtype=np.datetime64)
117+
labels = [metric_data["Label"]] * len(values)
118+
119+
df = pd.concat([df, pd.DataFrame({"value": values, "ts": ts, "label": labels})])
120+
121+
# We now calculate the relative time based on the first actual observed
122+
# time stamps, not the potentially start time that we used to scope our CW
123+
# API call. The difference could be for example startup times or waiting
124+
# for Spot.
125+
if not df.empty:
126+
df["rel_ts"] = datetime.fromtimestamp(1) + (df["ts"] - df["ts"].min()) # pyright: ignore
127+
return df
128+
129+
130+
@disk_cache
131+
def _collect_metrics(
132+
dimensions: List[Tuple[str, str]], start_time: datetime, end_time: Optional[datetime]
133+
) -> pd.DataFrame:
134+
"""Collects SageMaker training job metrics from CloudWatch for dimensions and time range."""
135+
df = pd.DataFrame()
136+
for dim_name, dim_value in dimensions:
137+
response = cw.list_metrics(
138+
Namespace="/aws/sagemaker/TrainingJobs",
139+
Dimensions=[
140+
{"Name": dim_name, "Value": dim_value},
141+
],
142+
)
143+
if not response["Metrics"]:
144+
continue
145+
metric_names = [metric["MetricName"] for metric in response["Metrics"]]
146+
if not metric_names:
147+
# No metric data yet, or not any longer, because the data were aged out
148+
continue
149+
metric_data_queries = [
150+
_metric_data_query_tpl(metric_name, dim_name, dim_value) for metric_name in metric_names
151+
]
152+
df = pd.concat([df, _get_metric_data(metric_data_queries, start_time, end_time)])
153+
154+
return df
155+
156+
157+
def get_cw_job_metrics(
158+
job_name: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None
159+
) -> pd.DataFrame:
160+
"""Retrieves CloudWatch metrics for a SageMaker training job.
161+
162+
Args:
163+
job_name (str): Name of the SageMaker training job.
164+
start_time (datetime, optional): Start time for metrics collection.
165+
Defaults to now - 4 hours.
166+
end_time (datetime, optional): End time for metrics collection.
167+
Defaults to start_time + 4 hours.
168+
169+
Returns:
170+
pd.DataFrame: Metrics data with columns for value, timestamp, and metric name.
171+
Results are cached to disk for improved performance.
172+
"""
173+
dimensions = [
174+
("TrainingJobName", job_name),
175+
("Host", job_name + "/algo-1"),
176+
]
177+
# If not given, use reasonable defaults for start and end time
178+
start_time = start_time or datetime.now() - timedelta(hours=4)
179+
end_time = end_time or start_time + timedelta(hours=4)
180+
return _collect_metrics(dimensions, start_time, end_time)

0 commit comments

Comments
 (0)