diff --git a/go.mod b/go.mod index 74c4f6698867..fc3185090bb4 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/awslabs/amazon-eks-ami/nodeadm v0.0.0-20240229193347-cfab22a10647 github.com/awslabs/operatorpkg v0.0.0-20240805231134-67d0acfb6306 github.com/go-logr/zapr v1.3.0 + github.com/google/uuid v1.6.0 github.com/imdario/mergo v0.3.16 github.com/jonathan-innis/aws-sdk-go-prometheus v0.1.1-0.20240804232425-54c8227e0bab github.com/mitchellh/hashstructure/v2 v2.0.2 @@ -65,7 +66,6 @@ require ( github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index ab0077632c31..ddb78450a174 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -651,7 +651,7 @@ var _ = Describe("CloudProvider", func() { }, }, }) - awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{ + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{ SecurityGroups: []*ec2.SecurityGroup{ { GroupId: aws.String(validSecurityGroup), @@ -665,7 +665,7 @@ var _ = Describe("CloudProvider", func() { }, }, }) - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{ Subnets: []*ec2.Subnet{ { SubnetId: aws.String(validSubnet1), @@ -1129,7 +1129,7 @@ var _ = Describe("CloudProvider", func() { }) It("should launch instances into subnet with the most available IP addresses", func() { awsEnv.SubnetCache.Flush() - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(100), @@ -1146,7 +1146,7 @@ var _ = Describe("CloudProvider", func() { }) It("should launch instances into subnet with the most available IP addresses in-between cache refreshes", func() { awsEnv.SubnetCache.Flush() - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(11), @@ -1174,7 +1174,7 @@ var _ = Describe("CloudProvider", func() { Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-1")) }) It("should update in-flight IPs when a CreateFleet error occurs", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, }}) @@ -1185,12 +1185,20 @@ var _ = Describe("CloudProvider", func() { Expect(len(bindings)).To(Equal(0)) }) It("should launch instances into subnets that are excluded by another NodePool", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ - {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(10), - Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, - {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(100), - Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}}, - }}) + awsEnv.EC2API.Subnets.Store("test-zone-1a", &ec2.Subnet{ + SubnetId: aws.String("test-subnet-1"), + AvailabilityZone: aws.String("test-zone-1a"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int64(10), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }) + awsEnv.EC2API.Subnets.Store("test-zone-1b", &ec2.Subnet{ + SubnetId: aws.String("test-subnet-2"), + AvailabilityZone: aws.String("test-zone-1b"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int64(100), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}, + }) nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{{Tags: map[string]string{"Name": "test-subnet-1"}}} ExpectApplied(ctx, env.Client, nodePool, nodeClass) controller := status.NewController(env.Client, awsEnv.SubnetProvider, awsEnv.SecurityGroupProvider, awsEnv.AMIProvider, awsEnv.InstanceProfileProvider, awsEnv.LaunchTemplateProvider) diff --git a/pkg/controllers/nodeclass/status/subnet_test.go b/pkg/controllers/nodeclass/status/subnet_test.go index 21a11493614d..d6be47bbe88f 100644 --- a/pkg/controllers/nodeclass/status/subnet_test.go +++ b/pkg/controllers/nodeclass/status/subnet_test.go @@ -79,7 +79,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { Expect(nodeClass.StatusConditions().IsTrue(v1.ConditionTypeSubnetsReady)).To(BeTrue()) }) It("Should have the correct ordering for the Subnets", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("subnet-test1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(20)}, {SubnetId: aws.String("subnet-test2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1b"), AvailableIpAddressCount: aws.Int64(100)}, {SubnetId: aws.String("subnet-test3"), AvailabilityZone: aws.String("test-zone-1c"), AvailabilityZoneId: aws.String("tstz1-1c"), AvailableIpAddressCount: aws.Int64(50)}, diff --git a/pkg/controllers/providers/ssm/invalidation/controller.go b/pkg/controllers/providers/ssm/invalidation/controller.go index d57be1594f89..81b44e69c6f9 100644 --- a/pkg/controllers/providers/ssm/invalidation/controller.go +++ b/pkg/controllers/providers/ssm/invalidation/controller.go @@ -29,6 +29,9 @@ import ( v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" "github.com/aws/karpenter-provider-aws/pkg/providers/amifamily" "github.com/aws/karpenter-provider-aws/pkg/providers/ssm" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/uuid" ) // The SSM Invalidation controller is responsible for invalidating "latest" SSM parameters when they point to deprecated @@ -66,6 +69,9 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { amis := []amifamily.AMI{} for _, nodeClass := range lo.Map(lo.Keys(amiIDsToParameters), func(amiID string, _ int) *v1.EC2NodeClass { return &v1.EC2NodeClass{ + ObjectMeta: metav1.ObjectMeta{ + UID: uuid.NewUUID(), // ensures that this doesn't hit the AMI cache. + }, Spec: v1.EC2NodeClassSpec{ AMISelectorTerms: []v1.AMISelectorTerm{{ID: amiID}}, }, diff --git a/pkg/fake/atomic.go b/pkg/fake/atomic.go index d9a34f03a065..866e880452fb 100644 --- a/pkg/fake/atomic.go +++ b/pkg/fake/atomic.go @@ -161,6 +161,13 @@ func (a *AtomicPtrSlice[T]) Pop() *T { return last } +func (a *AtomicPtrSlice[T]) At(index int) *T { + a.mu.Lock() + defer a.mu.Unlock() + + return clone(a.values[index]) +} + func (a *AtomicPtrSlice[T]) ForEach(fn func(*T)) { a.mu.RLock() defer a.mu.RUnlock() diff --git a/pkg/fake/ec2api.go b/pkg/fake/ec2api.go index 060e0fb67134..ff79624a988e 100644 --- a/pkg/fake/ec2api.go +++ b/pkg/fake/ec2api.go @@ -48,8 +48,8 @@ type CapacityPool struct { type EC2Behavior struct { DescribeImagesOutput AtomicPtr[ec2.DescribeImagesOutput] DescribeLaunchTemplatesOutput AtomicPtr[ec2.DescribeLaunchTemplatesOutput] - DescribeSubnetsOutput AtomicPtr[ec2.DescribeSubnetsOutput] - DescribeSecurityGroupsOutput AtomicPtr[ec2.DescribeSecurityGroupsOutput] + DescribeSubnetsBehavior MockedFunction[ec2.DescribeSubnetsInput, ec2.DescribeSubnetsOutput] + DescribeSecurityGroupsBehavior MockedFunction[ec2.DescribeSecurityGroupsInput, ec2.DescribeSecurityGroupsOutput] DescribeInstanceTypesOutput AtomicPtr[ec2.DescribeInstanceTypesOutput] DescribeInstanceTypeOfferingsOutput AtomicPtr[ec2.DescribeInstanceTypeOfferingsOutput] DescribeAvailabilityZonesOutput AtomicPtr[ec2.DescribeAvailabilityZonesOutput] @@ -61,6 +61,7 @@ type EC2Behavior struct { CreateTagsBehavior MockedFunction[ec2.CreateTagsInput, ec2.CreateTagsOutput] CalledWithCreateLaunchTemplateInput AtomicPtrSlice[ec2.CreateLaunchTemplateInput] CalledWithDescribeImagesInput AtomicPtrSlice[ec2.DescribeImagesInput] + Subnets sync.Map Instances sync.Map LaunchTemplates sync.Map InsufficientCapacityPools atomic.Slice[CapacityPool] @@ -84,8 +85,8 @@ var DefaultSupportedUsageClasses = aws.StringSlice([]string{"on-demand", "spot"} func (e *EC2API) Reset() { e.DescribeImagesOutput.Reset() e.DescribeLaunchTemplatesOutput.Reset() - e.DescribeSubnetsOutput.Reset() - e.DescribeSecurityGroupsOutput.Reset() + e.DescribeSubnetsBehavior.Reset() + e.DescribeSecurityGroupsBehavior.Reset() e.DescribeInstanceTypesOutput.Reset() e.DescribeInstanceTypeOfferingsOutput.Reset() e.DescribeAvailabilityZonesOutput.Reset() @@ -405,107 +406,109 @@ func (e *EC2API) DeleteLaunchTemplateWithContext(_ context.Context, input *ec2.D } func (e *EC2API) DescribeSubnetsWithContext(_ context.Context, input *ec2.DescribeSubnetsInput, _ ...request.Option) (*ec2.DescribeSubnetsOutput, error) { - if !e.NextError.IsNil() { - defer e.NextError.Reset() - return nil, e.NextError.Get() - } - if !e.DescribeSubnetsOutput.IsNil() { - describeSubnetsOutput := e.DescribeSubnetsOutput.Clone() - describeSubnetsOutput.Subnets = FilterDescribeSubnets(describeSubnetsOutput.Subnets, input.Filters) - return describeSubnetsOutput, nil - } - subnets := []*ec2.Subnet{ - { - SubnetId: aws.String("subnet-test1"), - AvailabilityZone: aws.String("test-zone-1a"), - AvailabilityZoneId: aws.String("tstz1-1a"), - AvailableIpAddressCount: aws.Int64(100), - MapPublicIpOnLaunch: aws.Bool(false), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-1")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + return e.DescribeSubnetsBehavior.Invoke(input, func(input *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) { + output := &ec2.DescribeSubnetsOutput{} + e.Subnets.Range(func(key, value any) bool { + subnet := value.(*ec2.Subnet) + if lo.Contains(lo.Map(input.SubnetIds, func(s *string, _ int) string { return lo.FromPtr(s) }), lo.FromPtr(subnet.SubnetId)) || len(input.Filters) != 0 && len(FilterDescribeSubnets([]*ec2.Subnet{subnet}, input.Filters)) != 0 { + output.Subnets = append(output.Subnets, subnet) + } + return true + }) + if len(output.Subnets) != 0 { + return output, nil + } + + defaultSubnets := []*ec2.Subnet{ + { + SubnetId: aws.String("subnet-test1"), + AvailabilityZone: aws.String("test-zone-1a"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int64(100), + MapPublicIpOnLaunch: aws.Bool(false), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-1")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test2"), - AvailabilityZone: aws.String("test-zone-1b"), - AvailabilityZoneId: aws.String("tstz1-1b"), - AvailableIpAddressCount: aws.Int64(100), - MapPublicIpOnLaunch: aws.Bool(true), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-2")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + SubnetId: aws.String("subnet-test2"), + AvailabilityZone: aws.String("test-zone-1b"), + AvailabilityZoneId: aws.String("tstz1-1b"), + AvailableIpAddressCount: aws.Int64(100), + MapPublicIpOnLaunch: aws.Bool(true), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-2")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test3"), - AvailabilityZone: aws.String("test-zone-1c"), - AvailabilityZoneId: aws.String("tstz1-1c"), - AvailableIpAddressCount: aws.Int64(100), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-3")}, - {Key: aws.String("TestTag")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + SubnetId: aws.String("subnet-test3"), + AvailabilityZone: aws.String("test-zone-1c"), + AvailabilityZoneId: aws.String("tstz1-1c"), + AvailableIpAddressCount: aws.Int64(100), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-3")}, + {Key: aws.String("TestTag")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test4"), - AvailabilityZone: aws.String("test-zone-1a-local"), - AvailabilityZoneId: aws.String("tstz1-1alocal"), - AvailableIpAddressCount: aws.Int64(100), - MapPublicIpOnLaunch: aws.Bool(true), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-4")}, + { + SubnetId: aws.String("subnet-test4"), + AvailabilityZone: aws.String("test-zone-1a-local"), + AvailabilityZoneId: aws.String("tstz1-1alocal"), + AvailableIpAddressCount: aws.Int64(100), + MapPublicIpOnLaunch: aws.Bool(true), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-4")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - } - if len(input.Filters) == 0 { - return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") - } - return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(subnets, input.Filters)}, nil + } + if len(input.Filters) == 0 { + return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") + } + return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(defaultSubnets, input.Filters)}, nil + }) } func (e *EC2API) DescribeSecurityGroupsWithContext(_ context.Context, input *ec2.DescribeSecurityGroupsInput, _ ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) { - if !e.NextError.IsNil() { - defer e.NextError.Reset() - return nil, e.NextError.Get() - } - if !e.DescribeSecurityGroupsOutput.IsNil() { - describeSecurityGroupsOutput := e.DescribeSecurityGroupsOutput.Clone() - describeSecurityGroupsOutput.SecurityGroups = FilterDescribeSecurtyGroups(describeSecurityGroupsOutput.SecurityGroups, input.Filters) - return e.DescribeSecurityGroupsOutput.Clone(), nil - } - sgs := []*ec2.SecurityGroup{ - { - GroupId: aws.String("sg-test1"), - GroupName: aws.String("securityGroup-test1"), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-1")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + return e.DescribeSecurityGroupsBehavior.Invoke(input, func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) { + defaultSecurityGroups := []*ec2.SecurityGroup{ + { + GroupId: aws.String("sg-test1"), + GroupName: aws.String("securityGroup-test1"), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-1")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - { - GroupId: aws.String("sg-test2"), - GroupName: aws.String("securityGroup-test2"), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-2")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + GroupId: aws.String("sg-test2"), + GroupName: aws.String("securityGroup-test2"), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-2")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - { - GroupId: aws.String("sg-test3"), - GroupName: aws.String("securityGroup-test3"), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-3")}, - {Key: aws.String("TestTag")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + GroupId: aws.String("sg-test3"), + GroupName: aws.String("securityGroup-test3"), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-3")}, + {Key: aws.String("TestTag")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - } - if len(input.Filters) == 0 { - return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") - } - return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(sgs, input.Filters)}, nil + } + if len(input.Filters) == 0 { + return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") + } + return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(defaultSecurityGroups, input.Filters)}, nil + }) } func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) { diff --git a/pkg/fake/types.go b/pkg/fake/types.go index 88fe2ca83bc9..0edc7bc40a2a 100644 --- a/pkg/fake/types.go +++ b/pkg/fake/types.go @@ -15,14 +15,22 @@ limitations under the License. package fake import ( + "reflect" + "sync" "sync/atomic" + + "github.com/google/uuid" + "github.com/samber/lo" ) type MockedFunction[I any, O any] struct { - Output AtomicPtr[O] // Output to return on call to this function + Output AtomicPtr[O] // Output to return on call to this function + MultiOut AtomicPtrSlice[O] + OutputPages AtomicPtrSlice[O] CalledWithInput AtomicPtrSlice[I] // Slice used to keep track of passed input to this function Error AtomicError // Error to return a certain number of times defined by custom error options + pageMapping sync.Map // token uuid -> page number: Internal construct to keep track of the page that we are on successfulCalls atomic.Int32 // Internal construct to keep track of the number of times this function has successfully been called failedCalls atomic.Int32 // Internal construct to keep track of the number of times this function has failed (with error) } @@ -31,11 +39,15 @@ type MockedFunction[I any, O any] struct { // each other. func (m *MockedFunction[I, O]) Reset() { m.Output.Reset() + m.MultiOut.Reset() + m.OutputPages.Reset() m.CalledWithInput.Reset() m.Error.Reset() m.successfulCalls.Store(0) m.failedCalls.Store(0) + m.pageMapping.Clear() + } func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O, error)) (*O, error) { @@ -50,6 +62,29 @@ func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O, m.successfulCalls.Add(1) return m.Output.Clone(), nil } + + if m.MultiOut.Len() > 0 { + m.successfulCalls.Add(1) + return m.MultiOut.Pop(), nil + } + // This output pages multi-threaded handling isn't perfect + // It will fail if pages are asynchronously requested from the same NextToken + if m.OutputPages.Len() > 0 { + token := uuid.New().String() // generate a token so that each paginated request set gets its own mapping + if !reflect.ValueOf(input).Elem().FieldByName("NextToken").Elem().CanSet() { + m.pageMapping.Store(token, 0) + } else { + token = reflect.ValueOf(input).Elem().FieldByName("NextToken").Elem().String() + } + pageNum := lo.Must(m.pageMapping.Load(token)).(int) + page := m.OutputPages.At(pageNum) + if pageNum < m.OutputPages.Len()-1 { + reflect.ValueOf(page).Elem().FieldByName("NextToken").Set(reflect.ValueOf(lo.ToPtr(token))) + } + m.pageMapping.Store(token, pageNum+1) + m.successfulCalls.Add(1) + return page, nil + } out, err := defaultTransformer(input) if err != nil { m.failedCalls.Add(1) diff --git a/pkg/providers/amifamily/ami.go b/pkg/providers/amifamily/ami.go index 4eeb475a3fea..1f343becb177 100644 --- a/pkg/providers/amifamily/ami.go +++ b/pkg/providers/amifamily/ami.go @@ -27,11 +27,13 @@ import ( "github.com/patrickmn/go-cache" "github.com/samber/lo" "k8s.io/utils/clock" - "sigs.k8s.io/controller-runtime/pkg/log" + + "github.com/aws/karpenter-provider-aws/pkg/utils" v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" "github.com/aws/karpenter-provider-aws/pkg/providers/version" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/karpenter/pkg/cloudprovider" "sigs.k8s.io/karpenter/pkg/scheduling" "sigs.k8s.io/karpenter/pkg/utils/pretty" @@ -68,17 +70,13 @@ func NewDefaultProvider(clock clock.Clock, versionProvider version.Provider, ssm func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) (AMIs, error) { p.Lock() defer p.Unlock() - queries, err := p.DescribeImageQueries(ctx, nodeClass) - if err != nil { - return nil, fmt.Errorf("getting AMI queries, %w", err) - } // Discover deprecated AMIs if automatic AMI discovery and upgrade is enabled. This ensures we'll be able to // provision in the event of an EKS optimized AMI being deprecated. includeDeprecated := false if alias := nodeClass.Alias(); alias != nil { includeDeprecated = alias.Version == v1.AliasVersionLatest } - amis, err := p.amis(ctx, queries, includeDeprecated) + amis, err := p.amis(ctx, nodeClass, includeDeprecated) if err != nil { return nil, err } @@ -151,12 +149,13 @@ func (p *DefaultProvider) DescribeImageQueries(ctx context.Context, nodeClass *v } //nolint:gocyclo -func (p *DefaultProvider) amis(ctx context.Context, queries []DescribeImageQuery, includeDeprecated bool) (AMIs, error) { - hash, err := hashstructure.Hash(queries, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) +func (p *DefaultProvider) amis(ctx context.Context, nodeClass *v1.EC2NodeClass, includeDeprecated bool) (AMIs, error) { + queries, err := p.DescribeImageQueries(ctx, nodeClass) if err != nil { - return nil, err + return nil, fmt.Errorf("getting AMI queries, %w", err) } - if images, ok := p.cache.Get(fmt.Sprintf("%d", hash)); ok { + hash := utils.GetNodeClassHash(nodeClass) + if images, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a deep-copy of AMIs so alterations // to the data don't affect the original return append(AMIs{}, images.(AMIs)...), nil @@ -198,7 +197,7 @@ func (p *DefaultProvider) amis(ctx context.Context, queries []DescribeImageQuery return nil, fmt.Errorf("describing images, %w", err) } } - p.cache.SetDefault(fmt.Sprintf("%d", hash), AMIs(lo.Values(images))) + p.cache.SetDefault(hash, AMIs(lo.Values(images))) return lo.Values(images), nil } diff --git a/pkg/providers/amifamily/suite_test.go b/pkg/providers/amifamily/suite_test.go index 8e9110322e0d..dcb33e6b64c4 100644 --- a/pkg/providers/amifamily/suite_test.go +++ b/pkg/providers/amifamily/suite_test.go @@ -29,6 +29,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/gstruct" . "sigs.k8s.io/karpenter/pkg/utils/testing" "github.com/samber/lo" @@ -44,6 +45,8 @@ import ( "github.com/aws/karpenter-provider-aws/pkg/operator/options" "github.com/aws/karpenter-provider-aws/pkg/providers/amifamily" "github.com/aws/karpenter-provider-aws/pkg/test" + + . "sigs.k8s.io/karpenter/pkg/test/expectations" ) var ctx context.Context @@ -303,6 +306,219 @@ var _ = Describe("AMIProvider", func() { })) }) }) + Context("Provider Cache", func() { + It("should resolve AMIs from cache that are filtered by id", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String(coretest.RandomName()), + ImageId: aws.String("ami-123"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + { + Name: aws.String(coretest.RandomName()), + ImageId: aws.String("ami-456"), + Architecture: lo.ToPtr("arm64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + ID: "ami-123", + }, + { + ID: "ami-456", + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "AmiID": Equal("ami-123"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "AmiID": Equal("ami-456"), + }), + )) + }) + It("should resolve AMIs from cache that are filtered by name", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: lo.ToPtr("arm64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Name: "ami-name-1", + }, + { + Name: "ami-name-2", + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + }) + It("should resolve AMIs from cache that are filtered by tags", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: lo.ToPtr("arm64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"test": "test"}, + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String("ami-name-3"), + ImageId: aws.String("ami-789"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMIFamily = &v1.AMIFamilyAL2 + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + amis, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(amis).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-3"), + }), + )) + + // OR semantics + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: lo.ToPtr("arm64"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + amis, err = awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(amis).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + + cacheItems := awsEnv.AMICache.Items() + Expect(cacheItems).To(HaveLen(2)) + cachedImages := make([]amifamily.AMIs, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedImages = append(cachedImages, item.Object.(amifamily.AMIs)) + } + + Expect(cachedImages).To(ConsistOf( + ConsistOf( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-3"), + }), + ), + ConsistOf( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + ), + )) + }) + }) Context("AMI Selectors", func() { // When you tag public or shared resources, the tags you assign are available only to your AWS account; no other AWS account will have access to those tags // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions diff --git a/pkg/providers/instancetype/suite_test.go b/pkg/providers/instancetype/suite_test.go index 09413d513707..e779bd900b3e 100644 --- a/pkg/providers/instancetype/suite_test.go +++ b/pkg/providers/instancetype/suite_test.go @@ -1812,7 +1812,7 @@ var _ = Describe("InstanceTypeProvider", func() { } }) It("shouldn't report more resources than are actually available on instances", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{ Subnets: []*ec2.Subnet{ { AvailabilityZone: aws.String("us-west-2a"), diff --git a/pkg/providers/securitygroup/securitygroup.go b/pkg/providers/securitygroup/securitygroup.go index 998498277f32..d82c5a88440e 100644 --- a/pkg/providers/securitygroup/securitygroup.go +++ b/pkg/providers/securitygroup/securitygroup.go @@ -22,7 +22,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "github.com/samber/lo" "sigs.k8s.io/controller-runtime/pkg/log" @@ -30,6 +29,7 @@ import ( "sigs.k8s.io/karpenter/pkg/utils/pretty" v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" + "github.com/aws/karpenter-provider-aws/pkg/utils" ) type Provider interface { @@ -56,9 +56,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) p.Lock() defer p.Unlock() - // Get SecurityGroups - filterSets := getFilterSets(nodeClass.Spec.SecurityGroupSelectorTerms) - securityGroups, err := p.getSecurityGroups(ctx, filterSets) + securityGroups, err := p.getSecurityGroups(ctx, nodeClass) if err != nil { return nil, err } @@ -71,12 +69,10 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) return securityGroups, nil } -func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][]*ec2.Filter) ([]*ec2.SecurityGroup, error) { - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if sg, ok := p.cache.Get(fmt.Sprint(hash)); ok { +func (p *DefaultProvider) getSecurityGroups(ctx context.Context, nodeClass *v1.EC2NodeClass) ([]*ec2.SecurityGroup, error) { + filterSets := getFilterSets(nodeClass.Spec.SecurityGroupSelectorTerms) + hash := utils.GetNodeClassHash(nodeClass) + if sg, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) // so that modifications to the ordering of the data don't affect the original return append([]*ec2.SecurityGroup{}, sg.([]*ec2.SecurityGroup)...), nil @@ -91,7 +87,7 @@ func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][] securityGroups[lo.FromPtr(output.SecurityGroups[i].GroupId)] = output.SecurityGroups[i] } } - p.cache.SetDefault(fmt.Sprint(hash), lo.Values(securityGroups)) + p.cache.SetDefault(hash, lo.Values(securityGroups)) return lo.Values(securityGroups), nil } diff --git a/pkg/providers/securitygroup/suite_test.go b/pkg/providers/securitygroup/suite_test.go index 6a2e18fe8111..4178a5db45fd 100644 --- a/pkg/providers/securitygroup/suite_test.go +++ b/pkg/providers/securitygroup/suite_test.go @@ -116,7 +116,7 @@ var _ = Describe("SecurityGroupProvider", func() { }, securityGroups) }) It("should discover security groups by tag", func() { - awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ {GroupName: aws.String("test-sgName-1"), GroupId: aws.String("test-sg-1"), Tags: []*ec2.Tag{{Key: aws.String("kubernetes.io/cluster/test-cluster"), Value: aws.String("test-sg-1")}}}, {GroupName: aws.String("test-sgName-2"), GroupId: aws.String("test-sg-2"), Tags: []*ec2.Tag{{Key: aws.String("kubernetes.io/cluster/test-cluster"), Value: aws.String("test-sg-2")}}}, }}) @@ -272,7 +272,13 @@ var _ = Describe("SecurityGroupProvider", func() { }) Context("Provider Cache", func() { It("should resolve security groups from cache that are filtered by id", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) for _, sg := range expectedSecurityGroups { nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ { @@ -284,6 +290,7 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]*ec2.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) @@ -291,7 +298,13 @@ var _ = Describe("SecurityGroupProvider", func() { } }) It("should resolve security groups from cache that are filtered by Name", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) for _, sg := range expectedSecurityGroups { nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ { @@ -303,6 +316,7 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]*ec2.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) @@ -310,7 +324,13 @@ var _ = Describe("SecurityGroupProvider", func() { } }) It("should resolve security groups from cache that are filtered by tags", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) tagSet := lo.Map(expectedSecurityGroups, func(sg *ec2.SecurityGroup, _ int) map[string]string { tag, _ := lo.Find(sg.Tags, func(tag *ec2.Tag) bool { return lo.FromPtr(tag.Key) == "Name" @@ -334,6 +354,88 @@ var _ = Describe("SecurityGroupProvider", func() { lo.Contains(expectedSecurityGroups, cachedSecurityGroup[0]) } }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + {GroupName: aws.String("test-sgName-3"), GroupId: aws.String("test-sg-3"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + securityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSecurityGroups([]*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-3"), + GroupName: aws.String("test-sgName-3"), + }, + }, securityGroups) + + // OR semantics + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + {GroupName: aws.String("test-sgName-2"), GroupId: aws.String("test-sg-2"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + {GroupName: aws.String("test-sgName-1"), GroupId: aws.String("test-sg-1"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}}, + }}) + nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + securityGroups, err = awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSecurityGroups([]*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-1"), + GroupName: aws.String("test-sgName-1"), + }, + { + GroupId: aws.String("test-sg-2"), + GroupName: aws.String("test-sgName-2"), + }, + }, securityGroups) + + cacheItems := awsEnv.SecurityGroupCache.Items() + // There should be 2 cache entries one for each semantic. + Expect(cacheItems).To(HaveLen(2)) + // Extract cached security group arrays for comparison + cachedSecurityGroups := make([][]*ec2.SecurityGroup, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedSecurityGroups = append(cachedSecurityGroups, item.Object.([]*ec2.SecurityGroup)) + } + // Expect cache to contain result of both look ups. + Expect(cachedSecurityGroups).To(ContainElement(ContainElements( + []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-1"), + GroupName: aws.String("test-sgName-1"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + }, + { + GroupId: aws.String("test-sg-2"), + GroupName: aws.String("test-sgName-2"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + ))) + Expect(cachedSecurityGroups).To(ContainElement( + []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-3"), + GroupName: aws.String("test-sgName-3"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + )) + }) }) It("should not cause data races when calling List() simultaneously", func() { wg := sync.WaitGroup{} diff --git a/pkg/providers/subnet/subnet.go b/pkg/providers/subnet/subnet.go index 58fc1776804c..82323042b10f 100644 --- a/pkg/providers/subnet/subnet.go +++ b/pkg/providers/subnet/subnet.go @@ -23,13 +23,13 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "github.com/samber/lo" corev1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/log" v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" + "github.com/aws/karpenter-provider-aws/pkg/utils" karpv1 "sigs.k8s.io/karpenter/pkg/apis/v1" "sigs.k8s.io/karpenter/pkg/cloudprovider" @@ -82,11 +82,8 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) if len(filterSets) == 0 { return []*ec2.Subnet{}, nil } - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok { + hash := utils.GetNodeClassHash(nodeClass) + if subnets, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) // so that modifications to the ordering of the data don't affect the original return append([]*ec2.Subnet{}, subnets.([]*ec2.Subnet)...), nil @@ -108,7 +105,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) delete(p.inflightIPs, lo.FromPtr(output.Subnets[i].SubnetId)) // remove any previously tracked IP addresses since we just refreshed from EC2 } } - p.cache.SetDefault(fmt.Sprint(hash), lo.Values(subnets)) + p.cache.SetDefault(hash, lo.Values(subnets)) if p.cm.HasChanged(fmt.Sprintf("subnets/%s", nodeClass.Name), lo.Keys(subnets)) { log.FromContext(ctx). WithValues("subnets", lo.Map(lo.Values(subnets), func(s *ec2.Subnet, _ int) v1.Subnet { diff --git a/pkg/providers/subnet/suite_test.go b/pkg/providers/subnet/suite_test.go index f6dedfa1d2e7..f3a4d714575a 100644 --- a/pkg/providers/subnet/suite_test.go +++ b/pkg/providers/subnet/suite_test.go @@ -22,6 +22,7 @@ import ( "sigs.k8s.io/karpenter/pkg/test/v1alpha1" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/samber/lo" @@ -232,7 +233,13 @@ var _ = Describe("SubnetProvider", func() { }) Context("Provider Cache", func() { It("should resolve subnets from cache that are filtered by id", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets + expectedSubnets := []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }, + } + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: expectedSubnets}) for _, subnet := range expectedSubnets { nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ { @@ -244,6 +251,7 @@ var _ = Describe("SubnetProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SubnetCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]*ec2.Subnet) Expect(cachedSubnet).To(HaveLen(1)) @@ -251,7 +259,13 @@ var _ = Describe("SubnetProvider", func() { } }) It("should resolve subnets from cache that are filtered by tags", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets + expectedSubnets := []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }, + } + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: expectedSubnets}) tagSet := lo.Map(expectedSubnets, func(subnet *ec2.Subnet, _ int) map[string]string { tag, _ := lo.Find(subnet.Tags, func(tag *ec2.Tag) bool { return lo.FromPtr(tag.Key) == "Name" @@ -269,12 +283,98 @@ var _ = Describe("SubnetProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SubnetCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]*ec2.Subnet) Expect(cachedSubnet).To(HaveLen(1)) lo.Contains(expectedSubnets, cachedSubnet[0]) } }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), SubnetArn: aws.String("test-subnet-arn-3"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }}) + nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSubnets([]*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), + SubnetArn: aws.String("test-subnet-arn-3"), + }, + }, subnets) + + // OR semantics + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + {SubnetId: aws.String("test-subnet-id-2"), SubnetArn: aws.String("test-subnet-arn-2"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + {SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}}, + }}) + nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + subnets, err = awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSubnets([]*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), + SubnetArn: aws.String("test-subnet-arn-1"), + }, + { + SubnetId: aws.String("test-subnet-id-2"), + SubnetArn: aws.String("test-subnet-arn-2"), + }, + }, subnets) + + cacheItems := awsEnv.SubnetCache.Items() + // There should be 2 cache entries one for each semantic. + Expect(cacheItems).To(HaveLen(2)) + // Extract cached subnet arrays for comparison + cachedSubnets := make([][]*ec2.Subnet, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedSubnets = append(cachedSubnets, item.Object.([]*ec2.Subnet)) + } + // Expect cache to contain result of both look ups. + Expect(cachedSubnets).To(ContainElement(ContainElements( + []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), + SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + }, + { + SubnetId: aws.String("test-subnet-id-2"), + SubnetArn: aws.String("test-subnet-arn-2"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + ))) + Expect(cachedSubnets).To(ContainElement( + []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), + SubnetArn: aws.String("test-subnet-arn-3"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + )) + }) }) It("should not cause data races when calling List() simultaneously", func() { wg := sync.WaitGroup{} @@ -311,6 +411,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1b"), @@ -329,6 +430,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1c"), @@ -348,6 +450,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1a-local"), @@ -361,6 +464,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("test-subnet-4"), }, }, + VpcId: aws.String("vpc-test1"), }, })) }() diff --git a/pkg/test/environment.go b/pkg/test/environment.go index 8067f578c1d1..f2d8dce8a601 100644 --- a/pkg/test/environment.go +++ b/pkg/test/environment.go @@ -59,6 +59,7 @@ type Environment struct { PricingAPI *fake.PricingAPI // Cache + AMICache *cache.Cache EC2Cache *cache.Cache KubernetesVersionCache *cache.Cache InstanceTypeCache *cache.Cache @@ -96,6 +97,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment iamapi := fake.NewIAMAPI() // cache + amiCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) ec2Cache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) kubernetesVersionCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) instanceTypeCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) @@ -116,7 +118,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment versionProvider := version.NewDefaultProvider(env.KubernetesInterface, kubernetesVersionCache) instanceProfileProvider := instanceprofile.NewDefaultProvider(fake.DefaultRegion, iamapi, instanceProfileCache) ssmProvider := ssmp.NewDefaultProvider(ssmapi, ssmCache) - amiProvider := amifamily.NewDefaultProvider(clock, versionProvider, ssmProvider, ec2api, ec2Cache) + amiProvider := amifamily.NewDefaultProvider(clock, versionProvider, ssmProvider, ec2api, amiCache) amiResolver := amifamily.NewResolver(amiProvider) instanceTypesProvider := instancetype.NewDefaultProvider(fake.DefaultRegion, instanceTypeCache, ec2api, subnetProvider, unavailableOfferingsCache, pricingProvider) launchTemplateProvider := @@ -150,6 +152,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment IAMAPI: iamapi, PricingAPI: fakePricingAPI, + AMICache: amiCache, EC2Cache: ec2Cache, KubernetesVersionCache: kubernetesVersionCache, LaunchTemplateCache: launchTemplateCache, @@ -187,6 +190,7 @@ func (env *Environment) Reset() { env.PricingProvider.Reset() env.InstanceTypesProvider.Reset() + env.AMICache.Flush() env.EC2Cache.Flush() env.KubernetesVersionCache.Flush() env.UnavailableOfferingsCache.Flush() diff --git a/pkg/utils/suite_test.go b/pkg/utils/suite_test.go new file mode 100644 index 000000000000..241a861e96b2 --- /dev/null +++ b/pkg/utils/suite_test.go @@ -0,0 +1,44 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" + "github.com/aws/karpenter-provider-aws/pkg/utils" +) + +func TestUtils(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Utils Suite") +} + +var _ = Describe("GetNodeClassHash", func() { + It("should return formatted hash with UID and Generation", func() { + nodeClass := &v1.EC2NodeClass{ + ObjectMeta: metav1.ObjectMeta{ + UID: "test-uid-123", + Generation: 5, + }, + } + hash := utils.GetNodeClassHash(nodeClass) + Expect(hash).To(Equal("test-uid-123-5")) + }) +}) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 72c110cce8ee..4b9eee450845 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -177,3 +177,7 @@ func ResolveNodePoolFromNodeClaim(ctx context.Context, kubeClient client.Client, // There will be no nodePool referenced inside the nodeClaim in case of standalone nodeClaims return nil, nil } + +func GetNodeClassHash(nodeClass *v1.EC2NodeClass) string { + return fmt.Sprintf("%s-%d", nodeClass.UID, nodeClass.Generation) +}