diff --git a/cluster-autoscaler/cloudprovider/aws/aws_cloud_provider.go b/cluster-autoscaler/cloudprovider/aws/aws_cloud_provider.go index 3117b1c73654..26eff930c0bd 100644 --- a/cluster-autoscaler/cloudprovider/aws/aws_cloud_provider.go +++ b/cluster-autoscaler/cloudprovider/aws/aws_cloud_provider.go @@ -115,6 +115,13 @@ func (aws *awsCloudProvider) NodeGroupForNode(node *apiv1.Node) (cloudprovider.N klog.Warningf("Node %v has no providerId", node.Name) return nil, nil } + + // Skip SageMaker HyperPod instances + if strings.HasPrefix(node.GetName(), "hyperpod") { + klog.V(4).Infof("Skipping SageMaker HyperPod node %s", node.Name) + return nil, nil + } + ref, err := AwsRefFromProviderId(node.Spec.ProviderID) if err != nil { // Dropping this into V as it will be noisy with many Hybrid Nodes @@ -143,6 +150,11 @@ func (aws *awsCloudProvider) HasInstance(node *apiv1.Node) (bool, error) { return true, cloudprovider.ErrNotImplemented } + // Skip SageMaker HyperPod instances + if strings.HasPrefix(node.GetName(), "hyperpod") { + return true, cloudprovider.ErrNotImplemented + } + // avoid log spam for not autoscaled asgs: // Nodes that belong to an asg that is not autoscaled will not be found in the asgCache below, // so do not trigger warning spam by returning an error from being unable to find them. @@ -205,10 +217,19 @@ type AwsInstanceRef struct { } var validAwsRefIdRegex = regexp.MustCompile(fmt.Sprintf(`^aws\:\/\/\/[-0-9a-z]*\/[-0-9a-z]*(\/[-0-9a-z\.]*)?$|aws\:\/\/\/[-0-9a-z]*\/%s.*$`, placeholderInstanceNamePrefix)) +var sageMakerRefIdRegex = regexp.MustCompile(`^aws:///[-0-9a-z]+/sagemaker/.*$`) // AwsRefFromProviderId creates AwsInstanceRef object from provider id which // must be in format: aws:///zone/name func AwsRefFromProviderId(id string) (*AwsInstanceRef, error) { + // Special case for SageMaker format: aws:////sagemaker/... + if sageMakerRefIdRegex.MatchString(id) { + return &AwsInstanceRef{ + ProviderID: id, + Name: "sagemaker-node", + }, nil + } + if validAwsRefIdRegex.FindStringSubmatch(id) == nil { return nil, fmt.Errorf("wrong id: expected format aws:////, got %v", id) } @@ -313,6 +334,11 @@ func (ng *AwsNodeGroup) DecreaseTargetSize(delta int) error { // Belongs returns true if the given node belongs to the NodeGroup. func (ng *AwsNodeGroup) Belongs(node *apiv1.Node) (bool, error) { + // Skip SageMaker HyperPod instances + if strings.HasPrefix(node.GetName(), "hyperpod") { + return false, nil + } + ref, err := AwsRefFromProviderId(node.Spec.ProviderID) if err != nil { return false, err diff --git a/cluster-autoscaler/cloudprovider/aws/aws_cloud_provider_test.go b/cluster-autoscaler/cloudprovider/aws/aws_cloud_provider_test.go index 1910f8752fd5..1c6bbed8ef3b 100644 --- a/cluster-autoscaler/cloudprovider/aws/aws_cloud_provider_test.go +++ b/cluster-autoscaler/cloudprovider/aws/aws_cloud_provider_test.go @@ -725,8 +725,21 @@ func TestHasInstance(t *testing.T) { assert.Equal(t, cloudprovider.ErrNotImplemented, err) assert.True(t, present) - // Case 3: correct node - not present in AWS + // Case 3: incorrect node - sagemaker hyperpod is unsupported node3 := &apiv1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "hyperpod-node-1", + }, + Spec: apiv1.NodeSpec{ + ProviderID: "aws:///use1-az2/sagemaker/cluster/hyperpod-abc123-i-abc123", + }, + } + present, err = provider.HasInstance(node3) + assert.Equal(t, cloudprovider.ErrNotImplemented, err) + assert.True(t, present) + + // Case 4: correct node - not present in AWS + node4 := &apiv1.Node{ ObjectMeta: metav1.ObjectMeta{ Name: "node-2", }, @@ -734,12 +747,12 @@ func TestHasInstance(t *testing.T) { ProviderID: "aws:///us-east-1a/test-instance-id-2", }, } - present, err = provider.HasInstance(node3) + present, err = provider.HasInstance(node4) assert.ErrorContains(t, err, nodeNotPresentErr) assert.False(t, present) - // Case 4: correct node - not autoscaled -> not present in AWS -> no warning - node4 := &apiv1.Node{ + // Case 5: correct node - not autoscaled -> not present in AWS -> no warning + node5 := &apiv1.Node{ ObjectMeta: metav1.ObjectMeta{ Name: "node-2", Annotations: map[string]string{ @@ -750,7 +763,7 @@ func TestHasInstance(t *testing.T) { ProviderID: "aws:///us-east-1a/test-instance-id-2", }, } - present, err = provider.HasInstance(node4) + present, err = provider.HasInstance(node5) assert.NoError(t, err) assert.False(t, present) }