Skip to content

Output of function step is not compatible with sagemaker.clarify.ModelConfig() #4320

@acere

Description

@acere

Describe the bug
When creating a pipeline combining steps defined using @step and sagemaker.clarify.ModelConfig() the compilation results in
AttributeError: 'NoneType' object has no attribute 'sagemaker_session'.
This make it hard to combine @step functions with Clarify steps.

To reproduce
execute this script

import sagemaker
from sagemaker.clarify import BiasConfig, DataConfig, ModelConfig
from sagemaker.workflow.check_job_config import CheckJobConfig
from sagemaker.workflow.clarify_check_step import (
    ClarifyCheckStep,
    ModelBiasCheckConfig,
    ModelPredictedLabelConfig,
)
from sagemaker.workflow.function_step import step
from sagemaker.workflow.pipeline import Pipeline

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
bucket = sagemaker_session.default_bucket()
instance_type = "ml.c5"

@step(instance_type=instance_type)
def dummy_func():
    return "single-text"


@step(instance_type=instance_type)
def generate_data():
    return "data-uri"


check_job_config = CheckJobConfig(
    role=role,
    instance_count=1,
    instance_type="ml.c5.xlarge",
    volume_size_in_gb=120,
    sagemaker_session=sagemaker_session,
)

bias_config = BiasConfig(
    label_values_or_threshold=15.0,
    facet_name=["facet"],
    facet_values_or_threshold=None,
)

model_bias_data_config = DataConfig(
    s3_data_input_path=generate_data(),
    s3_output_path=f"s3://{bucket}/model-bias",
    dataset_type="text/csv",
    label="label",
    predicted_label="prediction",
    s3_analysis_config_output_path=f"s3://{bucket}/model-bias/analysis_cfg",
)

model_bias_check_config = ModelBiasCheckConfig(
    data_config=model_bias_data_config,
    data_bias_config=bias_config,
    model_predicted_label_config=ModelPredictedLabelConfig(),
    model_config=ModelConfig(
        model_name=dummy_func(),
        instance_count=1,
        instance_type="ml.m5.xlarge",
    ),
)

model_bias_check_step = ClarifyCheckStep(
    name="ModelBiasCheckStep",
    clarify_check_config=model_bias_check_config,
    check_job_config=check_job_config,
    skip_check=True,
    register_new_baseline=True,
    model_package_group_name="ModelPackageName",
)
pipeline = Pipeline(name="TestPipeline", steps=[model_bias_check_step])
definition = pipeline.definition()

System information
A description of your system. Please provide:

  • SageMaker Python SDK version: 2.199.0

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions