Skip to content

Commit 09935d8

Browse files
(Backport v1.1.x)fix: prevent hash collisions while resolving subnets, security groups and AMIs from nodeclass selectors (#8661)
Co-authored-by: Saurav Agarwalla <[email protected]>
1 parent a4354cb commit 09935d8

File tree

16 files changed

+640
-147
lines changed

16 files changed

+640
-147
lines changed

pkg/cloudprovider/suite_test.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ var _ = Describe("CloudProvider", func() {
646646
},
647647
},
648648
})
649-
awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{
649+
awsEnv.EC2API.DescribeSecurityGroupsBehavior.Output.Set(&ec2.DescribeSecurityGroupsOutput{
650650
SecurityGroups: []ec2types.SecurityGroup{
651651
{
652652
GroupId: aws.String(validSecurityGroup),
@@ -660,7 +660,7 @@ var _ = Describe("CloudProvider", func() {
660660
},
661661
},
662662
})
663-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{
663+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{
664664
Subnets: []ec2types.Subnet{
665665
{
666666
SubnetId: aws.String(validSubnet1),
@@ -1142,7 +1142,7 @@ var _ = Describe("CloudProvider", func() {
11421142
})
11431143
It("should launch instances into subnet with the most available IP addresses", func() {
11441144
awsEnv.SubnetCache.Flush()
1145-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1145+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
11461146
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
11471147
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
11481148
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(100),
@@ -1159,7 +1159,7 @@ var _ = Describe("CloudProvider", func() {
11591159
})
11601160
It("should launch instances into subnet with the most available IP addresses in-between cache refreshes", func() {
11611161
awsEnv.SubnetCache.Flush()
1162-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1162+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
11631163
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
11641164
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
11651165
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(11),
@@ -1187,7 +1187,7 @@ var _ = Describe("CloudProvider", func() {
11871187
Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-1"))
11881188
})
11891189
It("should update in-flight IPs when a CreateFleet error occurs", func() {
1190-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1190+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
11911191
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int32(10),
11921192
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
11931193
}})
@@ -1198,12 +1198,20 @@ var _ = Describe("CloudProvider", func() {
11981198
Expect(len(bindings)).To(Equal(0))
11991199
})
12001200
It("should launch instances into subnets that are excluded by another NodePool", func() {
1201-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
1202-
{SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(10),
1203-
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}},
1204-
{SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(100),
1205-
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}},
1206-
}})
1201+
awsEnv.EC2API.Subnets.Store("test-zone-1a", ec2types.Subnet{
1202+
SubnetId: aws.String("test-subnet-1"),
1203+
AvailabilityZone: aws.String("test-zone-1a"),
1204+
AvailabilityZoneId: aws.String("tstz1-1a"),
1205+
AvailableIpAddressCount: aws.Int32(10),
1206+
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}},
1207+
})
1208+
awsEnv.EC2API.Subnets.Store("test-zone-1b", ec2types.Subnet{
1209+
SubnetId: aws.String("test-subnet-2"),
1210+
AvailabilityZone: aws.String("test-zone-1b"),
1211+
AvailabilityZoneId: aws.String("tstz1-1a"),
1212+
AvailableIpAddressCount: aws.Int32(100),
1213+
Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}},
1214+
})
12071215
nodeClass.Spec.SubnetSelectorTerms = []v1.SubnetSelectorTerm{{Tags: map[string]string{"Name": "test-subnet-1"}}}
12081216
ExpectApplied(ctx, env.Client, nodePool, nodeClass)
12091217
controller := status.NewController(env.Client, awsEnv.SubnetProvider, awsEnv.SecurityGroupProvider, awsEnv.AMIProvider, awsEnv.InstanceProfileProvider, awsEnv.LaunchTemplateProvider)

pkg/controllers/nodeclass/status/ami_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() {
607607
awsEnv.Clock.Step(40 * time.Minute)
608608

609609
// Flush Cache
610-
awsEnv.EC2Cache.Flush()
610+
awsEnv.AMICache.Flush()
611611

612612
ExpectObjectReconciled(ctx, env.Client, statusController, nodeClass)
613613
nodeClass = ExpectExists(ctx, env.Client, nodeClass)
@@ -706,7 +706,7 @@ var _ = Describe("NodeClass AMI Status Controller", func() {
706706
},
707707
})
708708

709-
awsEnv.EC2Cache.Flush()
709+
awsEnv.AMICache.Flush()
710710

711711
ExpectApplied(ctx, env.Client, nodeClass)
712712
ExpectObjectReconciled(ctx, env.Client, statusController, nodeClass)

pkg/controllers/nodeclass/status/subnet_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() {
8080
Expect(nodeClass.StatusConditions().IsTrue(v1.ConditionTypeSubnetsReady)).To(BeTrue())
8181
})
8282
It("Should have the correct ordering for the Subnets", func() {
83-
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
83+
awsEnv.EC2API.DescribeSubnetsBehavior.Output.Set(&ec2.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{
8484
{SubnetId: aws.String("subnet-test1"), AvailabilityZone: aws.String("test-zone-1a"), AvailabilityZoneId: aws.String("tstz1-1a"), AvailableIpAddressCount: aws.Int32(20)},
8585
{SubnetId: aws.String("subnet-test2"), AvailabilityZone: aws.String("test-zone-1b"), AvailabilityZoneId: aws.String("tstz1-1b"), AvailableIpAddressCount: aws.Int32(100)},
8686
{SubnetId: aws.String("subnet-test3"), AvailabilityZone: aws.String("test-zone-1c"), AvailabilityZoneId: aws.String("tstz1-1c"), AvailableIpAddressCount: aws.Int32(50)},

pkg/controllers/providers/ssm/invalidation/controller.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ import (
2929
v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1"
3030
"github.com/aws/karpenter-provider-aws/pkg/providers/amifamily"
3131
"github.com/aws/karpenter-provider-aws/pkg/providers/ssm"
32+
33+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
34+
"k8s.io/apimachinery/pkg/util/uuid"
3235
)
3336

3437
// The SSM Invalidation controller is responsible for invalidating "latest" SSM parameters when they point to deprecated
@@ -66,6 +69,9 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) {
6669
amis := []amifamily.AMI{}
6770
for _, nodeClass := range lo.Map(lo.Keys(amiIDsToParameters), func(amiID string, _ int) *v1.EC2NodeClass {
6871
return &v1.EC2NodeClass{
72+
ObjectMeta: metav1.ObjectMeta{
73+
UID: uuid.NewUUID(), // ensures that this doesn't hit the AMI cache.
74+
},
6975
Spec: v1.EC2NodeClassSpec{
7076
AMISelectorTerms: []v1.AMISelectorTerm{{ID: amiID}},
7177
},

pkg/fake/ec2api.go

Lines changed: 98 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,19 @@ type CapacityPool struct {
4848
type EC2Behavior struct {
4949
DescribeImagesOutput AtomicPtr[ec2.DescribeImagesOutput]
5050
DescribeLaunchTemplatesOutput AtomicPtr[ec2.DescribeLaunchTemplatesOutput]
51-
DescribeSubnetsOutput AtomicPtr[ec2.DescribeSubnetsOutput]
52-
DescribeSecurityGroupsOutput AtomicPtr[ec2.DescribeSecurityGroupsOutput]
5351
DescribeInstanceTypesOutput AtomicPtr[ec2.DescribeInstanceTypesOutput]
5452
DescribeInstanceTypeOfferingsOutput AtomicPtr[ec2.DescribeInstanceTypeOfferingsOutput]
5553
DescribeAvailabilityZonesOutput AtomicPtr[ec2.DescribeAvailabilityZonesOutput]
5654
DescribeSpotPriceHistoryBehavior MockedFunction[ec2.DescribeSpotPriceHistoryInput, ec2.DescribeSpotPriceHistoryOutput]
5755
CreateFleetBehavior MockedFunction[ec2.CreateFleetInput, ec2.CreateFleetOutput]
5856
TerminateInstancesBehavior MockedFunction[ec2.TerminateInstancesInput, ec2.TerminateInstancesOutput]
5957
DescribeInstancesBehavior MockedFunction[ec2.DescribeInstancesInput, ec2.DescribeInstancesOutput]
58+
DescribeSubnetsBehavior MockedFunction[ec2.DescribeSubnetsInput, ec2.DescribeSubnetsOutput]
59+
DescribeSecurityGroupsBehavior MockedFunction[ec2.DescribeSecurityGroupsInput, ec2.DescribeSecurityGroupsOutput]
6060
CreateTagsBehavior MockedFunction[ec2.CreateTagsInput, ec2.CreateTagsOutput]
6161
CalledWithCreateLaunchTemplateInput AtomicPtrSlice[ec2.CreateLaunchTemplateInput]
6262
CalledWithDescribeImagesInput AtomicPtrSlice[ec2.DescribeImagesInput]
63+
Subnets sync.Map
6364
Instances sync.Map
6465
LaunchTemplates sync.Map
6566
InsufficientCapacityPools atomic.Slice[CapacityPool]
@@ -83,8 +84,8 @@ var DefaultSupportedUsageClasses = []ec2types.UsageClassType{ec2types.UsageClass
8384
func (e *EC2API) Reset() {
8485
e.DescribeImagesOutput.Reset()
8586
e.DescribeLaunchTemplatesOutput.Reset()
86-
e.DescribeSubnetsOutput.Reset()
87-
e.DescribeSecurityGroupsOutput.Reset()
87+
e.DescribeSubnetsBehavior.Reset()
88+
e.DescribeSecurityGroupsBehavior.Reset()
8889
e.DescribeInstanceTypesOutput.Reset()
8990
e.DescribeInstanceTypeOfferingsOutput.Reset()
9091
e.DescribeAvailabilityZonesOutput.Reset()
@@ -379,107 +380,109 @@ func (e *EC2API) DeleteLaunchTemplate(_ context.Context, input *ec2.DeleteLaunch
379380
}
380381

381382
func (e *EC2API) DescribeSubnets(_ context.Context, input *ec2.DescribeSubnetsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) {
382-
if !e.NextError.IsNil() {
383-
defer e.NextError.Reset()
384-
return nil, e.NextError.Get()
385-
}
386-
if !e.DescribeSubnetsOutput.IsNil() {
387-
describeSubnetsOutput := e.DescribeSubnetsOutput.Clone()
388-
describeSubnetsOutput.Subnets = FilterDescribeSubnets(describeSubnetsOutput.Subnets, input.Filters)
389-
return describeSubnetsOutput, nil
390-
}
391-
subnets := []ec2types.Subnet{
392-
{
393-
SubnetId: aws.String("subnet-test1"),
394-
AvailabilityZone: aws.String("test-zone-1a"),
395-
AvailabilityZoneId: aws.String("tstz1-1a"),
396-
AvailableIpAddressCount: aws.Int32(100),
397-
MapPublicIpOnLaunch: aws.Bool(false),
398-
Tags: []ec2types.Tag{
399-
{Key: aws.String("Name"), Value: aws.String("test-subnet-1")},
400-
{Key: aws.String("foo"), Value: aws.String("bar")},
383+
return e.DescribeSubnetsBehavior.Invoke(input, func(input *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) {
384+
output := &ec2.DescribeSubnetsOutput{}
385+
e.Subnets.Range(func(key, value any) bool {
386+
subnet := value.(ec2types.Subnet)
387+
if lo.Contains(input.SubnetIds, lo.FromPtr(subnet.SubnetId)) || len(input.Filters) != 0 && len(FilterDescribeSubnets([]ec2types.Subnet{subnet}, input.Filters)) != 0 {
388+
output.Subnets = append(output.Subnets, subnet)
389+
}
390+
return true
391+
})
392+
if len(output.Subnets) != 0 {
393+
return output, nil
394+
}
395+
396+
defaultSubnets := []ec2types.Subnet{
397+
{
398+
SubnetId: aws.String("subnet-test1"),
399+
AvailabilityZone: aws.String("test-zone-1a"),
400+
AvailabilityZoneId: aws.String("tstz1-1a"),
401+
AvailableIpAddressCount: aws.Int32(100),
402+
MapPublicIpOnLaunch: aws.Bool(false),
403+
Tags: []ec2types.Tag{
404+
{Key: aws.String("Name"), Value: aws.String("test-subnet-1")},
405+
{Key: aws.String("foo"), Value: aws.String("bar")},
406+
},
407+
VpcId: aws.String("vpc-test1"),
401408
},
402-
},
403-
{
404-
SubnetId: aws.String("subnet-test2"),
405-
AvailabilityZone: aws.String("test-zone-1b"),
406-
AvailabilityZoneId: aws.String("tstz1-1b"),
407-
AvailableIpAddressCount: aws.Int32(100),
408-
MapPublicIpOnLaunch: aws.Bool(true),
409-
Tags: []ec2types.Tag{
410-
{Key: aws.String("Name"), Value: aws.String("test-subnet-2")},
411-
{Key: aws.String("foo"), Value: aws.String("bar")},
409+
{
410+
SubnetId: aws.String("subnet-test2"),
411+
AvailabilityZone: aws.String("test-zone-1b"),
412+
AvailabilityZoneId: aws.String("tstz1-1b"),
413+
AvailableIpAddressCount: aws.Int32(100),
414+
MapPublicIpOnLaunch: aws.Bool(true),
415+
Tags: []ec2types.Tag{
416+
{Key: aws.String("Name"), Value: aws.String("test-subnet-2")},
417+
{Key: aws.String("foo"), Value: aws.String("bar")},
418+
},
419+
VpcId: aws.String("vpc-test1"),
412420
},
413-
},
414-
{
415-
SubnetId: aws.String("subnet-test3"),
416-
AvailabilityZone: aws.String("test-zone-1c"),
417-
AvailabilityZoneId: aws.String("tstz1-1c"),
418-
AvailableIpAddressCount: aws.Int32(100),
419-
Tags: []ec2types.Tag{
420-
{Key: aws.String("Name"), Value: aws.String("test-subnet-3")},
421-
{Key: aws.String("TestTag")},
422-
{Key: aws.String("foo"), Value: aws.String("bar")},
421+
{
422+
SubnetId: aws.String("subnet-test3"),
423+
AvailabilityZone: aws.String("test-zone-1c"),
424+
AvailabilityZoneId: aws.String("tstz1-1c"),
425+
AvailableIpAddressCount: aws.Int32(100),
426+
Tags: []ec2types.Tag{
427+
{Key: aws.String("Name"), Value: aws.String("test-subnet-3")},
428+
{Key: aws.String("TestTag")},
429+
{Key: aws.String("foo"), Value: aws.String("bar")},
430+
},
431+
VpcId: aws.String("vpc-test1"),
423432
},
424-
},
425-
{
426-
SubnetId: aws.String("subnet-test4"),
427-
AvailabilityZone: aws.String("test-zone-1a-local"),
428-
AvailabilityZoneId: aws.String("tstz1-1alocal"),
429-
AvailableIpAddressCount: aws.Int32(100),
430-
MapPublicIpOnLaunch: aws.Bool(true),
431-
Tags: []ec2types.Tag{
432-
{Key: aws.String("Name"), Value: aws.String("test-subnet-4")},
433+
{
434+
SubnetId: aws.String("subnet-test4"),
435+
AvailabilityZone: aws.String("test-zone-1a-local"),
436+
AvailabilityZoneId: aws.String("tstz1-1alocal"),
437+
AvailableIpAddressCount: aws.Int32(100),
438+
MapPublicIpOnLaunch: aws.Bool(true),
439+
Tags: []ec2types.Tag{
440+
{Key: aws.String("Name"), Value: aws.String("test-subnet-4")},
441+
},
442+
VpcId: aws.String("vpc-test1"),
433443
},
434-
},
435-
}
436-
if len(input.Filters) == 0 {
437-
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
438-
}
439-
return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(subnets, input.Filters)}, nil
444+
}
445+
if len(input.Filters) == 0 {
446+
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
447+
}
448+
return &ec2.DescribeSubnetsOutput{Subnets: FilterDescribeSubnets(defaultSubnets, input.Filters)}, nil
449+
})
440450
}
441451

442452
func (e *EC2API) DescribeSecurityGroups(_ context.Context, input *ec2.DescribeSecurityGroupsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) {
443-
if !e.NextError.IsNil() {
444-
defer e.NextError.Reset()
445-
return nil, e.NextError.Get()
446-
}
447-
if !e.DescribeSecurityGroupsOutput.IsNil() {
448-
describeSecurityGroupsOutput := e.DescribeSecurityGroupsOutput.Clone()
449-
describeSecurityGroupsOutput.SecurityGroups = FilterDescribeSecurtyGroups(describeSecurityGroupsOutput.SecurityGroups, input.Filters)
450-
return e.DescribeSecurityGroupsOutput.Clone(), nil
451-
}
452-
sgs := []ec2types.SecurityGroup{
453-
{
454-
GroupId: aws.String("sg-test1"),
455-
GroupName: aws.String("securityGroup-test1"),
456-
Tags: []ec2types.Tag{
457-
{Key: aws.String("Name"), Value: aws.String("test-security-group-1")},
458-
{Key: aws.String("foo"), Value: aws.String("bar")},
453+
return e.DescribeSecurityGroupsBehavior.Invoke(input, func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) {
454+
defaultSecurityGroups := []ec2types.SecurityGroup{
455+
{
456+
GroupId: aws.String("sg-test1"),
457+
GroupName: aws.String("securityGroup-test1"),
458+
Tags: []ec2types.Tag{
459+
{Key: aws.String("Name"), Value: aws.String("test-security-group-1")},
460+
{Key: aws.String("foo"), Value: aws.String("bar")},
461+
},
459462
},
460-
},
461-
{
462-
GroupId: aws.String("sg-test2"),
463-
GroupName: aws.String("securityGroup-test2"),
464-
Tags: []ec2types.Tag{
465-
{Key: aws.String("Name"), Value: aws.String("test-security-group-2")},
466-
{Key: aws.String("foo"), Value: aws.String("bar")},
463+
{
464+
GroupId: aws.String("sg-test2"),
465+
GroupName: aws.String("securityGroup-test2"),
466+
Tags: []ec2types.Tag{
467+
{Key: aws.String("Name"), Value: aws.String("test-security-group-2")},
468+
{Key: aws.String("foo"), Value: aws.String("bar")},
469+
},
467470
},
468-
},
469-
{
470-
GroupId: aws.String("sg-test3"),
471-
GroupName: aws.String("securityGroup-test3"),
472-
Tags: []ec2types.Tag{
473-
{Key: aws.String("Name"), Value: aws.String("test-security-group-3")},
474-
{Key: aws.String("TestTag")},
475-
{Key: aws.String("foo"), Value: aws.String("bar")},
471+
{
472+
GroupId: aws.String("sg-test3"),
473+
GroupName: aws.String("securityGroup-test3"),
474+
Tags: []ec2types.Tag{
475+
{Key: aws.String("Name"), Value: aws.String("test-security-group-3")},
476+
{Key: aws.String("TestTag")},
477+
{Key: aws.String("foo"), Value: aws.String("bar")},
478+
},
476479
},
477-
},
478-
}
479-
if len(input.Filters) == 0 {
480-
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
481-
}
482-
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(sgs, input.Filters)}, nil
480+
}
481+
if len(input.Filters) == 0 {
482+
return nil, fmt.Errorf("InvalidParameterValue: The filter 'null' is invalid")
483+
}
484+
return &ec2.DescribeSecurityGroupsOutput{SecurityGroups: FilterDescribeSecurtyGroups(defaultSecurityGroups, input.Filters)}, nil
485+
})
483486
}
484487

485488
func (e *EC2API) DescribeAvailabilityZones(context.Context, *ec2.DescribeAvailabilityZonesInput, ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) {

pkg/fake/types.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
type MockedFunction[I any, O any] struct {
2727
Output AtomicPtr[O] // Output to return on call to this function
28+
MultiOut AtomicPtrSlice[O]
2829
OutputPages AtomicPtrSlice[O]
2930
CalledWithInput AtomicPtrSlice[I] // Slice used to keep track of passed input to this function
3031
Error AtomicError // Error to return a certain number of times defined by custom error options
@@ -38,6 +39,7 @@ type MockedFunction[I any, O any] struct {
3839
// each other.
3940
func (m *MockedFunction[I, O]) Reset() {
4041
m.Output.Reset()
42+
m.MultiOut.Reset()
4143
m.OutputPages.Reset()
4244
m.CalledWithInput.Reset()
4345
m.Error.Reset()
@@ -59,6 +61,11 @@ func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O,
5961
m.successfulCalls.Add(1)
6062
return m.Output.Clone(), nil
6163
}
64+
65+
if m.MultiOut.Len() > 0 {
66+
m.successfulCalls.Add(1)
67+
return m.MultiOut.Pop(), nil
68+
}
6269
// This output pages multi-threaded handling isn't perfect
6370
// It will fail if pages are asynchronously requested from the same NextToken
6471
if m.OutputPages.Len() > 0 {

0 commit comments

Comments
 (0)