diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index 2e67c18b1835..fb8c8912d979 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -656,7 +656,7 @@ var _ = Describe("CloudProvider", func() { }, }, }) - awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{ + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{ SecurityGroups: []ec2types.SecurityGroup{ { GroupId: aws.String(validSecurityGroup), @@ -670,7 +670,7 @@ var _ = Describe("CloudProvider", func() { }, }, }) - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{ Subnets: []ec2types.Subnet{ { SubnetId: aws.String(validSubnet1), @@ -1152,7 +1152,7 @@ var _ = Describe("CloudProvider", func() { }) It("should launch instances into subnet with the most available IP addresses", func() { awsEnv.SubnetCache.Flush() - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10), Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(100), @@ -1169,7 +1169,7 @@ var _ = Describe("CloudProvider", func() { }) It("should launch instances into subnet with the most available IP addresses in-between cache refreshes", func() { awsEnv.SubnetCache.Flush() - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10), Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(11), @@ -1197,7 +1197,7 @@ var _ = Describe("CloudProvider", func() { Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-1")) }) It("should update in-flight IPs when a CreateFleet error occurs", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int32(10), Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, }}) @@ -1208,12 +1208,20 @@ var _ = Describe("CloudProvider", func() { Expect(len(bindings)).To(Equal(0)) }) It("should launch instances into subnets that are excluded by another NodePool", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ - {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10), - Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, - {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(100), - Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}}, - }}) + awsEnv.EC2API.Subnets.Store("test-zone-1a", ec2types.Subnet{ + SubnetId: aws.String("test-subnet-1"), + AvailabilityZone: aws.String("test-zone-1a"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int32(10), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }) + awsEnv.EC2API.Subnets.Store("test-zone-1b", ec2types.Subnet{ + SubnetId: aws.String("test-subnet-2"), + AvailabilityZone: aws.String("test-zone-1b"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int32(100), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}, + }) nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{{Tags: map[string]string{"Name": "test-subnet-1"}}} ExpectApplied(ctx, env.Client, nodePool, nodeClass) controller := nodeclass.NewController(env.Client, recorder, awsEnv.SubnetProvider, awsEnv.SecurityGroupProvider, awsEnv.AMIProvider, awsEnv.InstanceProfileProvider, awsEnv.LaunchTemplateProvider) diff --git a/pkg/controllers/nodeclass/ami_test.go b/pkg/controllers/nodeclass/ami_test.go index 601c15e7c931..63cc49a6d466 100644 --- a/pkg/controllers/nodeclass/ami_test.go +++ b/pkg/controllers/nodeclass/ami_test.go @@ -631,7 +631,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() { awsEnv.Clock.Step(40 * time.Minute) // Flush Cache - awsEnv.EC2Cache.Flush() + awsEnv.AMICache.Flush() ExpectObjectReconciled(ctx, env.Client, controller, nodeClass) nodeClass = ExpectExists(ctx, env.Client, nodeClass) @@ -730,7 +730,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() { }, }) - awsEnv.EC2Cache.Flush() + awsEnv.AMICache.Flush() ExpectApplied(ctx, env.Client, nodeClass) ExpectObjectReconciled(ctx, env.Client, controller, nodeClass) diff --git a/pkg/controllers/nodeclass/subnet_test.go b/pkg/controllers/nodeclass/subnet_test.go index 5770c1a351fc..1f7e3726d86f 100644 --- a/pkg/controllers/nodeclass/subnet_test.go +++ b/pkg/controllers/nodeclass/subnet_test.go @@ -80,7 +80,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { Expect(nodeClass.StatusConditions().IsTrue(v1.ConditionTypeSubnetsReady)).To(BeTrue()) }) It("Should have the correct ordering for the Subnets", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ {SubnetId: aws.String("subnet-test1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(20)}, {SubnetId: aws.String("subnet-test2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1b"), AvailableIpAddressCount: aws.Int32(100)}, {SubnetId: aws.String("subnet-test3"), AvailabilityZone: aws.String("test-zone-1c"), AvailabilityZoneId: aws.String("tstz1-1c"), AvailableIpAddressCount: aws.Int32(50)}, diff --git a/pkg/controllers/providers/ssm/invalidation/controller.go b/pkg/controllers/providers/ssm/invalidation/controller.go index d57be1594f89..81b44e69c6f9 100644 --- a/pkg/controllers/providers/ssm/invalidation/controller.go +++ b/pkg/controllers/providers/ssm/invalidation/controller.go @@ -29,6 +29,9 @@ import ( v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" "github.com/aws/karpenter-provider-aws/pkg/providers/amifamily" "github.com/aws/karpenter-provider-aws/pkg/providers/ssm" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/uuid" ) // The SSM Invalidation controller is responsible for invalidating "latest" SSM parameters when they point to deprecated @@ -66,6 +69,9 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { amis := []amifamily.AMI{} for _, nodeClass := range lo.Map(lo.Keys(amiIDsToParameters), func(amiID string, _ int) *v1.EC2NodeClass { return &v1.EC2NodeClass{ + ObjectMeta: metav1.ObjectMeta{ + UID: uuid.NewUUID(), // ensures that this doesn't hit the AMI cache. + }, Spec: v1.EC2NodeClassSpec{ AMISelectorTerms: []v1.AMISelectorTerm{{ID: amiID}}, }, diff --git a/pkg/fake/ec2api.go b/pkg/fake/ec2api.go index 2690e6ece45c..65a4b817a616 100644 --- a/pkg/fake/ec2api.go +++ b/pkg/fake/ec2api.go @@ -48,8 +48,8 @@ type CapacityPool struct { type EC2Behavior struct { DescribeImagesOutput AtomicPtr[ec2.DescribeImagesOutput] DescribeLaunchTemplatesOutput AtomicPtr[ec2.DescribeLaunchTemplatesOutput] - DescribeSubnetsOutput AtomicPtr[ec2.DescribeSubnetsOutput] - DescribeSecurityGroupsOutput AtomicPtr[ec2.DescribeSecurityGroupsOutput] + DescribeSubnetsBehavior MockedFunction[ec2.DescribeSubnetsInput, ec2.DescribeSubnetsOutput] + DescribeSecurityGroupsBehavior MockedFunction[ec2.DescribeSecurityGroupsInput, ec2.DescribeSecurityGroupsOutput] DescribeInstanceTypesOutput AtomicPtr[ec2.DescribeInstanceTypesOutput] DescribeInstanceTypeOfferingsOutput AtomicPtr[ec2.DescribeInstanceTypeOfferingsOutput] DescribeAvailabilityZonesOutput AtomicPtr[ec2.DescribeAvailabilityZonesOutput] @@ -60,6 +60,7 @@ type EC2Behavior struct { CreateTagsBehavior MockedFunction[ec2.CreateTagsInput, ec2.CreateTagsOutput] CalledWithCreateLaunchTemplateInput AtomicPtrSlice[ec2.CreateLaunchTemplateInput] CalledWithDescribeImagesInput AtomicPtrSlice[ec2.DescribeImagesInput] + Subnets sync.Map Instances sync.Map LaunchTemplates sync.Map InsufficientCapacityPools atomic.Slice[CapacityPool] @@ -83,8 +84,8 @@ var DefaultSupportedUsageClasses = []ec2types.UsageClassType{ec2types.UsageClass func (e *EC2API) Reset() { e.DescribeImagesOutput.Reset() e.DescribeLaunchTemplatesOutput.Reset() - e.DescribeSubnetsOutput.Reset() - e.DescribeSecurityGroupsOutput.Reset() + e.DescribeSubnetsBehavior.Reset() + e.DescribeSecurityGroupsBehavior.Reset() e.DescribeInstanceTypesOutput.Reset() e.DescribeInstanceTypeOfferingsOutput.Reset() e.DescribeAvailabilityZonesOutput.Reset() @@ -380,107 +381,109 @@ func (e *EC2API) DeleteLaunchTemplate(_ context.Context, input *ec2.DeleteLaunch } func (e *EC2API) DescribeSubnets(_ context.Context, input *ec2.DescribeSubnetsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) { - if !e.NextError.IsNil() { - defer e.NextError.Reset() - return nil, e.NextError.Get() - } - if !e.DescribeSubnetsOutput.IsNil() { - describeSubnetsOutput := e.DescribeSubnetsOutput.Clone() - describeSubnetsOutput.Subnets = FilterDescribeSubnets(describeSubnetsOutput.Subnets, input.Filters) - return describeSubnetsOutput, nil - } - subnets := []ec2types.Subnet{ - { - SubnetId: aws.String("subnet-test1"), - AvailabilityZone: aws.String("test-zone-1a"), - AvailabilityZoneId: aws.String("tstz1-1a"), - AvailableIpAddressCount: aws.Int32(100), - MapPublicIpOnLaunch: aws.Bool(false), - Tags: []ec2types.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-1")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + return e.DescribeSubnetsBehavior.Invoke(input, func(input *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) { + output := &ec2.DescribeSubnetsOutput{} + e.Subnets.Range(func(key, value any) bool { + subnet := value.(ec2types.Subnet) + if lo.Contains(input.SubnetIds, lo.FromPtr(subnet.SubnetId)) || len(input.Filters) != 0 && len(FilterDescribeSubnets([]ec2types.Subnet{subnet}, input.Filters)) != 0 { + output.Subnets = append(output.Subnets, subnet) + } + return true + }) + if len(output.Subnets) != 0 { + return output, nil + } + + defaultSubnets := []ec2types.Subnet{ + { + SubnetId: aws.String("subnet-test1"), + AvailabilityZone: aws.String("test-zone-1a"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int32(100), + MapPublicIpOnLaunch: aws.Bool(false), + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-1")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test2"), - AvailabilityZone: aws.String("test-zone-1b"), - AvailabilityZoneId: aws.String("tstz1-1b"), - AvailableIpAddressCount: aws.Int32(100), - MapPublicIpOnLaunch: aws.Bool(true), - Tags: []ec2types.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-2")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + SubnetId: aws.String("subnet-test2"), + AvailabilityZone: aws.String("test-zone-1b"), + AvailabilityZoneId: aws.String("tstz1-1b"), + AvailableIpAddressCount: aws.Int32(100), + MapPublicIpOnLaunch: aws.Bool(true), + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-2")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test3"), - AvailabilityZone: aws.String("test-zone-1c"), - AvailabilityZoneId: aws.String("tstz1-1c"), - AvailableIpAddressCount: aws.Int32(100), - Tags: []ec2types.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-3")}, - {Key: aws.String("TestTag")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + SubnetId: aws.String("subnet-test3"), + AvailabilityZone: aws.String("test-zone-1c"), + AvailabilityZoneId: aws.String("tstz1-1c"), + AvailableIpAddressCount: aws.Int32(100), + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-3")}, + {Key: aws.String("TestTag")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test4"), - AvailabilityZone: aws.String("test-zone-1a-local"), - AvailabilityZoneId: aws.String("tstz1-1alocal"), - AvailableIpAddressCount: aws.Int32(100), - MapPublicIpOnLaunch: aws.Bool(true), - Tags: []ec2types.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-4")}, + { + SubnetId: aws.String("subnet-test4"), + AvailabilityZone: aws.String("test-zone-1a-local"), + AvailabilityZoneId: aws.String("tstz1-1alocal"), + AvailableIpAddressCount: aws.Int32(100), + MapPublicIpOnLaunch: aws.Bool(true), + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-4")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - } - if len(input.Filters) == 0 { - return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") - } - return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(subnets, input.Filters)}, nil + } + if len(input.Filters) == 0 { + return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") + } + return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(defaultSubnets, input.Filters)}, nil + }) } func (e *EC2API) DescribeSecurityGroups(_ context.Context, input *ec2.DescribeSecurityGroupsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { - if !e.NextError.IsNil() { - defer e.NextError.Reset() - return nil, e.NextError.Get() - } - if !e.DescribeSecurityGroupsOutput.IsNil() { - describeSecurityGroupsOutput := e.DescribeSecurityGroupsOutput.Clone() - describeSecurityGroupsOutput.SecurityGroups = FilterDescribeSecurtyGroups(describeSecurityGroupsOutput.SecurityGroups, input.Filters) - return e.DescribeSecurityGroupsOutput.Clone(), nil - } - sgs := []ec2types.SecurityGroup{ - { - GroupId: aws.String("sg-test1"), - GroupName: aws.String("securityGroup-test1"), - Tags: []ec2types.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-1")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + return e.DescribeSecurityGroupsBehavior.Invoke(input, func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) { + defaultSecurityGroups := []ec2types.SecurityGroup{ + { + GroupId: aws.String("sg-test1"), + GroupName: aws.String("securityGroup-test1"), + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-1")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - { - GroupId: aws.String("sg-test2"), - GroupName: aws.String("securityGroup-test2"), - Tags: []ec2types.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-2")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + GroupId: aws.String("sg-test2"), + GroupName: aws.String("securityGroup-test2"), + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-2")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - { - GroupId: aws.String("sg-test3"), - GroupName: aws.String("securityGroup-test3"), - Tags: []ec2types.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-3")}, - {Key: aws.String("TestTag")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + GroupId: aws.String("sg-test3"), + GroupName: aws.String("securityGroup-test3"), + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-3")}, + {Key: aws.String("TestTag")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - } - if len(input.Filters) == 0 { - return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") - } - return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(sgs, input.Filters)}, nil + } + if len(input.Filters) == 0 { + return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") + } + return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(defaultSecurityGroups, input.Filters)}, nil + }) } func (e *EC2API) DescribeAvailabilityZones(context.Context, *ec2.DescribeAvailabilityZonesInput, ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) { diff --git a/pkg/fake/types.go b/pkg/fake/types.go index ecb36e5cd3f2..30470954967f 100644 --- a/pkg/fake/types.go +++ b/pkg/fake/types.go @@ -25,6 +25,7 @@ import ( type MockedFunction[I any, O any] struct { Output AtomicPtr[O] // Output to return on call to this function + MultiOut AtomicPtrSlice[O] OutputPages AtomicPtrSlice[O] CalledWithInput AtomicPtrSlice[I] // Slice used to keep track of passed input to this function Error AtomicError // Error to return a certain number of times defined by custom error options @@ -38,6 +39,7 @@ type MockedFunction[I any, O any] struct { // each other. func (m *MockedFunction[I, O]) Reset() { m.Output.Reset() + m.MultiOut.Reset() m.OutputPages.Reset() m.CalledWithInput.Reset() m.Error.Reset() @@ -59,6 +61,11 @@ func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O, m.successfulCalls.Add(1) return m.Output.Clone(), nil } + + if m.MultiOut.Len() > 0 { + m.successfulCalls.Add(1) + return m.MultiOut.Pop(), nil + } // This output pages multi-threaded handling isn't perfect // It will fail if pages are asynchronously requested from the same NextToken if m.OutputPages.Len() > 0 { diff --git a/pkg/providers/amifamily/ami.go b/pkg/providers/amifamily/ami.go index ee6fdfc7257d..6d2418599cbc 100644 --- a/pkg/providers/amifamily/ami.go +++ b/pkg/providers/amifamily/ami.go @@ -26,12 +26,14 @@ import ( "github.com/patrickmn/go-cache" "github.com/samber/lo" "k8s.io/utils/clock" - "sigs.k8s.io/controller-runtime/pkg/log" + + "github.com/aws/karpenter-provider-aws/pkg/utils" v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" sdk "github.com/aws/karpenter-provider-aws/pkg/aws" "github.com/aws/karpenter-provider-aws/pkg/providers/version" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/karpenter/pkg/cloudprovider" "sigs.k8s.io/karpenter/pkg/scheduling" "sigs.k8s.io/karpenter/pkg/utils/pretty" @@ -69,11 +71,7 @@ func NewDefaultProvider(clk clock.Clock, versionProvider version.Provider, ssmPr func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) (AMIs, error) { p.Lock() defer p.Unlock() - queries, err := p.DescribeImageQueries(ctx, nodeClass) - if err != nil { - return nil, fmt.Errorf("getting AMI queries, %w", err) - } - amis, err := p.amis(ctx, queries) + amis, err := p.amis(ctx, nodeClass) if err != nil { return nil, err } @@ -143,12 +141,13 @@ func (p *DefaultProvider) DescribeImageQueries(ctx context.Context, nodeClass *v } //nolint:gocyclo -func (p *DefaultProvider) amis(ctx context.Context, queries []DescribeImageQuery) (AMIs, error) { - hash, err := hashstructure.Hash(queries, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) +func (p *DefaultProvider) amis(ctx context.Context, nodeClass *v1.EC2NodeClass) (AMIs, error) { + queries, err := p.DescribeImageQueries(ctx, nodeClass) if err != nil { - return nil, err + return nil, fmt.Errorf("getting AMI queries, %w", err) } - if images, ok := p.cache.Get(fmt.Sprintf("%d", hash)); ok { + hash := utils.GetNodeClassHash(nodeClass) + if images, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a deep-copy of AMIs so alterations // to the data don't affect the original return append(AMIs{}, images.(AMIs)...), nil @@ -192,7 +191,7 @@ func (p *DefaultProvider) amis(ctx context.Context, queries []DescribeImageQuery } } } - p.cache.SetDefault(fmt.Sprintf("%d", hash), AMIs(lo.Values(images))) + p.cache.SetDefault(hash, AMIs(lo.Values(images))) return lo.Values(images), nil } diff --git a/pkg/providers/amifamily/suite_test.go b/pkg/providers/amifamily/suite_test.go index 4adc0ea4c558..dd77daf16afe 100644 --- a/pkg/providers/amifamily/suite_test.go +++ b/pkg/providers/amifamily/suite_test.go @@ -30,6 +30,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/gstruct" . "sigs.k8s.io/karpenter/pkg/utils/testing" "github.com/samber/lo" @@ -45,6 +46,8 @@ import ( "github.com/aws/karpenter-provider-aws/pkg/operator/options" "github.com/aws/karpenter-provider-aws/pkg/providers/amifamily" "github.com/aws/karpenter-provider-aws/pkg/test" + + . "sigs.k8s.io/karpenter/pkg/test/expectations" ) var ctx context.Context @@ -532,6 +535,219 @@ var _ = Describe("AMIProvider", func() { })) }) }) + Context("Provider Cache", func() { + It("should resolve AMIs from cache that are filtered by id", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String(coretest.RandomName()), + ImageId: aws.String("ami-123"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + { + Name: aws.String(coretest.RandomName()), + ImageId: aws.String("ami-456"), + Architecture: "arm64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + ID: "ami-123", + }, + { + ID: "ami-456", + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "AmiID": Equal("ami-123"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "AmiID": Equal("ami-456"), + }), + )) + }) + It("should resolve AMIs from cache that are filtered by name", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: "arm64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Name: "ami-name-1", + }, + { + Name: "ami-name-2", + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + }) + It("should resolve AMIs from cache that are filtered by tags", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: "arm64", + Tags: []ec2types.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"test": "test"}, + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String("ami-name-3"), + ImageId: aws.String("ami-789"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMIFamily = &v1.AMIFamilyAL2 + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + amis, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(amis).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-3"), + }), + )) + + // OR semantics + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []ec2types.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: "x86_64", + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: "arm64", + Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: ec2types.ImageStateAvailable, + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1.AMISelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + amis, err = awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(amis).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + + cacheItems := awsEnv.AMICache.Items() + Expect(cacheItems).To(HaveLen(2)) + cachedImages := make([]amifamily.AMIs, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedImages = append(cachedImages, item.Object.(amifamily.AMIs)) + } + + Expect(cachedImages).To(ConsistOf( + ConsistOf( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-3"), + }), + ), + ConsistOf( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + ), + )) + }) + }) Context("AMI Selectors", func() { // When you tag public or shared resources, the tags you assign are available only to your AWS account; no other AWS account will have access to those tags // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions diff --git a/pkg/providers/instancetype/suite_test.go b/pkg/providers/instancetype/suite_test.go index 10d975a6ff29..7a3ef3c85419 100644 --- a/pkg/providers/instancetype/suite_test.go +++ b/pkg/providers/instancetype/suite_test.go @@ -1820,7 +1820,7 @@ var _ = Describe("InstanceTypeProvider", func() { } }) It("shouldn't report more resources than are actually available on instances", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{ Subnets: []ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2a"), diff --git a/pkg/providers/securitygroup/securitygroup.go b/pkg/providers/securitygroup/securitygroup.go index 983a2e1ee195..265cfe50006f 100644 --- a/pkg/providers/securitygroup/securitygroup.go +++ b/pkg/providers/securitygroup/securitygroup.go @@ -22,7 +22,6 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "github.com/samber/lo" "sigs.k8s.io/controller-runtime/pkg/log" @@ -31,6 +30,7 @@ import ( v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" sdk "github.com/aws/karpenter-provider-aws/pkg/aws" + "github.com/aws/karpenter-provider-aws/pkg/utils" ) type Provider interface { @@ -57,9 +57,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) p.Lock() defer p.Unlock() - // Get SecurityGroups - filterSets := getFilterSets(nodeClass.Spec.SecurityGroupSelectorTerms) - securityGroups, err := p.getSecurityGroups(ctx, filterSets) + securityGroups, err := p.getSecurityGroups(ctx, nodeClass) if err != nil { return nil, err } @@ -72,12 +70,10 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) return securityGroups, nil } -func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][]ec2types.Filter) ([]ec2types.SecurityGroup, error) { - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if sg, ok := p.cache.Get(fmt.Sprint(hash)); ok { +func (p *DefaultProvider) getSecurityGroups(ctx context.Context, nodeClass *v1.EC2NodeClass) ([]ec2types.SecurityGroup, error) { + filterSets := getFilterSets(nodeClass.Spec.SecurityGroupSelectorTerms) + hash := utils.GetNodeClassHash(nodeClass) + if sg, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) // so that modifications to the ordering of the data don't affect the original return append([]ec2types.SecurityGroup{}, sg.([]ec2types.SecurityGroup)...), nil @@ -92,7 +88,7 @@ func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][] securityGroups[lo.FromPtr(output.SecurityGroups[i].GroupId)] = output.SecurityGroups[i] } } - p.cache.SetDefault(fmt.Sprint(hash), lo.Values(securityGroups)) + p.cache.SetDefault(hash, lo.Values(securityGroups)) return lo.Values(securityGroups), nil } diff --git a/pkg/providers/securitygroup/suite_test.go b/pkg/providers/securitygroup/suite_test.go index 629ebd4ade56..dd83011ec39d 100644 --- a/pkg/providers/securitygroup/suite_test.go +++ b/pkg/providers/securitygroup/suite_test.go @@ -117,7 +117,7 @@ var _ = Describe("SecurityGroupProvider", func() { }, securityGroups) }) It("should discover security groups by tag", func() { - awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{ + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{ {GroupName: aws.String("test-sgName-1"), GroupId: aws.String("test-sg-1"), Tags: []ec2types.Tag{{Key: aws.String("kubernetes.io/cluster/test-cluster"), Value: aws.String("test-sg-1")}}}, {GroupName: aws.String("test-sgName-2"), GroupId: aws.String("test-sg-2"), Tags: []ec2types.Tag{{Key: aws.String("kubernetes.io/cluster/test-cluster"), Value: aws.String("test-sg-2")}}}, }}) @@ -273,7 +273,13 @@ var _ = Describe("SecurityGroupProvider", func() { }) Context("Provider Cache", func() { It("should resolve security groups from cache that are filtered by id", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) for _, sg := range expectedSecurityGroups { nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ { @@ -285,6 +291,7 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]ec2types.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) @@ -292,7 +299,13 @@ var _ = Describe("SecurityGroupProvider", func() { } }) It("should resolve security groups from cache that are filtered by Name", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) for _, sg := range expectedSecurityGroups { nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ { @@ -304,6 +317,7 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]ec2types.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) @@ -311,7 +325,13 @@ var _ = Describe("SecurityGroupProvider", func() { } }) It("should resolve security groups from cache that are filtered by tags", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) tagSet := lo.Map(expectedSecurityGroups, func(sg ec2types.SecurityGroup, _ int) map[string]string { tag, _ := lo.Find(sg.Tags, func(tag ec2types.Tag) bool { return lo.FromPtr(tag.Key) == "Name" @@ -329,12 +349,95 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } - for _, cachedObject := range awsEnv.SubnetCache.Items() { + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) + for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]ec2types.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) lo.Contains(lo.ToSlicePtr(expectedSecurityGroups), lo.ToPtr(cachedSecurityGroup[0])) } }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{ + {GroupName: aws.String("test-sgName-3"), GroupId: aws.String("test-sg-3"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + securityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSecurityGroups([]ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-3"), + GroupName: aws.String("test-sgName-3"), + }, + }, securityGroups) + + // OR semantics + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{ + {GroupName: aws.String("test-sgName-2"), GroupId: aws.String("test-sg-2"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{ + {GroupName: aws.String("test-sgName-1"), GroupId: aws.String("test-sg-1"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}}, + }}) + nodeClass.Spec.SecurityGroupSelectorTerms = []v1.SecurityGroupSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + securityGroups, err = awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSecurityGroups([]ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-1"), + GroupName: aws.String("test-sgName-1"), + }, + { + GroupId: aws.String("test-sg-2"), + GroupName: aws.String("test-sgName-2"), + }, + }, securityGroups) + + cacheItems := awsEnv.SecurityGroupCache.Items() + // There should be 2 cache entries one for each semantic. + Expect(cacheItems).To(HaveLen(2)) + // Extract cached security group arrays for comparison + cachedSecurityGroups := make([][]ec2types.SecurityGroup, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedSecurityGroups = append(cachedSecurityGroups, item.Object.([]ec2types.SecurityGroup)) + } + // Expect cache to contain result of both look ups. + Expect(cachedSecurityGroups).To(ContainElement(ContainElements( + []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-1"), + GroupName: aws.String("test-sgName-1"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + }, + { + GroupId: aws.String("test-sg-2"), + GroupName: aws.String("test-sgName-2"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + ))) + Expect(cachedSecurityGroups).To(ContainElement( + []ec2types.SecurityGroup{ + { + GroupId: aws.String("test-sg-3"), + GroupName: aws.String("test-sgName-3"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + )) + }) }) It("should not cause data races when calling List() simultaneously", func() { wg := sync.WaitGroup{} diff --git a/pkg/providers/subnet/subnet.go b/pkg/providers/subnet/subnet.go index 959b8bc4ddc8..64807aa487b5 100644 --- a/pkg/providers/subnet/subnet.go +++ b/pkg/providers/subnet/subnet.go @@ -24,13 +24,13 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "github.com/samber/lo" corev1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/log" v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" + "github.com/aws/karpenter-provider-aws/pkg/utils" karpv1 "sigs.k8s.io/karpenter/pkg/apis/v1" "sigs.k8s.io/karpenter/pkg/cloudprovider" @@ -85,11 +85,8 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) if len(filterSets) == 0 { return []ec2types.Subnet{}, nil } - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok { + hash := utils.GetNodeClassHash(nodeClass) + if subnets, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) // so that modifications to the ordering of the data don't affect the original return append([]ec2types.Subnet{}, subnets.([]ec2types.Subnet)...), nil @@ -110,7 +107,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) delete(p.inflightIPs, lo.FromPtr(output.Subnets[i].SubnetId)) // remove any previously tracked IP addresses since we just refreshed from EC2 } } - p.cache.SetDefault(fmt.Sprint(hash), lo.Values(subnets)) + p.cache.SetDefault(hash, lo.Values(subnets)) if p.cm.HasChanged(fmt.Sprintf("subnets/%s", nodeClass.Name), lo.Keys(subnets)) { log.FromContext(ctx). WithValues("subnets", lo.Map(lo.Values(subnets), func(s ec2types.Subnet, _ int) v1.Subnet { diff --git a/pkg/providers/subnet/suite_test.go b/pkg/providers/subnet/suite_test.go index 7c16a485a55d..735f72648baf 100644 --- a/pkg/providers/subnet/suite_test.go +++ b/pkg/providers/subnet/suite_test.go @@ -22,6 +22,8 @@ import ( "sigs.k8s.io/karpenter/pkg/test/v1alpha1" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/samber/lo" @@ -232,7 +234,13 @@ var _ = Describe("SubnetProvider", func() { }) Context("Provider Cache", func() { It("should resolve subnets from cache that are filtered by id", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets + expectedSubnets := []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }, + } + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: expectedSubnets}) for _, subnet := range expectedSubnets { nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ { @@ -244,6 +252,7 @@ var _ = Describe("SubnetProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SubnetCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]ec2types.Subnet) Expect(cachedSubnet).To(HaveLen(1)) @@ -251,7 +260,13 @@ var _ = Describe("SubnetProvider", func() { } }) It("should resolve subnets from cache that are filtered by tags", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets + expectedSubnets := []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }, + } + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: expectedSubnets}) tagSet := lo.Map(expectedSubnets, func(subnet ec2types.Subnet, _ int) map[string]string { tag, _ := lo.Find(subnet.Tags, func(tag ec2types.Tag) bool { return lo.FromPtr(tag.Key) == "Name" @@ -269,12 +284,98 @@ var _ = Describe("SubnetProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SubnetCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]ec2types.Subnet) Expect(cachedSubnet).To(HaveLen(1)) lo.Contains(lo.ToSlicePtr(expectedSubnets), lo.ToPtr(cachedSubnet[0])) } }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), SubnetArn: aws.String("test-subnet-arn-3"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }}) + nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSubnets([]ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), + SubnetArn: aws.String("test-subnet-arn-3"), + }, + }, subnets) + + // OR semantics + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("test-subnet-id-2"), SubnetArn: aws.String("test-subnet-arn-2"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}}, + }}) + nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + subnets, err = awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSubnets([]ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), + SubnetArn: aws.String("test-subnet-arn-1"), + }, + { + SubnetId: aws.String("test-subnet-id-2"), + SubnetArn: aws.String("test-subnet-arn-2"), + }, + }, subnets) + + cacheItems := awsEnv.SubnetCache.Items() + // There should be 2 cache entries one for each semantic. + Expect(cacheItems).To(HaveLen(2)) + // Extract cached subnet arrays for comparison + cachedSubnets := make([][]ec2types.Subnet, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedSubnets = append(cachedSubnets, item.Object.([]ec2types.Subnet)) + } + // Expect cache to contain result of both look ups. + Expect(cachedSubnets).To(ContainElement(ContainElements( + []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), + SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + }, + { + SubnetId: aws.String("test-subnet-id-2"), + SubnetArn: aws.String("test-subnet-arn-2"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + ))) + Expect(cachedSubnets).To(ContainElement( + []ec2types.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), + SubnetArn: aws.String("test-subnet-arn-3"), + Tags: []ec2types.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + )) + }) }) It("should not cause data races when calling List() simultaneously", func() { wg := sync.WaitGroup{} @@ -311,6 +412,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1b"), @@ -329,6 +431,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1c"), @@ -348,6 +451,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1a-local"), @@ -361,6 +465,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("test-subnet-4"), }, }, + VpcId: aws.String("vpc-test1"), }, })) }() diff --git a/pkg/test/environment.go b/pkg/test/environment.go index 2d9a9d243083..cb5b58418c45 100644 --- a/pkg/test/environment.go +++ b/pkg/test/environment.go @@ -62,6 +62,7 @@ type Environment struct { PricingAPI *fake.PricingAPI // Cache + AMICache *cache.Cache EC2Cache *cache.Cache InstanceTypeCache *cache.Cache UnavailableOfferingsCache *awscache.UnavailableOfferings @@ -99,6 +100,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment iamapi := fake.NewIAMAPI() // cache + amiCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) ec2Cache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) instanceTypeCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) discoveredCapacityCache := cache.New(awscache.DiscoveredCapacityCacheTTL, awscache.DefaultCleanupInterval) @@ -123,7 +125,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment lo.Must0(versionProvider.UpdateVersion(ctx)) instanceProfileProvider := instanceprofile.NewDefaultProvider(fake.DefaultRegion, iamapi, instanceProfileCache) ssmProvider := ssmp.NewDefaultProvider(ssmapi, ssmCache) - amiProvider := amifamily.NewDefaultProvider(clock, versionProvider, ssmProvider, ec2api, ec2Cache) + amiProvider := amifamily.NewDefaultProvider(clock, versionProvider, ssmProvider, ec2api, amiCache) amiResolver := amifamily.NewDefaultResolver() instanceTypesResolver := instancetype.NewDefaultResolver(fake.DefaultRegion, pricingProvider, unavailableOfferingsCache) instanceTypesProvider := instancetype.NewDefaultProvider(instanceTypeCache, discoveredCapacityCache, ec2api, subnetProvider, instanceTypesResolver) @@ -159,6 +161,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment IAMAPI: iamapi, PricingAPI: fakePricingAPI, + AMICache: amiCache, EC2Cache: ec2Cache, InstanceTypeCache: instanceTypeCache, LaunchTemplateCache: launchTemplateCache, @@ -195,6 +198,7 @@ func (env *Environment) Reset() { env.PricingProvider.Reset() env.InstanceTypesProvider.Reset() + env.AMICache.Flush() env.EC2Cache.Flush() env.UnavailableOfferingsCache.Flush() env.LaunchTemplateCache.Flush() diff --git a/pkg/utils/suite_test.go b/pkg/utils/suite_test.go new file mode 100644 index 000000000000..241a861e96b2 --- /dev/null +++ b/pkg/utils/suite_test.go @@ -0,0 +1,44 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" + "github.com/aws/karpenter-provider-aws/pkg/utils" +) + +func TestUtils(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Utils Suite") +} + +var _ = Describe("GetNodeClassHash", func() { + It("should return formatted hash with UID and Generation", func() { + nodeClass := &v1.EC2NodeClass{ + ObjectMeta: metav1.ObjectMeta{ + UID: "test-uid-123", + Generation: 5, + }, + } + hash := utils.GetNodeClassHash(nodeClass) + Expect(hash).To(Equal("test-uid-123-5")) + }) +}) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index d558b0790223..6da29870531a 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -24,6 +24,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1" + "github.com/samber/lo" ) @@ -83,3 +85,6 @@ func WithDefaultFloat64(key string, def float64) float64 { } return f } +func GetNodeClassHash(nodeClass *v1.EC2NodeClass) string { + return fmt.Sprintf("%s-%d", nodeClass.UID, nodeClass.Generation) +}