diff --git a/pkg/controllers/nodeclass/ami_test.go b/pkg/controllers/nodeclass/ami_test.go index f7e025ada009..eaf04f916d97 100644 --- a/pkg/controllers/nodeclass/ami_test.go +++ b/pkg/controllers/nodeclass/ami_test.go @@ -651,7 +651,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() { awsEnv.Clock.Step(40 * time.Minute) // Flush Cache - awsEnv.EC2Cache.Flush() + awsEnv.AMICache.Flush() ExpectObjectReconciled(ctx, env.Client, controller, nodeClass) nodeClass = ExpectExists(ctx, env.Client, nodeClass) @@ -750,7 +750,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() { }, }) - awsEnv.EC2Cache.Flush() + awsEnv.AMICache.Flush() ExpectApplied(ctx, env.Client, nodeClass) ExpectObjectReconciled(ctx, env.Client, controller, nodeClass) diff --git a/pkg/controllers/providers/ssm/invalidation/controller.go b/pkg/controllers/providers/ssm/invalidation/controller.go index 741999c34d79..cb7033b6c136 100644 --- a/pkg/controllers/providers/ssm/invalidation/controller.go +++ b/pkg/controllers/providers/ssm/invalidation/controller.go @@ -29,6 +29,9 @@ import ( v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" "github.com/aws/karpenter-provider-aws/pkg/providers/amifamily" "github.com/aws/karpenter-provider-aws/pkg/providers/ssm" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/uuid" ) // The SSM Invalidation controller is responsible for invalidating "latest" SSM parameters when they point to deprecated @@ -66,6 +69,9 @@ func (c *Controller) Reconcile(ctx context.Context) (reconciler.Result, error) { amis := []amifamily.AMI{} for _, nodeClass := range lo.Map(lo.Keys(amiIDsToParameters), func(amiID string, _ int) *v1.EC2NodeClass { return &v1.EC2NodeClass{ + ObjectMeta: metav1.ObjectMeta{ + UID: uuid.NewUUID(), // ensures that this doesn't hit the AMI cache. + }, Spec: v1.EC2NodeClassSpec{ AMISelectorTerms: []v1.AMISelectorTerm{{ID: amiID}}, }, diff --git a/pkg/fake/types.go b/pkg/fake/types.go index cc96de003d68..af65e7bf4257 100644 --- a/pkg/fake/types.go +++ b/pkg/fake/types.go @@ -25,6 +25,7 @@ import ( type MockedFunction[I any, O any] struct { Output AtomicPtr[O] // Output to return on call to this function + MultiOut AtomicPtrSlice[O] OutputPages AtomicPtrSlice[O] CalledWithInput AtomicPtrSlice[I] // Slice used to keep track of passed input to this function Error AtomicError // Error to return a certain number of times defined by custom error options @@ -38,6 +39,7 @@ type MockedFunction[I any, O any] struct { // each other. func (m *MockedFunction[I, O]) Reset() { m.Output.Reset() + m.MultiOut.Reset() m.OutputPages.Reset() m.CalledWithInput.Reset() m.Error.Reset() @@ -60,6 +62,11 @@ func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O, m.successfulCalls.Add(1) return m.Output.Clone(), nil } + + if m.MultiOut.Len() > 0 { + m.successfulCalls.Add(1) + return m.MultiOut.Pop(), nil + } // This output pages multi-threaded handling isn't perfect // It will fail if pages are asynchronously requested from the same NextToken if m.OutputPages.Len() > 0 { diff --git a/pkg/providers/amifamily/ami.go b/pkg/providers/amifamily/ami.go index d6275a94aa75..f2bfe46c942b 100644 --- a/pkg/providers/amifamily/ami.go +++ b/pkg/providers/amifamily/ami.go @@ -30,6 +30,7 @@ import ( "k8s.io/utils/clock" "github.com/aws/karpenter-provider-aws/pkg/errors" + "github.com/aws/karpenter-provider-aws/pkg/utils" v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" sdk "github.com/aws/karpenter-provider-aws/pkg/aws" @@ -70,11 +71,7 @@ func NewDefaultProvider(clk clock.Clock, versionProvider version.Provider, ssmPr func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) (AMIs, error) { p.Lock() defer p.Unlock() - queries, err := p.DescribeImageQueries(ctx, nodeClass) - if err != nil { - return nil, fmt.Errorf("getting AMI queries, %w", err) - } - amis, err := p.amis(ctx, queries) + amis, err := p.amis(ctx, nodeClass) if err != nil { return nil, err } @@ -165,12 +162,13 @@ func (p *DefaultProvider) DescribeImageQueries(ctx context.Context, nodeClass *v } //nolint:gocyclo -func (p *DefaultProvider) amis(ctx context.Context, queries []DescribeImageQuery) (AMIs, error) { - hash, err := hashstructure.Hash(queries, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) +func (p *DefaultProvider) amis(ctx context.Context, nodeClass *v1.EC2NodeClass) (AMIs, error) { + queries, err := p.DescribeImageQueries(ctx, nodeClass) if err != nil { - return nil, err + return nil, fmt.Errorf("getting AMI queries, %w", err) } - if images, ok := p.cache.Get(fmt.Sprintf("%d", hash)); ok { + hash := utils.GetNodeClassHash(nodeClass) + if images, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a deep-copy of AMIs so alterations // to the data don't affect the original return append(AMIs{}, images.(AMIs)...), nil @@ -214,7 +212,7 @@ func (p *DefaultProvider) amis(ctx context.Context, queries []DescribeImageQuery } } } - p.cache.SetDefault(fmt.Sprintf("%d", hash), AMIs(lo.Values(images))) + p.cache.SetDefault(hash, AMIs(lo.Values(images))) return lo.Values(images), nil } diff --git a/pkg/providers/amifamily/suite_test.go b/pkg/providers/amifamily/suite_test.go index ee263049d035..4a8ae8738a20 100644 --- a/pkg/providers/amifamily/suite_test.go +++ b/pkg/providers/amifamily/suite_test.go @@ -31,6 +31,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/gstruct" . "sigs.k8s.io/karpenter/pkg/utils/testing" "github.com/samber/lo" @@ -51,6 +52,8 @@ import ( "github.com/aws/karpenter-provider-aws/pkg/operator/options" "github.com/aws/karpenter-provider-aws/pkg/providers/amifamily" "github.com/aws/karpenter-provider-aws/pkg/test" + + . "sigs.k8s.io/karpenter/pkg/test/expectations" ) var ctx context.Context @@ -579,6 +582,219 @@ var _ = Describe("AMIProvider", func() { })) }) }) + Context("Provider Cache", func() { + It("should resolve AMIs from cache that are filtered by id", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String(coretest.RandomName()), + ImageId: aws.String("ami-123"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + { + Name: aws.String(coretest.RandomName()), + ImageId: aws.String("ami-456"), + Architecture: "arm64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + ID: "ami-123", + }, + { + ID: "ami-456", + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "AmiID": Equal("ami-123"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "AmiID": Equal("ami-456"), + }), + )) + }) + It("should resolve AMIs from cache that are filtered by name", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: "arm64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Name: "ami-name-1", + }, + { + Name: "ami-name-2", + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + }) + It("should resolve AMIs from cache that are filtered by tags", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: "arm64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"test": "test"}, + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String("ami-name-3"), + ImageId: aws.String("ami-789"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMIFamily = &v1.AMIFamilyAL2 + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + amis, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(amis).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-3"), + }), + )) + + // OR semantics + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: "arm64", + Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + amis, err = awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(amis).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + + cacheItems := awsEnv.AMICache.Items() + Expect(cacheItems).To(HaveLen(2)) + cachedImages := make([]amifamily.AMIs, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedImages = append(cachedImages, item.Object.(amifamily.AMIs)) + } + + Expect(cachedImages).To(ConsistOf( + ConsistOf( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-3"), + }), + ), + ConsistOf( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + ), + )) + }) + }) Context("AMI Selectors", func() { // When you tag public or shared resources, the tags you assign are available only to your AWS account; no other AWS account will have access to those tags // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions diff --git a/pkg/providers/securitygroup/securitygroup.go b/pkg/providers/securitygroup/securitygroup.go index 086cadf6ec30..f6fb280cf427 100644 --- a/pkg/providers/securitygroup/securitygroup.go +++ b/pkg/providers/securitygroup/securitygroup.go @@ -22,7 +22,6 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "github.com/samber/lo" "sigs.k8s.io/controller-runtime/pkg/log" @@ -31,6 +30,7 @@ import ( v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" sdk "github.com/aws/karpenter-provider-aws/pkg/aws" + "github.com/aws/karpenter-provider-aws/pkg/utils" ) type Provider interface { @@ -57,9 +57,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) p.Lock() defer p.Unlock() - // Get SecurityGroups - filterSets := getFilterSets(nodeClass.Spec.SecurityGroupSelectorTerms) - securityGroups, err := p.getSecurityGroups(ctx, filterSets) + securityGroups, err := p.getSecurityGroups(ctx, nodeClass) if err != nil { return nil, err } @@ -72,12 +70,10 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) return securityGroups, nil } -func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][]ec2types.Filter) ([]ec2types.SecurityGroup, error) { - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if sg, ok := p.cache.Get(fmt.Sprint(hash)); ok { +func (p *DefaultProvider) getSecurityGroups(ctx context.Context, nodeClass *v1.EC2NodeClass) ([]ec2types.SecurityGroup, error) { + filterSets := getFilterSets(nodeClass.Spec.SecurityGroupSelectorTerms) + hash := utils.GetNodeClassHash(nodeClass) + if sg, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) // so that modifications to the ordering of the data don't affect the original return append([]ec2types.SecurityGroup{}, sg.([]ec2types.SecurityGroup)...), nil @@ -98,7 +94,7 @@ func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][] } } } - p.cache.SetDefault(fmt.Sprint(hash), lo.Values(securityGroups)) + p.cache.SetDefault(hash, lo.Values(securityGroups)) return lo.Values(securityGroups), nil } diff --git a/pkg/providers/securitygroup/suite_test.go b/pkg/providers/securitygroup/suite_test.go index 8022cc6e34f5..9178d4deff05 100644 --- a/pkg/providers/securitygroup/suite_test.go +++ b/pkg/providers/securitygroup/suite_test.go @@ -273,7 +273,13 @@ var _ = Describe("SecurityGroupProvider", func() { }) Context("Provider Cache", func() { It("should resolve security groups from cache that are filtered by id", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Clone().SecurityGroups + expectedSecurityGroups := []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) for _, sg := range expectedSecurityGroups { nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ { @@ -285,6 +291,7 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]ec2types.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) @@ -292,7 +299,13 @@ var _ = Describe("SecurityGroupProvider", func() { } }) It("should resolve security groups from cache that are filtered by Name", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Clone().SecurityGroups + expectedSecurityGroups := []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) for _, sg := range expectedSecurityGroups { nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ { @@ -304,6 +317,7 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]ec2types.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) @@ -311,7 +325,13 @@ var _ = Describe("SecurityGroupProvider", func() { } }) It("should resolve security groups from cache that are filtered by tags", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Clone().SecurityGroups + expectedSecurityGroups := []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) tagSet := lo.Map(expectedSecurityGroups, func(sg ec2types.SecurityGroup, _ int) map[string]string { tag, _ := lo.Find(sg.Tags, func(tag ec2types.Tag) bool { return lo.FromPtr(tag.Key) == "Name" @@ -329,12 +349,95 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } - for _, cachedObject := range awsEnv.SubnetCache.Items() { + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) + for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]ec2types.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) lo.Contains(lo.ToSlicePtr(expectedSecurityGroups), lo.ToPtr(cachedSecurityGroup[0])) } }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{ + {GroupName: aws.String("test-sgName-3"), GroupId: aws.String("test-sg-3"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + securityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSecurityGroups([]ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-3"), + GroupName: aws.String("test-sgName-3"), + }, + }, securityGroups) + + // OR semantics + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{ + {GroupName: aws.String("test-sgName-2"), GroupId: aws.String("test-sg-2"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{ + {GroupName: aws.String("test-sgName-1"), GroupId: aws.String("test-sg-1"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}}, + }}) + nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + securityGroups, err = awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSecurityGroups([]ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-1"), + GroupName: aws.String("test-sgName-1"), + }, + { + GroupId: aws.String("test-sg-2"), + GroupName: aws.String("test-sgName-2"), + }, + }, securityGroups) + + cacheItems := awsEnv.SecurityGroupCache.Items() + // There should be 2 cache entries one for each semantic. + Expect(cacheItems).To(HaveLen(2)) + // Extract cached security group arrays for comparison + cachedSecurityGroups := make([][]ec2types.SecurityGroup, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedSecurityGroups = append(cachedSecurityGroups, item.Object.([]ec2types.SecurityGroup)) + } + // Expect cache to contain result of both look ups. + Expect(cachedSecurityGroups).To(ContainElement(ContainElements( + []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-1"), + GroupName: aws.String("test-sgName-1"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + }, + { + GroupId: aws.String("test-sg-2"), + GroupName: aws.String("test-sgName-2"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + ))) + Expect(cachedSecurityGroups).To(ContainElement( + []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-3"), + GroupName: aws.String("test-sgName-3"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + )) + }) }) It("should not cause data races when calling List() simultaneously", func() { wg := sync.WaitGroup{} diff --git a/pkg/providers/subnet/subnet.go b/pkg/providers/subnet/subnet.go index ef6315d9b458..be0951e51704 100644 --- a/pkg/providers/subnet/subnet.go +++ b/pkg/providers/subnet/subnet.go @@ -25,13 +25,13 @@ import ( ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/awslabs/operatorpkg/serrors" - "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "github.com/samber/lo" corev1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/log" v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" + "github.com/aws/karpenter-provider-aws/pkg/utils" karpv1 "sigs.k8s.io/karpenter/pkg/apis/v1" "sigs.k8s.io/karpenter/pkg/cloudprovider" @@ -86,11 +86,8 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) if len(filterSets) == 0 { return []ec2types.Subnet{}, nil } - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok { + hash := utils.GetNodeClassHash(nodeClass) + if subnets, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) // so that modifications to the ordering of the data don't affect the original return append([]ec2types.Subnet{}, subnets.([]ec2types.Subnet)...), nil @@ -117,7 +114,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) } } } - p.cache.SetDefault(fmt.Sprint(hash), lo.Values(subnets)) + p.cache.SetDefault(hash, lo.Values(subnets)) if p.cm.HasChanged(fmt.Sprintf("subnets/%s", nodeClass.Name), lo.Keys(subnets)) { log.FromContext(ctx). WithValues("subnets", lo.Map(lo.Values(subnets), func(s ec2types.Subnet, _ int) v1.Subnet { diff --git a/pkg/providers/subnet/suite_test.go b/pkg/providers/subnet/suite_test.go index 7cd66a8c53cd..2123c68ba6d1 100644 --- a/pkg/providers/subnet/suite_test.go +++ b/pkg/providers/subnet/suite_test.go @@ -311,7 +311,13 @@ var _ = Describe("SubnetProvider", func() { }) Context("Provider Cache", func() { It("should resolve subnets from cache that are filtered by id", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsBehavior.Output.Clone().Subnets + expectedSubnets := []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }, + } + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: expectedSubnets}) for _, subnet := range expectedSubnets { nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ { @@ -323,6 +329,7 @@ var _ = Describe("SubnetProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SubnetCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]ec2types.Subnet) Expect(cachedSubnet).To(HaveLen(1)) @@ -330,7 +337,13 @@ var _ = Describe("SubnetProvider", func() { } }) It("should resolve subnets from cache that are filtered by tags", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsBehavior.Output.Clone().Subnets + expectedSubnets := []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }, + } + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: expectedSubnets}) tagSet := lo.Map(expectedSubnets, func(subnet ec2types.Subnet, _ int) map[string]string { tag, _ := lo.Find(subnet.Tags, func(tag ec2types.Tag) bool { return lo.FromPtr(tag.Key) == "Name" @@ -348,12 +361,98 @@ var _ = Describe("SubnetProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SubnetCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]ec2types.Subnet) Expect(cachedSubnet).To(HaveLen(1)) lo.Contains(lo.ToSlicePtr(expectedSubnets), lo.ToPtr(cachedSubnet[0])) } }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), SubnetArn: aws.String("test-subnet-arn-3"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }}) + nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSubnets([]ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), + SubnetArn: aws.String("test-subnet-arn-3"), + }, + }, subnets) + + // OR semantics + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("test-subnet-id-2"), SubnetArn: aws.String("test-subnet-arn-2"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}}, + }}) + nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + subnets, err = awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSubnets([]ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), + SubnetArn: aws.String("test-subnet-arn-1"), + }, + { + SubnetId: aws.String("test-subnet-id-2"), + SubnetArn: aws.String("test-subnet-arn-2"), + }, + }, subnets) + + cacheItems := awsEnv.SubnetCache.Items() + // There should be 2 cache entries one for each semantic. + Expect(cacheItems).To(HaveLen(2)) + // Extract cached subnet arrays for comparison + cachedSubnets := make([][]ec2types.Subnet, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedSubnets = append(cachedSubnets, item.Object.([]ec2types.Subnet)) + } + // Expect cache to contain result of both look ups. + Expect(cachedSubnets).To(ContainElement(ContainElements( + []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), + SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + }, + { + SubnetId: aws.String("test-subnet-id-2"), + SubnetArn: aws.String("test-subnet-arn-2"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + ))) + Expect(cachedSubnets).To(ContainElement( + []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), + SubnetArn: aws.String("test-subnet-arn-3"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + )) + }) }) It("should not cause data races when calling List() simultaneously", func() { wg := sync.WaitGroup{} diff --git a/pkg/test/environment.go b/pkg/test/environment.go index a44849663bec..5ff5f8bab1bd 100644 --- a/pkg/test/environment.go +++ b/pkg/test/environment.go @@ -68,6 +68,7 @@ type Environment struct { PricingAPI *fake.PricingAPI // Cache + AMICache *cache.Cache EC2Cache *cache.Cache InstanceTypeCache *cache.Cache InstanceCache *cache.Cache @@ -116,6 +117,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment iamapi := fake.NewIAMAPI() // cache + amiCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) ec2Cache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) instanceTypeCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) instanceCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) @@ -149,7 +151,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment lo.Must0(versionProvider.UpdateVersion(ctx)) instanceProfileProvider := instanceprofile.NewDefaultProvider(iamapi, instanceProfileCache, roleCache, protectedProfilesCache, fake.DefaultRegion) ssmProvider := ssmp.NewDefaultProvider(ssmapi, ssmCache) - amiProvider := amifamily.NewDefaultProvider(clock, versionProvider, ssmProvider, ec2api, ec2Cache) + amiProvider := amifamily.NewDefaultProvider(clock, versionProvider, ssmProvider, ec2api, amiCache) amiResolver := amifamily.NewDefaultResolver(fake.DefaultRegion) instanceTypesResolver := instancetype.NewDefaultResolver(fake.DefaultRegion) capacityReservationProvider := capacityreservation.NewProvider(ec2api, clock, capacityReservationCache, capacityReservationAvailabilityCache) @@ -198,6 +200,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment IAMAPI: iamapi, PricingAPI: fakePricingAPI, + AMICache: amiCache, EC2Cache: ec2Cache, InstanceTypeCache: instanceTypeCache, InstanceCache: instanceCache, @@ -245,6 +248,7 @@ func (env *Environment) Reset() { env.PricingProvider.Reset() env.InstanceTypesProvider.Reset() + env.AMICache.Flush() env.EC2Cache.Flush() env.InstanceCache.Flush() env.UnavailableOfferingsCache.Flush() diff --git a/pkg/utils/suite_test.go b/pkg/utils/suite_test.go new file mode 100644 index 000000000000..241a861e96b2 --- /dev/null +++ b/pkg/utils/suite_test.go @@ -0,0 +1,44 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" + "github.com/aws/karpenter-provider-aws/pkg/utils" +) + +func TestUtils(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Utils Suite") +} + +var _ = Describe("GetNodeClassHash", func() { + It("should return formatted hash with UID and Generation", func() { + nodeClass := &v1.EC2NodeClass{ + ObjectMeta: metav1.ObjectMeta{ + UID: "test-uid-123", + Generation: 5, + }, + } + hash := utils.GetNodeClassHash(nodeClass) + Expect(hash).To(Equal("test-uid-123-5")) + }) +}) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 3a8d6e0f038c..ba343525ac9a 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -121,3 +121,7 @@ func GetTags(nodeClass *v1.EC2NodeClass, nodeClaim *karpv1.NodeClaim, clusterNam } return lo.Assign(nodeClass.Spec.Tags, staticTags), nil } + +func GetNodeClassHash(nodeClass *v1.EC2NodeClass) string { + return fmt.Sprintf("%s-%d", nodeClass.UID, nodeClass.Generation) +} diff --git a/test/suites/ami/suite_test.go b/test/suites/ami/suite_test.go index 4d8f61ddb970..01ce774acf00 100644 --- a/test/suites/ami/suite_test.go +++ b/test/suites/ami/suite_test.go @@ -93,7 +93,7 @@ var _ = Describe("AMI", func() { }) It("should use the most recent AMI when discovering multiple", func() { // choose an old static image that will definitely have an older creation date - oldCustomAMI := env.GetAMIBySSMPath(fmt.Sprintf("/aws/service/eks/optimized-ami/%[1]s/amazon-linux-2023/x86_64/standard/amazon-eks-node-al2023-x86_64-standard-%[1]s-v20250819/image_id", env.K8sVersion())) + oldCustomAMI := env.GetAMIBySSMPath(fmt.Sprintf("/aws/service/eks/optimized-ami/%[1]s/amazon-linux-2023/x86_64/standard/amazon-eks-node-al2023-x86_64-standard-%[1]s-v20250915/image_id", env.K8sVersion())) nodeClass.Spec.AMIFamily = lo.ToPtr(v1.AMIFamilyAL2023) nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ {ID: customAMI}, @@ -234,8 +234,8 @@ var _ = Describe("AMI", func() { Entry("AL2 (latest)", "al2@latest"), Entry("AL2 (pinned)", "al2@v20250116"), Entry("Bottlerocket (latest)", "bottlerocket@latest"), - Entry("Bottlerocket (pinned with v prefix)", "bottlerocket@v1.45.0"), - Entry("Bottlerocket (pinned without v prefix)", "bottlerocket@1.45.0"), + Entry("Bottlerocket (pinned with v prefix)", "bottlerocket@v1.47.0"), + Entry("Bottlerocket (pinned without v prefix)", "bottlerocket@1.47.0"), ) It("should support Custom AMIFamily with AMI Selectors", func() { al2023AMI := env.GetAMIBySSMPath(fmt.Sprintf("/aws/service/eks/optimized-ami/%s/amazon-linux-2023/x86_64/standard/recommended/image_id", env.K8sVersion()))