diff --git a/go.mod b/go.mod index 4027498cbac1..530d12b2ad61 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/awslabs/amazon-eks-ami/nodeadm v0.0.0-20240229193347-cfab22a10647 github.com/awslabs/operatorpkg v0.0.0-20240518001059-1e35978ba21b github.com/go-logr/zapr v1.3.0 + github.com/google/uuid v1.6.0 github.com/imdario/mergo v0.3.16 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/onsi/ginkgo/v2 v2.20.0 @@ -64,7 +65,6 @@ require ( github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.21.0 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index 330d4fc521c1..bc0e9314e0bd 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -618,7 +618,7 @@ var _ = Describe("CloudProvider", func() { }, }, }) - awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{ + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{ SecurityGroups: []*ec2.SecurityGroup{ { GroupId: aws.String(validSecurityGroup), @@ -632,7 +632,7 @@ var _ = Describe("CloudProvider", func() { }, }, }) - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{ Subnets: []*ec2.Subnet{ { SubnetId: aws.String(validSubnet1), @@ -1084,7 +1084,7 @@ var _ = Describe("CloudProvider", func() { }) It("should launch instances into subnet with the most available IP addresses", func() { awsEnv.SubnetCache.Flush() - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(100), @@ -1101,7 +1101,7 @@ var _ = Describe("CloudProvider", func() { }) It("should launch instances into subnet with the most available IP addresses in-between cache refreshes", func() { awsEnv.SubnetCache.Flush() - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(11), @@ -1126,7 +1126,7 @@ var _ = Describe("CloudProvider", func() { Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-1")) }) It("should update in-flight IPs when a CreateFleet error occurs", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, }}) @@ -1137,12 +1137,20 @@ var _ = Describe("CloudProvider", func() { Expect(len(bindings)).To(Equal(0)) }) It("should launch instances into subnets that are excluded by another NodePool", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ - {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(10), - Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, - {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(100), - Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}}, - }}) + awsEnv.EC2API.Subnets.Store("test-zone-1a", &ec2.Subnet{ + SubnetId: aws.String("test-subnet-1"), + AvailabilityZone: aws.String("test-zone-1a"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int64(10), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }) + awsEnv.EC2API.Subnets.Store("test-zone-1b", &ec2.Subnet{ + SubnetId: aws.String("test-subnet-2"), + AvailabilityZone: aws.String("test-zone-1b"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int64(100), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}, + }) nodeClass.Spec.SubnetSelectorTerms = []v1beta1.SubnetSelectorTerm{{Tags: map[string]string{"Name": "test-subnet-1"}}} ExpectApplied(ctx, env.Client, nodePool, nodeClass) controller := status.NewController(env.Client, awsEnv.SubnetProvider, awsEnv.SecurityGroupProvider, awsEnv.AMIProvider, awsEnv.InstanceProfileProvider, awsEnv.LaunchTemplateProvider) diff --git a/pkg/controllers/nodeclass/status/subnet_test.go b/pkg/controllers/nodeclass/status/subnet_test.go index 5658e0fc7cce..e95e5b70617d 100644 --- a/pkg/controllers/nodeclass/status/subnet_test.go +++ b/pkg/controllers/nodeclass/status/subnet_test.go @@ -77,7 +77,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { })) }) It("Should have the correct ordering for the Subnets", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("subnet-test1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int64(20)}, {SubnetId: aws.String("subnet-test2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1b"), AvailableIpAddressCount: aws.Int64(100)}, {SubnetId: aws.String("subnet-test3"), AvailabilityZone: aws.String("test-zone-1c"), AvailabilityZoneId: aws.String("tstz1-1c"), AvailableIpAddressCount: aws.Int64(50)}, diff --git a/pkg/controllers/providers/ssm/invalidation/controller.go b/pkg/controllers/providers/ssm/invalidation/controller.go index d94c9586f90a..bf70dbcdc5de 100644 --- a/pkg/controllers/providers/ssm/invalidation/controller.go +++ b/pkg/controllers/providers/ssm/invalidation/controller.go @@ -25,6 +25,9 @@ import ( "sigs.k8s.io/karpenter/pkg/operator/controller" "sigs.k8s.io/karpenter/pkg/operator/injection" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/uuid" + "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" "github.com/aws/karpenter-provider-aws/pkg/providers/amifamily" "github.com/aws/karpenter-provider-aws/pkg/providers/ssm" @@ -65,6 +68,9 @@ func (c *Controller) Reconcile(ctx context.Context, _ reconcile.Request) (reconc amis := []amifamily.AMI{} for _, nodeClass := range lo.Map(lo.Keys(amiIDsToParameters), func(amiID string, _ int) *v1beta1.EC2NodeClass { return &v1beta1.EC2NodeClass{ + ObjectMeta: metav1.ObjectMeta{ + UID: uuid.NewUUID(), // ensures that this doesn't hit the AMI cache. + }, Spec: v1beta1.EC2NodeClassSpec{ AMISelectorTerms: []v1beta1.AMISelectorTerm{{ID: amiID}}, }, diff --git a/pkg/controllers/providers/ssm/invalidation/suite_test.go b/pkg/controllers/providers/ssm/invalidation/suite_test.go index c901da8d6957..26fa1ed59950 100644 --- a/pkg/controllers/providers/ssm/invalidation/suite_test.go +++ b/pkg/controllers/providers/ssm/invalidation/suite_test.go @@ -87,7 +87,7 @@ var _ = Describe("SSM Invalidation Controller", func() { Expect(err).To(BeNil()) currentEntries := getSSMCacheEntries() Expect(len(currentEntries)).To(Equal(2)) - awsEnv.EC2Cache.Flush() + awsEnv.AMICache.Flush() ExpectReconcileSucceeded(ctx, invalidationController, client.ObjectKey{}) awsEnv.SSMAPI.Reset() _, err = awsEnv.AMIProvider.List(ctx, nodeClass) @@ -106,7 +106,7 @@ var _ = Describe("SSM Invalidation Controller", func() { currentEntries := getSSMCacheEntries() deprecateAMIs(lo.Values(currentEntries)...) Expect(len(currentEntries)).To(Equal(2)) - awsEnv.EC2Cache.Flush() + awsEnv.AMICache.Flush() ExpectReconcileSucceeded(ctx, invalidationController, client.ObjectKey{}) awsEnv.SSMAPI.Reset() _, err = awsEnv.AMIProvider.List(ctx, nodeClass) diff --git a/pkg/fake/atomic.go b/pkg/fake/atomic.go index d9a34f03a065..866e880452fb 100644 --- a/pkg/fake/atomic.go +++ b/pkg/fake/atomic.go @@ -161,6 +161,13 @@ func (a *AtomicPtrSlice[T]) Pop() *T { return last } +func (a *AtomicPtrSlice[T]) At(index int) *T { + a.mu.Lock() + defer a.mu.Unlock() + + return clone(a.values[index]) +} + func (a *AtomicPtrSlice[T]) ForEach(fn func(*T)) { a.mu.RLock() defer a.mu.RUnlock() diff --git a/pkg/fake/ec2api.go b/pkg/fake/ec2api.go index 654e986084da..0d128a9c6bc7 100644 --- a/pkg/fake/ec2api.go +++ b/pkg/fake/ec2api.go @@ -48,8 +48,8 @@ type CapacityPool struct { type EC2Behavior struct { DescribeImagesOutput AtomicPtr[ec2.DescribeImagesOutput] DescribeLaunchTemplatesOutput AtomicPtr[ec2.DescribeLaunchTemplatesOutput] - DescribeSubnetsOutput AtomicPtr[ec2.DescribeSubnetsOutput] - DescribeSecurityGroupsOutput AtomicPtr[ec2.DescribeSecurityGroupsOutput] + DescribeSubnetsBehavior MockedFunction[ec2.DescribeSubnetsInput, ec2.DescribeSubnetsOutput] + DescribeSecurityGroupsBehavior MockedFunction[ec2.DescribeSecurityGroupsInput, ec2.DescribeSecurityGroupsOutput] DescribeInstanceTypesOutput AtomicPtr[ec2.DescribeInstanceTypesOutput] DescribeInstanceTypeOfferingsOutput AtomicPtr[ec2.DescribeInstanceTypeOfferingsOutput] DescribeAvailabilityZonesOutput AtomicPtr[ec2.DescribeAvailabilityZonesOutput] @@ -61,6 +61,7 @@ type EC2Behavior struct { CreateTagsBehavior MockedFunction[ec2.CreateTagsInput, ec2.CreateTagsOutput] CalledWithCreateLaunchTemplateInput AtomicPtrSlice[ec2.CreateLaunchTemplateInput] CalledWithDescribeImagesInput AtomicPtrSlice[ec2.DescribeImagesInput] + Subnets sync.Map Instances sync.Map LaunchTemplates sync.Map InsufficientCapacityPools atomic.Slice[CapacityPool] @@ -84,8 +85,8 @@ var DefaultSupportedUsageClasses = aws.StringSlice([]string{"on-demand", "spot"} func (e *EC2API) Reset() { e.DescribeImagesOutput.Reset() e.DescribeLaunchTemplatesOutput.Reset() - e.DescribeSubnetsOutput.Reset() - e.DescribeSecurityGroupsOutput.Reset() + e.DescribeSubnetsBehavior.Reset() + e.DescribeSecurityGroupsBehavior.Reset() e.DescribeInstanceTypesOutput.Reset() e.DescribeInstanceTypeOfferingsOutput.Reset() e.DescribeAvailabilityZonesOutput.Reset() @@ -405,107 +406,109 @@ func (e *EC2API) DeleteLaunchTemplateWithContext(_ context.Context, input *ec2.D } func (e *EC2API) DescribeSubnetsWithContext(_ context.Context, input *ec2.DescribeSubnetsInput, _ ...request.Option) (*ec2.DescribeSubnetsOutput, error) { - if !e.NextError.IsNil() { - defer e.NextError.Reset() - return nil, e.NextError.Get() - } - if !e.DescribeSubnetsOutput.IsNil() { - describeSubnetsOutput := e.DescribeSubnetsOutput.Clone() - describeSubnetsOutput.Subnets = FilterDescribeSubnets(describeSubnetsOutput.Subnets, input.Filters) - return describeSubnetsOutput, nil - } - subnets := []*ec2.Subnet{ - { - SubnetId: aws.String("subnet-test1"), - AvailabilityZone: aws.String("test-zone-1a"), - AvailabilityZoneId: aws.String("tstz1-1a"), - AvailableIpAddressCount: aws.Int64(100), - MapPublicIpOnLaunch: aws.Bool(false), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-1")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + return e.DescribeSubnetsBehavior.Invoke(input, func(input *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) { + output := &ec2.DescribeSubnetsOutput{} + e.Subnets.Range(func(key, value any) bool { + subnet := value.(*ec2.Subnet) + if lo.Contains(lo.Map(input.SubnetIds, func(s *string, _ int) string { return lo.FromPtr(s) }), lo.FromPtr(subnet.SubnetId)) || len(input.Filters) != 0 && len(FilterDescribeSubnets([]*ec2.Subnet{subnet}, input.Filters)) != 0 { + output.Subnets = append(output.Subnets, subnet) + } + return true + }) + if len(output.Subnets) != 0 { + return output, nil + } + + defaultSubnets := []*ec2.Subnet{ + { + SubnetId: aws.String("subnet-test1"), + AvailabilityZone: aws.String("test-zone-1a"), + AvailabilityZoneId: aws.String("tstz1-1a"), + AvailableIpAddressCount: aws.Int64(100), + MapPublicIpOnLaunch: aws.Bool(false), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-1")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test2"), - AvailabilityZone: aws.String("test-zone-1b"), - AvailabilityZoneId: aws.String("tstz1-1b"), - AvailableIpAddressCount: aws.Int64(100), - MapPublicIpOnLaunch: aws.Bool(true), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-2")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + SubnetId: aws.String("subnet-test2"), + AvailabilityZone: aws.String("test-zone-1b"), + AvailabilityZoneId: aws.String("tstz1-1b"), + AvailableIpAddressCount: aws.Int64(100), + MapPublicIpOnLaunch: aws.Bool(true), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-2")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test3"), - AvailabilityZone: aws.String("test-zone-1c"), - AvailabilityZoneId: aws.String("tstz1-1c"), - AvailableIpAddressCount: aws.Int64(100), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-3")}, - {Key: aws.String("TestTag")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + SubnetId: aws.String("subnet-test3"), + AvailabilityZone: aws.String("test-zone-1c"), + AvailabilityZoneId: aws.String("tstz1-1c"), + AvailableIpAddressCount: aws.Int64(100), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-3")}, + {Key: aws.String("TestTag")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - { - SubnetId: aws.String("subnet-test4"), - AvailabilityZone: aws.String("test-zone-1a-local"), - AvailabilityZoneId: aws.String("tstz1-1alocal"), - AvailableIpAddressCount: aws.Int64(100), - MapPublicIpOnLaunch: aws.Bool(true), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-subnet-4")}, + { + SubnetId: aws.String("subnet-test4"), + AvailabilityZone: aws.String("test-zone-1a-local"), + AvailabilityZoneId: aws.String("tstz1-1alocal"), + AvailableIpAddressCount: aws.Int64(100), + MapPublicIpOnLaunch: aws.Bool(true), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-subnet-4")}, + }, + VpcId: aws.String("vpc-test1"), }, - }, - } - if len(input.Filters) == 0 { - return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") - } - return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(subnets, input.Filters)}, nil + } + if len(input.Filters) == 0 { + return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") + } + return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(defaultSubnets, input.Filters)}, nil + }) } func (e *EC2API) DescribeSecurityGroupsWithContext(_ context.Context, input *ec2.DescribeSecurityGroupsInput, _ ...request.Option) (*ec2.DescribeSecurityGroupsOutput, error) { - if !e.NextError.IsNil() { - defer e.NextError.Reset() - return nil, e.NextError.Get() - } - if !e.DescribeSecurityGroupsOutput.IsNil() { - describeSecurityGroupsOutput := e.DescribeSecurityGroupsOutput.Clone() - describeSecurityGroupsOutput.SecurityGroups = FilterDescribeSecurtyGroups(describeSecurityGroupsOutput.SecurityGroups, input.Filters) - return e.DescribeSecurityGroupsOutput.Clone(), nil - } - sgs := []*ec2.SecurityGroup{ - { - GroupId: aws.String("sg-test1"), - GroupName: aws.String("securityGroup-test1"), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-1")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + return e.DescribeSecurityGroupsBehavior.Invoke(input, func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) { + defaultSecurityGroups := []*ec2.SecurityGroup{ + { + GroupId: aws.String("sg-test1"), + GroupName: aws.String("securityGroup-test1"), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-1")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - { - GroupId: aws.String("sg-test2"), - GroupName: aws.String("securityGroup-test2"), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-2")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + GroupId: aws.String("sg-test2"), + GroupName: aws.String("securityGroup-test2"), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-2")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - { - GroupId: aws.String("sg-test3"), - GroupName: aws.String("securityGroup-test3"), - Tags: []*ec2.Tag{ - {Key: aws.String("Name"), Value: aws.String("test-security-group-3")}, - {Key: aws.String("TestTag")}, - {Key: aws.String("foo"), Value: aws.String("bar")}, + { + GroupId: aws.String("sg-test3"), + GroupName: aws.String("securityGroup-test3"), + Tags: []*ec2.Tag{ + {Key: aws.String("Name"), Value: aws.String("test-security-group-3")}, + {Key: aws.String("TestTag")}, + {Key: aws.String("foo"), Value: aws.String("bar")}, + }, }, - }, - } - if len(input.Filters) == 0 { - return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") - } - return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(sgs, input.Filters)}, nil + } + if len(input.Filters) == 0 { + return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid") + } + return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(defaultSecurityGroups, input.Filters)}, nil + }) } func (e *EC2API) DescribeAvailabilityZonesWithContext(context.Context, *ec2.DescribeAvailabilityZonesInput, ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) { diff --git a/pkg/fake/types.go b/pkg/fake/types.go index 88fe2ca83bc9..c0c545d600e8 100644 --- a/pkg/fake/types.go +++ b/pkg/fake/types.go @@ -15,14 +15,22 @@ limitations under the License. package fake import ( + "reflect" + "sync" "sync/atomic" + + "github.com/google/uuid" + "github.com/samber/lo" ) type MockedFunction[I any, O any] struct { - Output AtomicPtr[O] // Output to return on call to this function + Output AtomicPtr[O] // Output to return on call to this function + MultiOut AtomicPtrSlice[O] + OutputPages AtomicPtrSlice[O] CalledWithInput AtomicPtrSlice[I] // Slice used to keep track of passed input to this function Error AtomicError // Error to return a certain number of times defined by custom error options + pageMapping sync.Map // token uuid -> page number: Internal construct to keep track of the page that we are on successfulCalls atomic.Int32 // Internal construct to keep track of the number of times this function has successfully been called failedCalls atomic.Int32 // Internal construct to keep track of the number of times this function has failed (with error) } @@ -31,6 +39,8 @@ type MockedFunction[I any, O any] struct { // each other. func (m *MockedFunction[I, O]) Reset() { m.Output.Reset() + m.MultiOut.Reset() + m.OutputPages.Reset() m.CalledWithInput.Reset() m.Error.Reset() @@ -50,6 +60,29 @@ func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O, m.successfulCalls.Add(1) return m.Output.Clone(), nil } + + if m.MultiOut.Len() > 0 { + m.successfulCalls.Add(1) + return m.MultiOut.Pop(), nil + } + // This output pages multi-threaded handling isn't perfect + // It will fail if pages are asynchronously requested from the same NextToken + if m.OutputPages.Len() > 0 { + token := uuid.New().String() // generate a token so that each paginated request set gets its own mapping + if !reflect.ValueOf(input).Elem().FieldByName("NextToken").Elem().CanSet() { + m.pageMapping.Store(token, 0) + } else { + token = reflect.ValueOf(input).Elem().FieldByName("NextToken").Elem().String() + } + pageNum := lo.Must(m.pageMapping.Load(token)).(int) + page := m.OutputPages.At(pageNum) + if pageNum < m.OutputPages.Len()-1 { + reflect.ValueOf(page).Elem().FieldByName("NextToken").Set(reflect.ValueOf(lo.ToPtr(token))) + } + m.pageMapping.Store(token, pageNum+1) + m.successfulCalls.Add(1) + return page, nil + } out, err := defaultTransformer(input) if err != nil { m.failedCalls.Add(1) diff --git a/pkg/providers/amifamily/ami.go b/pkg/providers/amifamily/ami.go index 356409cb19f3..bb70e203339d 100644 --- a/pkg/providers/amifamily/ami.go +++ b/pkg/providers/amifamily/ami.go @@ -34,6 +34,7 @@ import ( "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" "github.com/aws/karpenter-provider-aws/pkg/providers/ssm" "github.com/aws/karpenter-provider-aws/pkg/providers/version" + "github.com/aws/karpenter-provider-aws/pkg/utils" "sigs.k8s.io/karpenter/pkg/cloudprovider" "sigs.k8s.io/karpenter/pkg/scheduling" @@ -115,7 +116,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl return nil, err } } else { - amis, err = p.getAMIs(ctx, nodeClass.Spec.AMISelectorTerms) + amis, err = p.getAMIs(ctx, nodeClass) if err != nil { return nil, err } @@ -130,7 +131,8 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl } func (p *DefaultProvider) getDefaultAMIs(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) (res AMIs, err error) { - if images, ok := p.cache.Get(lo.FromPtr(nodeClass.Spec.AMIFamily)); ok { + hash := utils.GetNodeClassHash(nodeClass) + if images, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a deep-copy of AMIs so alterations // to the data don't affect the original return append(AMIs{}, images.(AMIs)...), nil @@ -167,7 +169,7 @@ func (p *DefaultProvider) getDefaultAMIs(ctx context.Context, nodeClass *v1beta1 }); err != nil { return nil, fmt.Errorf("describing images, %w", err) } - p.cache.SetDefault(lo.FromPtr(nodeClass.Spec.AMIFamily), res) + p.cache.SetDefault(hash, res) return res, nil } @@ -182,20 +184,18 @@ func (p *DefaultProvider) resolveSSMParameter(ctx context.Context, name string) return imageID, nil } -func (p *DefaultProvider) getAMIs(ctx context.Context, terms []v1beta1.AMISelectorTerm) (AMIs, error) { - filterAndOwnerSets := GetFilterAndOwnerSets(terms) - hash, err := hashstructure.Hash(filterAndOwnerSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if images, ok := p.cache.Get(fmt.Sprintf("%d", hash)); ok { +//nolint:gocyclo +func (p *DefaultProvider) getAMIs(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) (AMIs, error) { + filterAndOwnerSets := GetFilterAndOwnerSets(nodeClass.Spec.AMISelectorTerms) + hash := utils.GetNodeClassHash(nodeClass) + if images, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a deep-copy of AMIs so alterations // to the data don't affect the original return append(AMIs{}, images.(AMIs)...), nil } images := map[uint64]AMI{} for _, filtersAndOwners := range filterAndOwnerSets { - if err = p.ec2api.DescribeImagesPagesWithContext(ctx, &ec2.DescribeImagesInput{ + if err := p.ec2api.DescribeImagesPagesWithContext(ctx, &ec2.DescribeImagesInput{ // Don't include filters in the Describe Images call as EC2 API doesn't allow empty filters. Filters: lo.Ternary(len(filtersAndOwners.Filters) > 0, filtersAndOwners.Filters, nil), Owners: lo.Ternary(len(filtersAndOwners.Owners) > 0, aws.StringSlice(filtersAndOwners.Owners), nil), @@ -231,7 +231,7 @@ func (p *DefaultProvider) getAMIs(ctx context.Context, terms []v1beta1.AMISelect return nil, fmt.Errorf("describing images, %w", err) } } - p.cache.SetDefault(fmt.Sprintf("%d", hash), AMIs(lo.Values(images))) + p.cache.SetDefault(hash, AMIs(lo.Values(images))) return lo.Values(images), nil } diff --git a/pkg/providers/amifamily/suite_test.go b/pkg/providers/amifamily/suite_test.go index 56c4660901c1..567c5a1c2645 100644 --- a/pkg/providers/amifamily/suite_test.go +++ b/pkg/providers/amifamily/suite_test.go @@ -27,6 +27,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/gstruct" . "sigs.k8s.io/karpenter/pkg/utils/testing" "github.com/samber/lo" @@ -43,6 +44,8 @@ import ( "github.com/aws/karpenter-provider-aws/pkg/operator/options" "github.com/aws/karpenter-provider-aws/pkg/providers/amifamily" "github.com/aws/karpenter-provider-aws/pkg/test" + + . "sigs.k8s.io/karpenter/pkg/test/expectations" ) var ctx context.Context @@ -329,6 +332,219 @@ var _ = Describe("AMIProvider", func() { })) }) }) + Context("Provider Cache", func() { + It("should resolve AMIs from cache that are filtered by id", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String(coretest.RandomName()), + ImageId: aws.String("ami-123"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + { + Name: aws.String(coretest.RandomName()), + ImageId: aws.String("ami-456"), + Architecture: lo.ToPtr("arm64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{ + { + ID: "ami-123", + }, + { + ID: "ami-456", + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "AmiID": Equal("ami-123"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "AmiID": Equal("ami-456"), + }), + )) + }) + It("should resolve AMIs from cache that are filtered by name", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: lo.ToPtr("arm64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{ + { + Name: "ami-name-1", + }, + { + Name: "ami-name-2", + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + }) + It("should resolve AMIs from cache that are filtered by tags", func() { + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: lo.ToPtr("arm64"), + Tags: []*ec2.Tag{{Key: lo.ToPtr("test"), Value: lo.ToPtr("test")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{ + { + Tags: map[string]string{"test": "test"}, + }, + } + _, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(awsEnv.AMICache.Items()).To(HaveLen(1)) + cachedImages := lo.Values(awsEnv.AMICache.Items())[0].Object.(amifamily.AMIs) + Expect(cachedImages).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String("ami-name-3"), + ImageId: aws.String("ami-789"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMIFamily = &v1beta1.AMIFamilyAL2 + nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + amis, err := awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(amis).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-3"), + }), + )) + + // OR semantics + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + { + Name: aws.String("ami-name-1"), + ImageId: aws.String("ami-123"), + Architecture: lo.ToPtr("x86_64"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + { + Name: aws.String("ami-name-2"), + ImageId: aws.String("ami-456"), + Architecture: lo.ToPtr("arm64"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + CreationDate: aws.String("2022-08-15T12:00:00Z"), + State: lo.ToPtr(ec2.ImageStateAvailable), + }, + }}) + nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + amis, err = awsEnv.AMIProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + + Expect(amis).To(ContainElements( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + )) + + cacheItems := awsEnv.AMICache.Items() + Expect(cacheItems).To(HaveLen(2)) + cachedImages := make([]amifamily.AMIs, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedImages = append(cachedImages, item.Object.(amifamily.AMIs)) + } + + Expect(cachedImages).To(ConsistOf( + ConsistOf( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-3"), + }), + ), + ConsistOf( + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-1"), + }), + gstruct.MatchFields(gstruct.IgnoreExtras, gstruct.Fields{ + "Name": Equal("ami-name-2"), + }), + ), + )) + }) + }) Context("AMI Selectors", func() { // When you tag public or shared resources, the tags you assign are available only to your AWS account; no other AWS account will have access to those tags // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions diff --git a/pkg/providers/instancetype/suite_test.go b/pkg/providers/instancetype/suite_test.go index 4a62ccd66056..fff3836c9164 100644 --- a/pkg/providers/instancetype/suite_test.go +++ b/pkg/providers/instancetype/suite_test.go @@ -1698,7 +1698,7 @@ var _ = Describe("InstanceTypeProvider", func() { } }) It("shouldn't report more resources than are actually available on instances", func() { - awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{ + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{ Subnets: []*ec2.Subnet{ { AvailabilityZone: aws.String("us-west-2a"), diff --git a/pkg/providers/securitygroup/securitygroup.go b/pkg/providers/securitygroup/securitygroup.go index db955dde3e12..4a747596a87c 100644 --- a/pkg/providers/securitygroup/securitygroup.go +++ b/pkg/providers/securitygroup/securitygroup.go @@ -22,7 +22,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "github.com/samber/lo" "sigs.k8s.io/controller-runtime/pkg/log" @@ -30,6 +29,7 @@ import ( "sigs.k8s.io/karpenter/pkg/utils/pretty" "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" + "github.com/aws/karpenter-provider-aws/pkg/utils" ) type Provider interface { @@ -56,9 +56,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl p.Lock() defer p.Unlock() - // Get SecurityGroups - filterSets := getFilterSets(nodeClass.Spec.SecurityGroupSelectorTerms) - securityGroups, err := p.getSecurityGroups(ctx, filterSets) + securityGroups, err := p.getSecurityGroups(ctx, nodeClass) if err != nil { return nil, err } @@ -72,12 +70,10 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl return securityGroups, nil } -func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][]*ec2.Filter) ([]*ec2.SecurityGroup, error) { - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if sg, ok := p.cache.Get(fmt.Sprint(hash)); ok { +func (p *DefaultProvider) getSecurityGroups(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) ([]*ec2.SecurityGroup, error) { + filterSets := getFilterSets(nodeClass.Spec.SecurityGroupSelectorTerms) + hash := utils.GetNodeClassHash(nodeClass) + if sg, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) // so that modifications to the ordering of the data don't affect the original return append([]*ec2.SecurityGroup{}, sg.([]*ec2.SecurityGroup)...), nil @@ -92,7 +88,7 @@ func (p *DefaultProvider) getSecurityGroups(ctx context.Context, filterSets [][] securityGroups[lo.FromPtr(output.SecurityGroups[i].GroupId)] = output.SecurityGroups[i] } } - p.cache.SetDefault(fmt.Sprint(hash), lo.Values(securityGroups)) + p.cache.SetDefault(hash, lo.Values(securityGroups)) return lo.Values(securityGroups), nil } diff --git a/pkg/providers/securitygroup/suite_test.go b/pkg/providers/securitygroup/suite_test.go index b1ed34b6a38c..ef1026fc4b80 100644 --- a/pkg/providers/securitygroup/suite_test.go +++ b/pkg/providers/securitygroup/suite_test.go @@ -113,7 +113,7 @@ var _ = Describe("SecurityGroupProvider", func() { }, securityGroups) }) It("should discover security groups by tag", func() { - awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ {GroupName: aws.String("test-sgName-1"), GroupId: aws.String("test-sg-1"), Tags: []*ec2.Tag{{Key: aws.String("kubernetes.io/cluster/test-cluster"), Value: aws.String("test-sg-1")}}}, {GroupName: aws.String("test-sgName-2"), GroupId: aws.String("test-sg-2"), Tags: []*ec2.Tag{{Key: aws.String("kubernetes.io/cluster/test-cluster"), Value: aws.String("test-sg-2")}}}, }}) @@ -269,7 +269,13 @@ var _ = Describe("SecurityGroupProvider", func() { }) Context("Provider Cache", func() { It("should resolve security groups from cache that are filtered by id", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) for _, sg := range expectedSecurityGroups { nodeClass.Spec.SecurityGroupSelectorTerms = []v1beta1.SecurityGroupSelectorTerm{ { @@ -281,6 +287,7 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]*ec2.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) @@ -288,7 +295,13 @@ var _ = Describe("SecurityGroupProvider", func() { } }) It("should resolve security groups from cache that are filtered by Name", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) for _, sg := range expectedSecurityGroups { nodeClass.Spec.SecurityGroupSelectorTerms = []v1beta1.SecurityGroupSelectorTerm{ { @@ -300,6 +313,7 @@ var _ = Describe("SecurityGroupProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SecurityGroupCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SecurityGroupCache.Items() { cachedSecurityGroup := cachedObject.Object.([]*ec2.SecurityGroup) Expect(cachedSecurityGroup).To(HaveLen(1)) @@ -307,7 +321,13 @@ var _ = Describe("SecurityGroupProvider", func() { } }) It("should resolve security groups from cache that are filtered by tags", func() { - expectedSecurityGroups := awsEnv.EC2API.DescribeSecurityGroupsOutput.Clone().SecurityGroups + expectedSecurityGroups := []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-id-1"), GroupName: aws.String("test-sg-name-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-sg-1")}}, + }, + } + awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: expectedSecurityGroups}) tagSet := lo.Map(expectedSecurityGroups, func(sg *ec2.SecurityGroup, _ int) map[string]string { tag, _ := lo.Find(sg.Tags, func(tag *ec2.Tag) bool { return lo.FromPtr(tag.Key) == "Name" @@ -331,6 +351,88 @@ var _ = Describe("SecurityGroupProvider", func() { lo.Contains(expectedSecurityGroups, cachedSecurityGroup[0]) } }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + {GroupName: aws.String("test-sgName-3"), GroupId: aws.String("test-sg-3"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + nodeClass.Spec.SecurityGroupSelectorTerms = []v1beta1.SecurityGroupSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + securityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSecurityGroups([]*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-3"), + GroupName: aws.String("test-sgName-3"), + }, + }, securityGroups) + + // OR semantics + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + {GroupName: aws.String("test-sgName-2"), GroupId: aws.String("test-sg-2"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + awsEnv.EC2API.DescribeSecurityGroupsBehavior.MultiOut.Add(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + {GroupName: aws.String("test-sgName-1"), GroupId: aws.String("test-sg-1"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}}, + }}) + nodeClass.Spec.SecurityGroupSelectorTerms = []v1beta1.SecurityGroupSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + securityGroups, err = awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSecurityGroups([]*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-1"), + GroupName: aws.String("test-sgName-1"), + }, + { + GroupId: aws.String("test-sg-2"), + GroupName: aws.String("test-sgName-2"), + }, + }, securityGroups) + + cacheItems := awsEnv.SecurityGroupCache.Items() + // There should be 2 cache entries one for each semantic. + Expect(cacheItems).To(HaveLen(2)) + // Extract cached security group arrays for comparison + cachedSecurityGroups := make([][]*ec2.SecurityGroup, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedSecurityGroups = append(cachedSecurityGroups, item.Object.([]*ec2.SecurityGroup)) + } + // Expect cache to contain result of both look ups. + Expect(cachedSecurityGroups).To(ContainElement(ContainElements( + []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-1"), + GroupName: aws.String("test-sgName-1"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + }, + { + GroupId: aws.String("test-sg-2"), + GroupName: aws.String("test-sgName-2"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + ))) + Expect(cachedSecurityGroups).To(ContainElement( + []*ec2.SecurityGroup{ + { + GroupId: aws.String("test-sg-3"), + GroupName: aws.String("test-sgName-3"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + )) + }) }) It("should not cause data races when calling List() simultaneously", func() { wg := sync.WaitGroup{} diff --git a/pkg/providers/subnet/subnet.go b/pkg/providers/subnet/subnet.go index 95bba2532361..2fc52b7dddb7 100644 --- a/pkg/providers/subnet/subnet.go +++ b/pkg/providers/subnet/subnet.go @@ -23,13 +23,13 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/mitchellh/hashstructure/v2" "github.com/patrickmn/go-cache" "github.com/samber/lo" v1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/log" "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" + "github.com/aws/karpenter-provider-aws/pkg/utils" corev1beta1 "sigs.k8s.io/karpenter/pkg/apis/v1beta1" "sigs.k8s.io/karpenter/pkg/cloudprovider" @@ -83,11 +83,8 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl if len(filterSets) == 0 { return []*ec2.Subnet{}, nil } - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) - if err != nil { - return nil, err - } - if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok { + hash := utils.GetNodeClassHash(nodeClass) + if subnets, ok := p.cache.Get(hash); ok { // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) // so that modifications to the ordering of the data don't affect the original return append([]*ec2.Subnet{}, subnets.([]*ec2.Subnet)...), nil @@ -109,8 +106,8 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl delete(p.inflightIPs, lo.FromPtr(output.Subnets[i].SubnetId)) // remove any previously tracked IP addresses since we just refreshed from EC2 } } - p.cache.SetDefault(fmt.Sprint(hash), lo.Values(subnets)) - if p.cm.HasChanged(fmt.Sprintf("subnets/%s", nodeClass.Name), subnets) { + p.cache.SetDefault(hash, lo.Values(subnets)) + if p.cm.HasChanged(fmt.Sprintf("subnets/%s", nodeClass.Name), lo.Keys(subnets)) { log.FromContext(ctx). WithValues("subnets", lo.Map(lo.Values(subnets), func(s *ec2.Subnet, _ int) v1beta1.Subnet { return v1beta1.Subnet{ diff --git a/pkg/providers/subnet/suite_test.go b/pkg/providers/subnet/suite_test.go index 4760e6ff7354..5d91f931d86d 100644 --- a/pkg/providers/subnet/suite_test.go +++ b/pkg/providers/subnet/suite_test.go @@ -277,7 +277,13 @@ var _ = Describe("SubnetProvider", func() { }) Context("Provider Cache", func() { It("should resolve subnets from cache that are filtered by id", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets + expectedSubnets := []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }, + } + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: expectedSubnets}) for _, subnet := range expectedSubnets { nodeClass.Spec.SubnetSelectorTerms = []v1beta1.SubnetSelectorTerm{ { @@ -289,6 +295,7 @@ var _ = Describe("SubnetProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SubnetCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]*ec2.Subnet) Expect(cachedSubnet).To(HaveLen(1)) @@ -296,7 +303,13 @@ var _ = Describe("SubnetProvider", func() { } }) It("should resolve subnets from cache that are filtered by tags", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets + expectedSubnets := []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}, + }, + } + awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: expectedSubnets}) tagSet := lo.Map(expectedSubnets, func(subnet *ec2.Subnet, _ int) map[string]string { tag, _ := lo.Find(subnet.Tags, func(tag *ec2.Tag) bool { return lo.FromPtr(tag.Key) == "Name" @@ -314,12 +327,98 @@ var _ = Describe("SubnetProvider", func() { Expect(err).To(BeNil()) } + Expect(awsEnv.SubnetCache.Items()).To(HaveLen(1)) for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]*ec2.Subnet) Expect(cachedSubnet).To(HaveLen(1)) lo.Contains(expectedSubnets, cachedSubnet[0]) } }) + It("should correctly disambiguate AND vs OR semantics for tags", func() { + // AND semantics + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), SubnetArn: aws.String("test-subnet-arn-3"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }}) + nodeClass.Spec.SubnetSelectorTerms = []v1beta1.SubnetSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1", "tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSubnets([]*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), + SubnetArn: aws.String("test-subnet-arn-3"), + }, + }, subnets) + + // OR semantics + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + {SubnetId: aws.String("test-subnet-id-2"), SubnetArn: aws.String("test-subnet-arn-2"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}}, + }}) + awsEnv.EC2API.DescribeSubnetsBehavior.MultiOut.Add(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + {SubnetId: aws.String("test-subnet-id-1"), SubnetArn: aws.String("test-subnet-arn-1"), Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}}, + }}) + nodeClass.Spec.SubnetSelectorTerms = []v1beta1.SubnetSelectorTerm{ + { + Tags: map[string]string{"tag-key-1": "tag-value-1"}, + }, + { + Tags: map[string]string{"tag-key-2": "tag-value-2"}, + }, + } + ExpectApplied(ctx, env.Client, nodeClass) + subnets, err = awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + ExpectConsistsOfSubnets([]*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), + SubnetArn: aws.String("test-subnet-arn-1"), + }, + { + SubnetId: aws.String("test-subnet-id-2"), + SubnetArn: aws.String("test-subnet-arn-2"), + }, + }, subnets) + + cacheItems := awsEnv.SubnetCache.Items() + // There should be 2 cache entries one for each semantic. + Expect(cacheItems).To(HaveLen(2)) + // Extract cached subnet arrays for comparison + cachedSubnets := make([][]*ec2.Subnet, 0, len(cacheItems)) + for _, item := range cacheItems { + cachedSubnets = append(cachedSubnets, item.Object.([]*ec2.Subnet)) + } + // Expect cache to contain result of both look ups. + Expect(cachedSubnets).To(ContainElement(ContainElements( + []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-1"), + SubnetArn: aws.String("test-subnet-arn-1"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}}, + }, + { + SubnetId: aws.String("test-subnet-id-2"), + SubnetArn: aws.String("test-subnet-arn-2"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + ))) + Expect(cachedSubnets).To(ContainElement( + []*ec2.Subnet{ + { + SubnetId: aws.String("test-subnet-id-3"), + SubnetArn: aws.String("test-subnet-arn-3"), + Tags: []*ec2.Tag{{Key: aws.String("tag-key-1"), Value: aws.String("tag-value-1")}, {Key: aws.String("tag-key-2"), Value: aws.String("tag-value-2")}}, + }, + }, + )) + }) }) It("should not cause data races when calling List() simultaneously", func() { wg := sync.WaitGroup{} @@ -356,6 +455,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1b"), @@ -374,6 +474,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1c"), @@ -393,6 +494,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("bar"), }, }, + VpcId: aws.String("vpc-test1"), }, { AvailabilityZone: lo.ToPtr("test-zone-1a-local"), @@ -406,6 +508,7 @@ var _ = Describe("SubnetProvider", func() { Value: lo.ToPtr("test-subnet-4"), }, }, + VpcId: aws.String("vpc-test1"), }, })) }() diff --git a/pkg/test/environment.go b/pkg/test/environment.go index 732e4b4dfa07..34c1a774330f 100644 --- a/pkg/test/environment.go +++ b/pkg/test/environment.go @@ -60,6 +60,7 @@ type Environment struct { PricingAPI *fake.PricingAPI // Cache + AMICache *cache.Cache EC2Cache *cache.Cache KubernetesVersionCache *cache.Cache InstanceTypeCache *cache.Cache @@ -97,6 +98,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment iamapi := fake.NewIAMAPI() // cache + amiCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) ec2Cache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) kubernetesVersionCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) instanceTypeCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) @@ -117,7 +119,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment versionProvider := version.NewDefaultProvider(env.KubernetesInterface, kubernetesVersionCache) instanceProfileProvider := instanceprofile.NewDefaultProvider(fake.DefaultRegion, iamapi, instanceProfileCache) ssmProvider := ssmp.NewDefaultProvider(ssmapi, ssmCache) - amiProvider := amifamily.NewDefaultProvider(clock, versionProvider, ssmProvider, ec2api, ec2Cache) + amiProvider := amifamily.NewDefaultProvider(clock, versionProvider, ssmProvider, ec2api, amiCache) amiResolver := amifamily.NewResolver(amiProvider) instanceTypesProvider := instancetype.NewDefaultProvider(fake.DefaultRegion, instanceTypeCache, ec2api, subnetProvider, unavailableOfferingsCache, pricingProvider) launchTemplateProvider := @@ -151,6 +153,7 @@ func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment IAMAPI: iamapi, PricingAPI: fakePricingAPI, + AMICache: amiCache, EC2Cache: ec2Cache, KubernetesVersionCache: kubernetesVersionCache, LaunchTemplateCache: launchTemplateCache, @@ -188,6 +191,7 @@ func (env *Environment) Reset() { env.PricingProvider.Reset() env.InstanceTypesProvider.Reset() + env.AMICache.Flush() env.EC2Cache.Flush() env.KubernetesVersionCache.Flush() env.UnavailableOfferingsCache.Flush() diff --git a/pkg/utils/suite_test.go b/pkg/utils/suite_test.go new file mode 100644 index 000000000000..2d6552385302 --- /dev/null +++ b/pkg/utils/suite_test.go @@ -0,0 +1,44 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" + "github.com/aws/karpenter-provider-aws/pkg/utils" +) + +func TestUtils(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Utils Suite") +} + +var _ = Describe("GetNodeClassHash", func() { + It("should return formatted hash with UID and Generation", func() { + nodeClass := &v1beta1.EC2NodeClass{ + ObjectMeta: metav1.ObjectMeta{ + UID: "test-uid-123", + Generation: 5, + }, + } + hash := utils.GetNodeClassHash(nodeClass) + Expect(hash).To(Equal("test-uid-123-5")) + }) +}) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 2dc741e87e3d..78de6ff35094 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -22,6 +22,8 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/samber/lo" + + "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" ) var ( @@ -66,3 +68,7 @@ func PrettySlice[T any](s []T, maxItems int) string { } return sb.String() } + +func GetNodeClassHash(nodeClass *v1beta1.EC2NodeClass) string { + return fmt.Sprintf("%s-%d", nodeClass.UID, nodeClass.Generation) +}