From 6db3c014ada160a4c7c4e5ae5a9121862850fbe4 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Wed, 2 Oct 2024 15:22:46 +1000 Subject: [PATCH 1/5] Aggregate NSG destination addresses optimally --- pkg/provider/loadbalancer/iputil/family.go | 7 + .../securitygroup/securitygroup.go | 12 ++ .../securitygroup/securitygroup_test.go | 166 ++++++++++++++++++ .../securitygroup/securityrule.go | 31 ++++ .../securitygroup/securityrule_test.go | 96 ++++++++++ 5 files changed, 312 insertions(+) create mode 100644 pkg/provider/loadbalancer/securitygroup/securityrule_test.go diff --git a/pkg/provider/loadbalancer/iputil/family.go b/pkg/provider/loadbalancer/iputil/family.go index 9ac42b92fd..a7b15a6930 100644 --- a/pkg/provider/loadbalancer/iputil/family.go +++ b/pkg/provider/loadbalancer/iputil/family.go @@ -29,6 +29,13 @@ const ( IPv6 Family = "IPv6" ) +func (f Family) MaxMask() int { + if f == IPv4 { + return 32 + } + return 128 +} + func FamilyOfAddr(addr netip.Addr) Family { if addr.Is4() { return IPv4 diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup.go b/pkg/provider/loadbalancer/securitygroup/securitygroup.go index 8fb6213fcf..08f1113adb 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup.go +++ b/pkg/provider/loadbalancer/securitygroup/securitygroup.go @@ -193,6 +193,18 @@ func (helper *RuleHelper) addAllowRule( { // Destination addresses := append(ListDestinationPrefixes(rule), dstPrefixes...) + + // Aggregate the prefixes + prefixes, serviceTags := SeparateIPsAndServiceTags(addresses) + prefixes = iputil.AggregatePrefixes(prefixes) + addresses = append(fnutil.Map(func(p netip.Prefix) string { + if p.Bits() == ipFamily.MaxMask() { + // Keep it as an IP address to avoid additional operation for old rules. + return p.Addr().String() + } + return p.String() + }, prefixes), serviceTags...) + SetDestinationPrefixes(rule, addresses) rule.Properties.DestinationPortRanges = to.SliceOfPtrs(dstPortRanges...) } diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go b/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go index dd7c3e5cb4..09b6435b31 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go +++ b/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go @@ -444,6 +444,172 @@ func TestSecurityGroupHelper_AddRuleForAllowedIPRanges(t *testing.T) { testutil.ExpectHasSecurityRules(t, outputSG, []*armnetwork.SecurityRule{targetRule}, "[`%s`] the target rule remain unchanged", c.TestName) } }) + t.Run("when destination address prefixes can be aggregated, it should aggregate them", func(t *testing.T) { + var ( + sg = fx.Azure().SecurityGroup().Build() + helper = ExpectNewSecurityGroupHelper(t, sg) + + protocol = armnetwork.SecurityRuleProtocolTCP + ipFamily = iputil.IPv4 + serviceTag = "AzureCloud" + dstAddresses = []netip.Addr{ + netip.MustParseAddr("10.0.0.1"), // 10.0.0.b0000_0001 + netip.MustParseAddr("10.0.0.2"), // 10.0.0.b0000_0010 + netip.MustParseAddr("10.0.0.3"), // 10.0.0.b0000_0011 + netip.MustParseAddr("10.0.0.4"), // 10.0.0.b0000_0100 + } + dstPorts = []int32{80, 443} + + expectedRule = &armnetwork.SecurityRule{ + Name: ptr.To(GenerateAllowSecurityRuleName(protocol, ipFamily, []string{serviceTag}, dstPorts)), + Properties: &armnetwork.SecurityRulePropertiesFormat{ + Protocol: to.Ptr(protocol), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefix: to.Ptr(serviceTag), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: to.SliceOfPtrs("10.0.0.1", "10.0.0.2/31", "10.0.0.4"), + DestinationPortRanges: to.SliceOfPtrs("443", "80"), + Priority: ptr.To(int32(500)), + }, + } + ) + + err := helper.AddRuleForAllowedServiceTag(serviceTag, protocol, dstAddresses, dstPorts) + assert.NoError(t, err) + + outputSG, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + assert.Equal(t, 1, len(outputSG.Properties.SecurityRules)) + testutil.ExpectHasSecurityRules(t, outputSG, []*armnetwork.SecurityRule{expectedRule}) + }) + t.Run("when adding new IP to existing rule, it should aggregate with old IPs", func(t *testing.T) { + var ( + sg = fx.Azure().SecurityGroup().Build() + helper = ExpectNewSecurityGroupHelper(t, sg) + + protocol = armnetwork.SecurityRuleProtocolTCP + ipFamily = iputil.IPv4 + serviceTag = "AzureCloud" + initialDstAddresses = []netip.Addr{ + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + } + newDstAddress = netip.MustParseAddr("10.0.0.3") + dstPorts = []int32{80, 443} + + expectedInitialRule = &armnetwork.SecurityRule{ + Name: ptr.To(GenerateAllowSecurityRuleName(protocol, ipFamily, []string{serviceTag}, dstPorts)), + Properties: &armnetwork.SecurityRulePropertiesFormat{ + Protocol: to.Ptr(protocol), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefix: to.Ptr(serviceTag), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: to.SliceOfPtrs("10.0.0.1", "10.0.0.2"), + DestinationPortRanges: to.SliceOfPtrs("443", "80"), + Priority: ptr.To(int32(500)), + }, + } + + expectedUpdatedRule = &armnetwork.SecurityRule{ + Name: ptr.To(GenerateAllowSecurityRuleName(protocol, ipFamily, []string{serviceTag}, dstPorts)), + Properties: &armnetwork.SecurityRulePropertiesFormat{ + Protocol: to.Ptr(protocol), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefix: to.Ptr(serviceTag), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: to.SliceOfPtrs("10.0.0.1", "10.0.0.2/31"), + DestinationPortRanges: to.SliceOfPtrs("443", "80"), + Priority: ptr.To(int32(500)), + }, + } + ) + + // Add initial rule + err := helper.AddRuleForAllowedServiceTag(serviceTag, protocol, initialDstAddresses, dstPorts) + assert.NoError(t, err) + + outputSG, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + assert.Equal(t, 1, len(outputSG.Properties.SecurityRules)) + testutil.ExpectHasSecurityRules(t, outputSG, []*armnetwork.SecurityRule{expectedInitialRule}) + + // Add new IP to existing rule + err = helper.AddRuleForAllowedServiceTag(serviceTag, protocol, []netip.Addr{newDstAddress}, dstPorts) + assert.NoError(t, err) + + outputSG, updated, err = helper.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + assert.Equal(t, 1, len(outputSG.Properties.SecurityRules)) + testutil.ExpectHasSecurityRules(t, outputSG, []*armnetwork.SecurityRule{expectedUpdatedRule}) + }) + t.Run("should aggregate IPv6 addresses", func(t *testing.T) { + var ( + sg = fx.Azure().SecurityGroup().Build() + helper = ExpectNewSecurityGroupHelper(t, sg) + + protocol = armnetwork.SecurityRuleProtocolTCP + serviceTag = "AzureCloud" + ipFamily = iputil.IPv6 + initialDstAddresses = []netip.Addr{netip.MustParseAddr("2001:db8::1"), netip.MustParseAddr("2001:db8::2")} + newDstAddr = netip.MustParseAddr("2001:db8::3") + dstPorts = []int32{80, 443} + + expectedInitialRule = &armnetwork.SecurityRule{ + Name: ptr.To(GenerateAllowSecurityRuleName(protocol, ipFamily, []string{serviceTag}, dstPorts)), + Properties: &armnetwork.SecurityRulePropertiesFormat{ + Protocol: to.Ptr(protocol), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefix: to.Ptr(serviceTag), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: to.SliceOfPtrs("2001:db8::1", "2001:db8::2"), + DestinationPortRanges: to.SliceOfPtrs("443", "80"), + Priority: ptr.To(int32(500)), + }, + } + + expectedUpdatedRule = &armnetwork.SecurityRule{ + Name: ptr.To(GenerateAllowSecurityRuleName(protocol, ipFamily, []string{serviceTag}, dstPorts)), + Properties: &armnetwork.SecurityRulePropertiesFormat{ + Protocol: to.Ptr(protocol), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefix: to.Ptr(serviceTag), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: to.SliceOfPtrs("2001:db8::1", "2001:db8::2/127"), + DestinationPortRanges: to.SliceOfPtrs("443", "80"), + Priority: ptr.To(int32(500)), + }, + } + ) + + // Add initial rule + err := helper.AddRuleForAllowedServiceTag(serviceTag, protocol, initialDstAddresses, dstPorts) + assert.NoError(t, err) + + outputSG, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + assert.Equal(t, 1, len(outputSG.Properties.SecurityRules)) + testutil.ExpectHasSecurityRules(t, outputSG, []*armnetwork.SecurityRule{expectedInitialRule}) + + // Add new IPv6 address to existing rule + err = helper.AddRuleForAllowedServiceTag(serviceTag, protocol, []netip.Addr{newDstAddr}, dstPorts) + assert.NoError(t, err) + + outputSG, updated, err = helper.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + assert.Equal(t, 1, len(outputSG.Properties.SecurityRules)) + testutil.ExpectHasSecurityRules(t, outputSG, []*armnetwork.SecurityRule{expectedUpdatedRule}) + }) + } func TestSecurityGroupHelper_AddRuleForAllowedServiceTag(t *testing.T) { diff --git a/pkg/provider/loadbalancer/securitygroup/securityrule.go b/pkg/provider/loadbalancer/securitygroup/securityrule.go index 1cd843d44f..b688c0270f 100644 --- a/pkg/provider/loadbalancer/securitygroup/securityrule.go +++ b/pkg/provider/loadbalancer/securitygroup/securityrule.go @@ -19,6 +19,7 @@ package securitygroup import ( "crypto/md5" //nolint:gosec "fmt" + "net/netip" "sort" "strconv" "strings" @@ -152,3 +153,33 @@ func ProtocolFromKubernetes(p v1.Protocol) (armnetwork.SecurityRuleProtocol, err } return "", fmt.Errorf("unsupported protocol %s", p) } + +// SeparateIPsAndServiceTags divides a list of prefixes into IP addresses/ranges and Azure service tags. +// +// The input prefixes can be sourced from networkSecurityGroup.SourceAddressPrefixes or +// networkSecurityGroup.DestinationAddressPrefixes, which are of type []string and may contain +// both IP addresses/ranges and Azure service tags. +// +// Returns: +// - []netip.Prefix: A slice of IP addresses and ranges parsed as netip.Prefix +// - []string: A slice of Azure service tags +func SeparateIPsAndServiceTags(prefixes []string) ([]netip.Prefix, []string) { + var ( + ips []netip.Prefix + serviceTags []string + ) + + for _, p := range prefixes { + if addr, err := netip.ParseAddr(p); err == nil { + ips = append(ips, netip.PrefixFrom(addr, addr.BitLen())) + continue + } + if prefix, err := netip.ParsePrefix(p); err == nil { + ips = append(ips, prefix) + continue + } + serviceTags = append(serviceTags, p) + } + + return ips, serviceTags +} diff --git a/pkg/provider/loadbalancer/securitygroup/securityrule_test.go b/pkg/provider/loadbalancer/securitygroup/securityrule_test.go new file mode 100644 index 0000000000..e44db32548 --- /dev/null +++ b/pkg/provider/loadbalancer/securitygroup/securityrule_test.go @@ -0,0 +1,96 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package securitygroup + +import ( + "net/netip" + "reflect" + "testing" +) + +func TestSeparateIPsAndServiceTags(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []string + expectedIPs []netip.Prefix + expectedTags []string + }{ + { + name: "Mixed IPs and service tags", + input: []string{"192.168.0.1", "10.0.0.0/24", "Internet", "172.16.0.1/32", "AzureLoadBalancer"}, + expectedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.1/32"), + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.1/32"), + }, + expectedTags: []string{"Internet", "AzureLoadBalancer"}, + }, + { + name: "Only IPs", + input: []string{"192.168.0.1", "10.0.0.0/24", "172.16.0.1/32"}, + expectedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.1/32"), + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.1/32"), + }, + }, + { + name: "Only service tags", + input: []string{"Internet", "AzureLoadBalancer", "VirtualNetwork"}, + expectedTags: []string{"Internet", "AzureLoadBalancer", "VirtualNetwork"}, + }, + { + name: "Empty input", + input: []string{}, + }, + { + name: "With Asterisk", + input: []string{"192.168.0.1", "*", "10.0.0.0/24", "Internet"}, + expectedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.1/32"), + netip.MustParsePrefix("10.0.0.0/24"), + }, + expectedTags: []string{"*", "Internet"}, + }, + { + name: "IPv6 addresses and prefixes", + input: []string{"2001:db8::1", "2001:db8::/32", "fe80::1234:5678:9abc:def0", "Internet"}, + expectedIPs: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("fe80::1234:5678:9abc:def0/128"), + }, + expectedTags: []string{"Internet"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ips, tags := SeparateIPsAndServiceTags(tc.input) + + if !reflect.DeepEqual(ips, tc.expectedIPs) { + t.Errorf("Expected IPs %v, but got %v", tc.expectedIPs, ips) + } + + if !reflect.DeepEqual(tags, tc.expectedTags) { + t.Errorf("Expected tags %v, but got %v", tc.expectedTags, tags) + } + }) + } +} From 43a467a9a85bc51367ce268549e67a6c0e20dc30 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Wed, 2 Oct 2024 18:06:36 +1000 Subject: [PATCH 2/5] Implement `Remove` for IP PrefixTree --- .../loadbalancer/accesscontrol_test.go | 48 +++--- pkg/provider/loadbalancer/fnutil/set.go | 35 +++++ pkg/provider/loadbalancer/iputil/prefix.go | 45 ++++++ .../loadbalancer/iputil/prefix_test.go | 81 +++++++++++ .../loadbalancer/iputil/prefix_tree.go | 118 +++++++++++---- .../loadbalancer/iputil/prefix_tree_test.go | 82 +++++++++++ .../securitygroup/addressprefix.go | 137 ++++++++++++++++++ ...rityrule_test.go => addressprefix_test.go} | 0 .../securitygroup/securitygroup.go | 32 ++-- .../securitygroup/securitygroup_test.go | 16 +- .../securitygroup/securityrule.go | 31 ---- 11 files changed, 512 insertions(+), 113 deletions(-) create mode 100644 pkg/provider/loadbalancer/fnutil/set.go create mode 100644 pkg/provider/loadbalancer/securitygroup/addressprefix.go rename pkg/provider/loadbalancer/securitygroup/{securityrule_test.go => addressprefix_test.go} (100%) diff --git a/pkg/provider/loadbalancer/accesscontrol_test.go b/pkg/provider/loadbalancer/accesscontrol_test.go index 548e930ebf..b1e6289248 100644 --- a/pkg/provider/loadbalancer/accesscontrol_test.go +++ b/pkg/provider/loadbalancer/accesscontrol_test.go @@ -1234,14 +1234,14 @@ func TestAccessControl_CleanSecurityGroup(t *testing.T) { { Name: ptr.To("test-rule-2"), Properties: &armnetwork.SecurityRulePropertiesFormat{ - Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk), - Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), - Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), - SourceAddressPrefixes: to.SliceOfPtrs("*"), - SourcePortRange: ptr.To("*"), - DestinationAddressPrefixes: to.SliceOfPtrs("8.8.8.8"), - DestinationPortRanges: to.SliceOfPtrs("5000"), - Priority: ptr.To(int32(502)), + Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefixes: to.SliceOfPtrs("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefix: to.Ptr("8.8.8.8"), + DestinationPortRanges: to.SliceOfPtrs("5000"), + Priority: ptr.To(int32(502)), }, }, }, outputSG.Properties.SecurityRules) @@ -1336,14 +1336,14 @@ func TestAccessControl_CleanSecurityGroup(t *testing.T) { { Name: ptr.To("test-rule-2"), Properties: &armnetwork.SecurityRulePropertiesFormat{ - Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk), - Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), - Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), - SourceAddressPrefixes: to.SliceOfPtrs("*"), - SourcePortRange: ptr.To("*"), - DestinationAddressPrefixes: to.SliceOfPtrs("8.8.8.8"), - DestinationPortRanges: to.SliceOfPtrs("5000"), - Priority: ptr.To(int32(502)), + Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefixes: to.SliceOfPtrs("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefix: to.Ptr("8.8.8.8"), + DestinationPortRanges: to.SliceOfPtrs("5000"), + Priority: ptr.To(int32(502)), }, }, }, outputSG.Properties.SecurityRules) @@ -1440,14 +1440,14 @@ func TestAccessControl_CleanSecurityGroup(t *testing.T) { { Name: ptr.To("test-rule-2"), Properties: &armnetwork.SecurityRulePropertiesFormat{ - Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk), - Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), - Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), - SourceAddressPrefixes: to.SliceOfPtrs("*"), - SourcePortRange: ptr.To("*"), - DestinationAddressPrefixes: to.SliceOfPtrs("8.8.8.8"), - DestinationPortRanges: to.SliceOfPtrs("5000"), - Priority: ptr.To(int32(502)), + Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefixes: to.SliceOfPtrs("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefix: to.Ptr("8.8.8.8"), + DestinationPortRanges: to.SliceOfPtrs("5000"), + Priority: ptr.To(int32(502)), }, }, { diff --git a/pkg/provider/loadbalancer/fnutil/set.go b/pkg/provider/loadbalancer/fnutil/set.go new file mode 100644 index 0000000000..00c36ed341 --- /dev/null +++ b/pkg/provider/loadbalancer/fnutil/set.go @@ -0,0 +1,35 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fnutil + +type Set[T comparable] map[T]struct{} + +func SliceToSet[T comparable](xs []T) Set[T] { + rv := make(Set[T], len(xs)) + for _, x := range xs { + rv[x] = struct{}{} + } + return rv +} + +func SetToSlice[T comparable](s Set[T]) []T { + rv := make([]T, 0, len(s)) + for x := range s { + rv = append(rv, x) + } + return rv +} diff --git a/pkg/provider/loadbalancer/iputil/prefix.go b/pkg/provider/loadbalancer/iputil/prefix.go index 77395c37a8..dd9fbedb82 100644 --- a/pkg/provider/loadbalancer/iputil/prefix.go +++ b/pkg/provider/loadbalancer/iputil/prefix.go @@ -32,6 +32,20 @@ func IsPrefixesAllowAll(prefixes []netip.Prefix) bool { return false } +// AddressesAsPrefixes converts a list of IP addresses to a list of prefixes. +// Each address is converted to a prefix with a mask length equal to its bit length. +// +// Examples: +// - 192.168.1.2 becomes 192.168.1.2/32 +// - 2001:db8::1 becomes 2001:db8::1/128 +func AddressesAsPrefixes(addresses []netip.Addr) []netip.Prefix { + var rv []netip.Prefix + for _, addr := range addresses { + rv = append(rv, netip.PrefixFrom(addr, addr.BitLen())) + } + return rv +} + // ParsePrefix parses a CIDR string and returns a Prefix. func ParsePrefix(v string) (netip.Prefix, error) { prefix, err := netip.ParsePrefix(v) @@ -79,3 +93,34 @@ func AggregatePrefixes(prefixes []netip.Prefix) []netip.Prefix { return append(v4Tree.List(), v6Tree.List()...) } + +// ExcludePrefixes excludes prefixes from the given prefixes. +func ExcludePrefixes(prefixes []netip.Prefix, exclude []netip.Prefix) []netip.Prefix { + var ( + v4Tree = newPrefixTreeForIPv4() + v6Tree = newPrefixTreeForIPv6() + ) + + // Build the prefix tree for the prefixes. + { + v4, v6 := GroupPrefixesByFamily(prefixes) + for _, p := range v4 { + v4Tree.Add(p) + } + for _, p := range v6 { + v6Tree.Add(p) + } + } + + // Exclude the prefixes. + v4, v6 := GroupPrefixesByFamily(exclude) + for _, p := range v4 { + v4Tree.Remove(p) + } + for _, p := range v6 { + v6Tree.Remove(p) + } + + // Return the remaining prefixes. + return append(v4Tree.List(), v6Tree.List()...) +} diff --git a/pkg/provider/loadbalancer/iputil/prefix_test.go b/pkg/provider/loadbalancer/iputil/prefix_test.go index 30661c4a4e..0c2c7651b8 100644 --- a/pkg/provider/loadbalancer/iputil/prefix_test.go +++ b/pkg/provider/loadbalancer/iputil/prefix_test.go @@ -343,3 +343,84 @@ func BenchmarkAggregatePrefixes(b *testing.B) { runMixedTests(b, n) } } + +func TestExcludePrefixes(t *testing.T) { + tests := []struct { + name string + prefixes []string + exclude []string + expected []string + }{ + { + name: "Exclude single IPv4 prefix", + prefixes: []string{"192.168.0.0/16", "10.0.0.0/8"}, + exclude: []string{"192.168.0.0/16"}, + expected: []string{"10.0.0.0/8"}, + }, + { + name: "Exclude non-existent IPv4 prefix", + prefixes: []string{"192.168.0.0/16", "10.0.0.0/8"}, + exclude: []string{"172.16.0.0/12"}, + expected: []string{"192.168.0.0/16", "10.0.0.0/8"}, + }, + { + name: "Exclude multiple IPv4 prefixes", + prefixes: []string{"192.168.0.0/16", "10.0.0.0/8", "172.16.0.0/12"}, + exclude: []string{"192.168.0.0/16", "10.0.0.0/8"}, + expected: []string{"172.16.0.0/12"}, + }, + { + name: "Exclude single IPv6 prefix", + prefixes: []string{"2001:db8::/32", "2001::/32"}, + exclude: []string{"2001:db8::/32"}, + expected: []string{"2001::/32"}, + }, + { + name: "Exclude non-existent IPv6 prefix", + prefixes: []string{"2001:db8::/32", "2001::/32"}, + exclude: []string{"2001:abc::/32"}, + expected: []string{"2001:db8::/32", "2001::/32"}, + }, + { + name: "Exclude multiple IPv6 prefixes", + prefixes: []string{"2001:db8::/32", "2001::/32", "2001:abc::/32"}, + exclude: []string{"2001:db8::/32", "2001::/32"}, + expected: []string{"2001:abc::/32"}, + }, + { + name: "Exclude subnet and split IPv4", + prefixes: []string{"192.168.0.0/16"}, + exclude: []string{"192.168.1.0/24"}, + expected: []string{"192.168.0.0/24", "192.168.2.0/23", "192.168.4.0/22", "192.168.8.0/21", "192.168.16.0/20", "192.168.32.0/19", "192.168.64.0/18", "192.168.128.0/17"}, + }, + { + name: "Exclude subnet and split IPv6", + prefixes: []string{"2001:db8::/32"}, + exclude: []string{"2001:db8:1::/48"}, + expected: []string{"2001:db8::/48", "2001:db8:2::/47", "2001:db8:4::/46", "2001:db8:8::/45", "2001:db8:10::/44", "2001:db8:20::/43", "2001:db8:40::/42", "2001:db8:80::/41", "2001:db8:100::/40", "2001:db8:200::/39", "2001:db8:400::/38", "2001:db8:800::/37", "2001:db8:1000::/36", "2001:db8:2000::/35", "2001:db8:4000::/34", "2001:db8:8000::/33"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefixes := make([]netip.Prefix, len(tt.prefixes)) + for i, p := range tt.prefixes { + prefixes[i] = netip.MustParsePrefix(p) + } + + exclude := make([]netip.Prefix, len(tt.exclude)) + for i, p := range tt.exclude { + exclude[i] = netip.MustParsePrefix(p) + } + + result := ExcludePrefixes(prefixes, exclude) + + resultStrings := make([]string, len(result)) + for i, p := range result { + resultStrings[i] = p.String() + } + + assert.ElementsMatch(t, tt.expected, resultStrings) + }) + } +} diff --git a/pkg/provider/loadbalancer/iputil/prefix_tree.go b/pkg/provider/loadbalancer/iputil/prefix_tree.go index 2399c09ebd..4135a35740 100644 --- a/pkg/provider/loadbalancer/iputil/prefix_tree.go +++ b/pkg/provider/loadbalancer/iputil/prefix_tree.go @@ -20,6 +20,13 @@ import ( "net/netip" ) +// bitAt returns the bit at the i-th position in the byte slice. +// The return value is either 0 or 1 as uint8. +// Panics if the index is out of bounds. +func bitAt(bytes []byte, i int) uint8 { + return bytes[i/8] >> (7 - i%8) & 1 +} + type prefixTreeNode struct { masked bool prefix netip.Prefix @@ -29,11 +36,38 @@ type prefixTreeNode struct { r *prefixTreeNode // right child node } -// pruneToRoot prunes the tree to the root. -// If a node's left and right children are both masked, -// it is masked and its children are pruned. -// This is done recursively up to the root. -func (n *prefixTreeNode) pruneToRoot() { +func (n *prefixTreeNode) NewLeftChild() *prefixTreeNode { + prefix := netip.PrefixFrom(n.prefix.Addr(), n.prefix.Bits()+1) + n.l = &prefixTreeNode{ + prefix: prefix, + p: n, + } + return n.l +} + +func (n *prefixTreeNode) NewRightChild() *prefixTreeNode { + prefixBytes := n.prefix.Addr().AsSlice() + { + // Set the next bit to 1 for the new prefix (it's the right child) + byteIndex := n.prefix.Bits() / 8 + bitIndex := n.prefix.Bits() % 8 + prefixBytes[byteIndex] |= 1 << (7 - bitIndex) + } + + addr, _ := netip.AddrFromSlice(prefixBytes) + prefix := netip.PrefixFrom(addr, n.prefix.Bits()+1) + n.r = &prefixTreeNode{ + prefix: prefix, + p: n, + } + return n.r +} + +// MaskAndPruneToRoot masks the current node and prunes the tree upwards. +// It recursively checks parent nodes, masking and pruning them if both +// children are masked. This process continues until reaching the root +// or a node that cannot be pruned. +func (n *prefixTreeNode) MaskAndPruneToRoot() { var node = n for node.p != nil { p := node.p @@ -75,48 +109,72 @@ func newPrefixTreeForIPv6() *prefixTree { // Add adds a prefix to the tree. func (t *prefixTree) Add(prefix netip.Prefix) { var ( - n = t.root - bits = prefix.Addr().AsSlice() + n = t.root + bytes = prefix.Addr().AsSlice() ) for i := 0; i < prefix.Bits(); i++ { if n.masked { break // It's already masked, the rest of the bits are irrelevant } - var bit = bits[i/8] >> (7 - i%8) & 1 - switch bit { - case 0: + if bitAt(bytes, i) == 0 { if n.l == nil { - next, err := prefix.Addr().Prefix(i + 1) - if err != nil { - panic("unreachable: invalid prefix") - } - n.l = &prefixTreeNode{ - prefix: next, - p: n, - } + n.NewLeftChild() } n = n.l - case 1: + } else { if n.r == nil { - next, err := prefix.Addr().Prefix(i + 1) - if err != nil { - panic("unreachable: invalid prefix") - } - n.r = &prefixTreeNode{ - prefix: next, - p: n, - } + n.NewRightChild() } n = n.r - default: - panic("unreachable: unexpected bit") } } n.masked = true n.l, n.r = nil, nil - n.pruneToRoot() + n.MaskAndPruneToRoot() +} + +// Remove removes a prefix from the tree. +// If the prefix is not in the tree, it does nothing. +func (t *prefixTree) Remove(prefix netip.Prefix) { + var ( + n = t.root + bytes = prefix.Addr().AsSlice() + ) + + isMasked := false + for i := 0; n != nil && i < prefix.Bits(); i++ { + var bit = bitAt(bytes, i) + + if !n.masked && !isMasked { + // Keep going down until it finds a masked node + if bit == 0 { + n = n.l + } else { + n = n.r + } + continue + } + + isMasked = true + n.masked = false + + // If the node is masked, it should have no children, + // and we need to split it. The other side should be masked. + n.NewLeftChild() + n.NewRightChild() + if bit == 0 { + n.r.masked = true + n = n.l + } else { + n.l.masked = true + n = n.r + } + } + if n != nil { + n.masked = false + } } // List returns all prefixes in the tree. diff --git a/pkg/provider/loadbalancer/iputil/prefix_tree_test.go b/pkg/provider/loadbalancer/iputil/prefix_tree_test.go index fe3f3dddf1..72eddbbc12 100644 --- a/pkg/provider/loadbalancer/iputil/prefix_tree_test.go +++ b/pkg/provider/loadbalancer/iputil/prefix_tree_test.go @@ -261,3 +261,85 @@ func BenchmarkPrefixTree_List(b *testing.B) { } }) } +func TestPrefixTree_Remove(t *testing.T) { + tests := []struct { + name string + add []string + remove []string + expected []string + }{ + { + name: "Remove single IPv4 prefix", + add: []string{"192.168.0.0/16", "10.0.0.0/8"}, + remove: []string{"192.168.0.0/16"}, + expected: []string{"10.0.0.0/8"}, + }, + { + name: "Remove non-existent IPv4 prefix", + add: []string{"192.168.0.0/16", "10.0.0.0/8"}, + remove: []string{"172.16.0.0/12"}, + expected: []string{"192.168.0.0/16", "10.0.0.0/8"}, + }, + { + name: "Remove multiple IPv4 prefixes", + add: []string{"192.168.0.0/16", "10.0.0.0/8", "172.16.0.0/12"}, + remove: []string{"192.168.0.0/16", "10.0.0.0/8"}, + expected: []string{"172.16.0.0/12"}, + }, + { + name: "Remove single IPv6 prefix", + add: []string{"2001:db8::/32", "2001::/32"}, + remove: []string{"2001:db8::/32"}, + expected: []string{"2001::/32"}, + }, + { + name: "Remove non-existent IPv6 prefix", + add: []string{"2001:db8::/32", "2001::/32"}, + remove: []string{"2001:abc::/32"}, + expected: []string{"2001:db8::/32", "2001::/32"}, + }, + { + name: "Remove multiple IPv6 prefixes", + add: []string{"2001:db8::/32", "2001::/32", "2001:abc::/32"}, + remove: []string{"2001:db8::/32", "2001::/32"}, + expected: []string{"2001:abc::/32"}, + }, + { + name: "Remove subnet and split IPv4", + add: []string{"192.168.0.0/16"}, + remove: []string{"192.168.1.0/24"}, + expected: []string{"192.168.0.0/24", "192.168.2.0/23", "192.168.4.0/22", "192.168.8.0/21", "192.168.16.0/20", "192.168.32.0/19", "192.168.64.0/18", "192.168.128.0/17"}, + }, + { + name: "Remove subnet and split IPv6", + add: []string{"2001:db8::/32"}, + remove: []string{"2001:db8:1::/48"}, + expected: []string{"2001:db8::/48", "2001:db8:2::/47", "2001:db8:4::/46", "2001:db8:8::/45", "2001:db8:10::/44", "2001:db8:20::/43", "2001:db8:40::/42", "2001:db8:80::/41", "2001:db8:100::/40", "2001:db8:200::/39", "2001:db8:400::/38", "2001:db8:800::/37", "2001:db8:1000::/36", "2001:db8:2000::/35", "2001:db8:4000::/34", "2001:db8:8000::/33"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree := newPrefixTreeForIPv4() + if len(tt.add) > 0 && netip.MustParsePrefix(tt.add[0]).Addr().Is6() { + tree = newPrefixTreeForIPv6() + } + + for _, prefix := range tt.add { + tree.Add(netip.MustParsePrefix(prefix)) + } + + for _, prefix := range tt.remove { + tree.Remove(netip.MustParsePrefix(prefix)) + } + + result := tree.List() + var resultStrings []string + for _, prefix := range result { + resultStrings = append(resultStrings, prefix.String()) + } + + assert.ElementsMatch(t, tt.expected, resultStrings) + }) + } +} diff --git a/pkg/provider/loadbalancer/securitygroup/addressprefix.go b/pkg/provider/loadbalancer/securitygroup/addressprefix.go new file mode 100644 index 0000000000..bfa4b48ab2 --- /dev/null +++ b/pkg/provider/loadbalancer/securitygroup/addressprefix.go @@ -0,0 +1,137 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package securitygroup + +import ( + "net/netip" + "sort" + + "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" +) + +type AddressPrefixes struct { + ipPrefixes []netip.Prefix + serviceTagIndex map[string]struct{} +} + +// NewAddressPrefixes creates a new AddressPrefixes instance from a slice of strings. +// It is designed to be used with networkSecurityGroup.SourceAddressPrefixes or +// networkSecurityGroup.DestinationAddressPrefixes, which are of type []string and may contain +// both IP addresses/ranges and Azure service tags. +func NewAddressPrefixes(s []string) *AddressPrefixes { + ipPrefixes, serviceTags := SeparateIPsAndServiceTags(s) + rv := &AddressPrefixes{ + ipPrefixes: ipPrefixes, + serviceTagIndex: fnutil.SliceToSet(serviceTags), + } + rv.tidyIPPrefixes() + return rv +} + +// tidyIPPrefixes aggregates and sorts the IP prefixes in the AddressPrefixes instance. +func (ap *AddressPrefixes) tidyIPPrefixes() { + ap.ipPrefixes = iputil.AggregatePrefixes(ap.ipPrefixes) +} + +// AddIPAddresses adds one or more IP addresses to the AddressPrefixes instance. +func (ap *AddressPrefixes) AddIPAddresses(addresses ...netip.Addr) { + for _, addr := range addresses { + ap.ipPrefixes = append(ap.ipPrefixes, netip.PrefixFrom(addr, addr.BitLen())) + } + ap.tidyIPPrefixes() +} + +// RemoveIPAddresses removes one or more IP addresses from the AddressPrefixes instance. +func (ap *AddressPrefixes) RemoveIPAddresses(addresses ...netip.Addr) { + ap.RemoveIPPrefixes(iputil.AddressesAsPrefixes(addresses)...) +} + +// RemoveIPPrefixes removes one or more IP prefixes from the AddressPrefixes instance. +func (ap *AddressPrefixes) RemoveIPPrefixes(prefixes ...netip.Prefix) { + ap.ipPrefixes = iputil.ExcludePrefixes(ap.ipPrefixes, prefixes) + // No need to tidyIPPrefixes here, as ExcludePrefixes already does that. +} + +// AddIPPrefixes adds one or more IP prefixes to the AddressPrefixes instance. +func (ap *AddressPrefixes) AddIPPrefixes(prefixes ...netip.Prefix) { + ap.ipPrefixes = append(ap.ipPrefixes, prefixes...) + ap.tidyIPPrefixes() +} + +// AddServiceTags adds one or more service tags to the AddressPrefixes instance. +func (ap *AddressPrefixes) AddServiceTags(tags ...string) { + for _, tag := range tags { + ap.serviceTagIndex[tag] = struct{}{} + } +} + +// RemoveServiceTags removes one or more service tags from the AddressPrefixes instance. +func (ap *AddressPrefixes) RemoveServiceTags(tags ...string) { + for _, tag := range tags { + delete(ap.serviceTagIndex, tag) + } +} + +// StringSlice returns a slice of strings representing all IP addresses, prefixes, and service tags. +func (ap *AddressPrefixes) StringSlice() []string { + var rv []string + + for _, ip := range ap.ipPrefixes { + if ip.Bits() == ip.Addr().BitLen() { + // Prefer IP address over IP range if possible. + rv = append(rv, ip.Addr().String()) + } else { + rv = append(rv, ip.String()) + } + } + sort.Slice(ap.ipPrefixes, func(i, j int) bool { + return ap.ipPrefixes[i].String() < ap.ipPrefixes[j].String() + }) + + return append(rv, fnutil.SetToSlice(ap.serviceTagIndex)...) +} + +// SeparateIPsAndServiceTags divides a list of prefixes into IP addresses/ranges and Azure service tags. +// +// The input prefixes can be sourced from networkSecurityGroup.SourceAddressPrefixes or +// networkSecurityGroup.DestinationAddressPrefixes, which are of type []string and may contain +// both IP addresses/ranges and Azure service tags. +// +// Returns: +// - []netip.Prefix: A slice of IP addresses and ranges parsed as netip.Prefix +// - []string: A slice of Azure service tags +func SeparateIPsAndServiceTags(prefixes []string) ([]netip.Prefix, []string) { + var ( + ips []netip.Prefix + serviceTags []string + ) + + for _, p := range prefixes { + if addr, err := netip.ParseAddr(p); err == nil { + ips = append(ips, netip.PrefixFrom(addr, addr.BitLen())) + continue + } + if prefix, err := netip.ParsePrefix(p); err == nil { + ips = append(ips, prefix) + continue + } + serviceTags = append(serviceTags, p) + } + + return ips, serviceTags +} diff --git a/pkg/provider/loadbalancer/securitygroup/securityrule_test.go b/pkg/provider/loadbalancer/securitygroup/addressprefix_test.go similarity index 100% rename from pkg/provider/loadbalancer/securitygroup/securityrule_test.go rename to pkg/provider/loadbalancer/securitygroup/addressprefix_test.go diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup.go b/pkg/provider/loadbalancer/securitygroup/securitygroup.go index 08f1113adb..ed949c37a4 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup.go +++ b/pkg/provider/loadbalancer/securitygroup/securitygroup.go @@ -192,20 +192,10 @@ func (helper *RuleHelper) addAllowRule( } { // Destination - addresses := append(ListDestinationPrefixes(rule), dstPrefixes...) - - // Aggregate the prefixes - prefixes, serviceTags := SeparateIPsAndServiceTags(addresses) - prefixes = iputil.AggregatePrefixes(prefixes) - addresses = append(fnutil.Map(func(p netip.Prefix) string { - if p.Bits() == ipFamily.MaxMask() { - // Keep it as an IP address to avoid additional operation for old rules. - return p.Addr().String() - } - return p.String() - }, prefixes), serviceTags...) - SetDestinationPrefixes(rule, addresses) + // Tidy up and aggregate the destination prefixes. + addressPrefixes := NewAddressPrefixes(append(ListDestinationPrefixes(rule), dstPrefixes...)) + SetDestinationPrefixes(rule, addressPrefixes.StringSlice()) rule.Properties.DestinationPortRanges = to.SliceOfPtrs(dstPortRanges...) } @@ -295,8 +285,10 @@ func (helper *RuleHelper) AddRuleForDenyAll(dstAddresses []netip.Addr) error { { // Destination addresses := fnutil.Map(func(ip netip.Addr) string { return ip.String() }, dstAddresses) - addresses = append(addresses, ListDestinationPrefixes(rule)...) - SetDestinationPrefixes(rule, addresses) + + // Tidy up and aggregate the destination prefixes. + addressPrefixes := NewAddressPrefixes(append(addresses, ListDestinationPrefixes(rule)...)) + SetDestinationPrefixes(rule, addressPrefixes.StringSlice()) rule.Properties.DestinationPortRange = ptr.To("*") } @@ -342,12 +334,12 @@ func (helper *RuleHelper) removeDestinationFromRule(rule *armnetwork.SecurityRul WithValues("security-rule-name", rule.Name) var ( - prefixIndex = fnutil.IndexSet(prefixes) // Used to check whether the prefix should be removed. - currentPrefixes = ListDestinationPrefixes(rule) - - expectedPrefixes = prefixIndex.SubtractedBy(currentPrefixes) // The prefixes to keep. - targetPrefixes = fnutil.Intersection(currentPrefixes, prefixes) // The prefixes to remove. + dstPrefixes = NewAddressPrefixes(ListDestinationPrefixes(rule)) + targetPrefixes, targetServiceTags = SeparateIPsAndServiceTags(prefixes) ) + dstPrefixes.RemoveIPPrefixes(targetPrefixes...) + dstPrefixes.RemoveServiceTags(targetServiceTags...) + expectedPrefixes := dstPrefixes.StringSlice() // Clean DenyAll rule if *rule.Properties.Access == armnetwork.SecurityRuleAccessDeny && len(retainDstPorts) == 0 { diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go b/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go index 09b6435b31..511e997fae 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go +++ b/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go @@ -1570,14 +1570,14 @@ func TestRuleHelper_RemoveDestinationFromRules(t *testing.T) { { Name: ptr.To("test-rule-2"), Properties: &armnetwork.SecurityRulePropertiesFormat{ - Protocol: to.Ptr(armnetwork.SecurityRuleProtocolTCP), - Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), - Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), - SourceAddressPrefix: ptr.To("*"), - SourcePortRange: ptr.To("*"), - DestinationAddressPrefixes: to.SliceOfPtrs("8.8.8.8"), - DestinationPortRanges: to.SliceOfPtrs("5000"), - Priority: ptr.To(int32(502)), + Protocol: to.Ptr(armnetwork.SecurityRuleProtocolTCP), + Access: to.Ptr(armnetwork.SecurityRuleAccessAllow), + Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefix: ptr.To("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefix: to.Ptr("8.8.8.8"), + DestinationPortRanges: to.SliceOfPtrs("5000"), + Priority: ptr.To(int32(502)), }, }, { diff --git a/pkg/provider/loadbalancer/securitygroup/securityrule.go b/pkg/provider/loadbalancer/securitygroup/securityrule.go index b688c0270f..1cd843d44f 100644 --- a/pkg/provider/loadbalancer/securitygroup/securityrule.go +++ b/pkg/provider/loadbalancer/securitygroup/securityrule.go @@ -19,7 +19,6 @@ package securitygroup import ( "crypto/md5" //nolint:gosec "fmt" - "net/netip" "sort" "strconv" "strings" @@ -153,33 +152,3 @@ func ProtocolFromKubernetes(p v1.Protocol) (armnetwork.SecurityRuleProtocol, err } return "", fmt.Errorf("unsupported protocol %s", p) } - -// SeparateIPsAndServiceTags divides a list of prefixes into IP addresses/ranges and Azure service tags. -// -// The input prefixes can be sourced from networkSecurityGroup.SourceAddressPrefixes or -// networkSecurityGroup.DestinationAddressPrefixes, which are of type []string and may contain -// both IP addresses/ranges and Azure service tags. -// -// Returns: -// - []netip.Prefix: A slice of IP addresses and ranges parsed as netip.Prefix -// - []string: A slice of Azure service tags -func SeparateIPsAndServiceTags(prefixes []string) ([]netip.Prefix, []string) { - var ( - ips []netip.Prefix - serviceTags []string - ) - - for _, p := range prefixes { - if addr, err := netip.ParseAddr(p); err == nil { - ips = append(ips, netip.PrefixFrom(addr, addr.BitLen())) - continue - } - if prefix, err := netip.ParsePrefix(p); err == nil { - ips = append(ips, prefix) - continue - } - serviceTags = append(serviceTags, p) - } - - return ips, serviceTags -} From d50a2a770143cfc1c9b03fbc7250769beaf66532 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Thu, 3 Oct 2024 12:36:52 +1000 Subject: [PATCH 3/5] Refactor Prefix Tree --- pkg/provider/loadbalancer/accesscontrol.go | 4 +- pkg/provider/loadbalancer/fnutil/slice.go | 6 + pkg/provider/loadbalancer/fnutil/string.go | 9 + .../loadbalancer/iputil/internal/prefix.go | 15 + .../iputil/internal/prefix_test.go | 70 +++++ .../prefixtree.go} | 80 +++++- .../prefixtree_test.go} | 265 +++++++++++++----- .../fuzz/FuzzPrefixTree/26ad160f482a1840 | 3 + .../fuzz/FuzzPrefixTree/3863b12be77d5992 | 3 + .../fuzz/FuzzPrefixTree/3c2ebd299c6fbdb9 | 3 + pkg/provider/loadbalancer/iputil/prefix.go | 10 +- .../securitygroup/securitygroup.go | 8 +- 12 files changed, 376 insertions(+), 100 deletions(-) create mode 100644 pkg/provider/loadbalancer/fnutil/string.go create mode 100644 pkg/provider/loadbalancer/iputil/internal/prefix.go create mode 100644 pkg/provider/loadbalancer/iputil/internal/prefix_test.go rename pkg/provider/loadbalancer/iputil/{prefix_tree.go => internal/prefixtree.go} (62%) rename pkg/provider/loadbalancer/iputil/{prefix_tree_test.go => internal/prefixtree_test.go} (58%) create mode 100644 pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/26ad160f482a1840 create mode 100644 pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/3863b12be77d5992 create mode 100644 pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/3c2ebd299c6fbdb9 diff --git a/pkg/provider/loadbalancer/accesscontrol.go b/pkg/provider/loadbalancer/accesscontrol.go index 9372895b1c..ef7e32da43 100644 --- a/pkg/provider/loadbalancer/accesscontrol.go +++ b/pkg/provider/loadbalancer/accesscontrol.go @@ -281,8 +281,8 @@ func (ac *AccessControl) CleanSecurityGroup( logger.V(10).Info("Start cleaning") var ( - ipv4Prefixes = fnutil.Map(func(addr netip.Addr) string { return addr.String() }, dstIPv4Addresses) - ipv6Prefixes = fnutil.Map(func(addr netip.Addr) string { return addr.String() }, dstIPv6Addresses) + ipv4Prefixes = fnutil.Map(fnutil.AsString, dstIPv4Addresses) + ipv6Prefixes = fnutil.Map(fnutil.AsString, dstIPv6Addresses) ) protocols := []armnetwork.SecurityRuleProtocol{ diff --git a/pkg/provider/loadbalancer/fnutil/slice.go b/pkg/provider/loadbalancer/fnutil/slice.go index fb1b39e7a3..1a10238169 100644 --- a/pkg/provider/loadbalancer/fnutil/slice.go +++ b/pkg/provider/loadbalancer/fnutil/slice.go @@ -102,6 +102,12 @@ func (xs *IndexSetWithComparableIndex[I, D]) SubtractedBy(ys []D) []D { return rv } +// Intersection returns the elements that are in both xs and ys. func Intersection[D comparable](xs, ys []D) []D { return IndexSet(xs).Intersection(ys) } + +// Difference returns the elements in xs but not in ys. +func Difference[D comparable](xs, ys []D) []D { + return IndexSet(ys).SubtractedBy(xs) +} diff --git a/pkg/provider/loadbalancer/fnutil/string.go b/pkg/provider/loadbalancer/fnutil/string.go new file mode 100644 index 0000000000..da1aecdebc --- /dev/null +++ b/pkg/provider/loadbalancer/fnutil/string.go @@ -0,0 +1,9 @@ +package fnutil + +type Stringer interface { + String() string +} + +func AsString[T Stringer](v T) string { + return v.String() +} diff --git a/pkg/provider/loadbalancer/iputil/internal/prefix.go b/pkg/provider/loadbalancer/iputil/internal/prefix.go new file mode 100644 index 0000000000..91b6413d5f --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/internal/prefix.go @@ -0,0 +1,15 @@ +package internal + +import ( + "net/netip" +) + +func ListAddresses(prefixes ...netip.Prefix) []netip.Addr { + var rv []netip.Addr + for _, p := range prefixes { + for addr := p.Addr(); p.Contains(addr); addr = addr.Next() { + rv = append(rv, addr) + } + } + return rv +} diff --git a/pkg/provider/loadbalancer/iputil/internal/prefix_test.go b/pkg/provider/loadbalancer/iputil/internal/prefix_test.go new file mode 100644 index 0000000000..c5795e677c --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/internal/prefix_test.go @@ -0,0 +1,70 @@ +package internal + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestListAddresses(t *testing.T) { + tests := []struct { + Name string + Prefixes []netip.Prefix + Expected []netip.Addr + }{ + { + Name: "Empty", + }, + { + Name: "Single IPv4 Address", + Prefixes: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + Expected: []netip.Addr{netip.MustParseAddr("192.168.1.1")}, + }, + { + Name: "IPv4 Subnet", + Prefixes: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/30")}, + Expected: []netip.Addr{ + netip.MustParseAddr("192.168.1.0"), + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("192.168.1.2"), + netip.MustParseAddr("192.168.1.3"), + }, + }, + { + Name: "Single IPv6 Address", + Prefixes: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")}, + Expected: []netip.Addr{netip.MustParseAddr("2001:db8::1")}, + }, + { + Name: "IPv6 Subnet", + Prefixes: []netip.Prefix{netip.MustParsePrefix("2001:db8::/126")}, + Expected: []netip.Addr{ + netip.MustParseAddr("2001:db8::"), + netip.MustParseAddr("2001:db8::1"), + netip.MustParseAddr("2001:db8::2"), + netip.MustParseAddr("2001:db8::3"), + }, + }, + { + Name: "Multiple Prefixes", + Prefixes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/31"), + netip.MustParsePrefix("2001:db8::/127"), + }, + Expected: []netip.Addr{ + netip.MustParseAddr("192.168.1.0"), + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("2001:db8::"), + netip.MustParseAddr("2001:db8::1"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + actual := ListAddresses(tt.Prefixes...) + assert.Equal(t, tt.Expected, actual) + }) + } +} diff --git a/pkg/provider/loadbalancer/iputil/prefix_tree.go b/pkg/provider/loadbalancer/iputil/internal/prefixtree.go similarity index 62% rename from pkg/provider/loadbalancer/iputil/prefix_tree.go rename to pkg/provider/loadbalancer/iputil/internal/prefixtree.go index 4135a35740..3a1ac96730 100644 --- a/pkg/provider/loadbalancer/iputil/prefix_tree.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefixtree.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package iputil +package internal import ( "net/netip" @@ -36,6 +36,8 @@ type prefixTreeNode struct { r *prefixTreeNode // right child node } +// NewLeftChild creates a new left child node for the current node. +// No checks are performed to see if the child already exists. func (n *prefixTreeNode) NewLeftChild() *prefixTreeNode { prefix := netip.PrefixFrom(n.prefix.Addr(), n.prefix.Bits()+1) n.l = &prefixTreeNode{ @@ -45,6 +47,8 @@ func (n *prefixTreeNode) NewLeftChild() *prefixTreeNode { return n.l } +// NewRightChild creates a new right child node for the current node. +// No checks are performed to see if the child already exists. func (n *prefixTreeNode) NewRightChild() *prefixTreeNode { prefixBytes := n.prefix.Addr().AsSlice() { @@ -63,11 +67,27 @@ func (n *prefixTreeNode) NewRightChild() *prefixTreeNode { return n.r } -// MaskAndPruneToRoot masks the current node and prunes the tree upwards. -// It recursively checks parent nodes, masking and pruning them if both -// children are masked. This process continues until reaching the root -// or a node that cannot be pruned. -func (n *prefixTreeNode) MaskAndPruneToRoot() { +// CondenseUntilRoot checks if the current node and its sibling are masked, +// and if so, marks their parent as masked and removes both children. +// This process is repeated up the tree until a node with an unmasked sibling is found. +// +// The process can be visualized as follows: +// +// Before: After: +// P P (masked) +// / \ / \ +// A B -> X X +// (M) (M) +// +// Where: +// +// P: Parent node +// A, B: Child nodes +// M: Masked +// X: Removed +// +// This method helps to optimize the tree structure by condensing fully masked subtrees. +func (n *prefixTreeNode) CondenseUntilRoot() { var node = n for node.p != nil { p := node.p @@ -83,13 +103,40 @@ func (n *prefixTreeNode) MaskAndPruneToRoot() { } } -type prefixTree struct { +// PrefixTree represents a tree structure for storing and managing IP prefixes. +// It efficiently handles prefix aggregation, merging of overlapping prefixes, +// and collapsing of neighboring prefixes. +// +// The tree is structured as follows: +// - Each node represents a bit in the IP address +// - Left child represents a 0 bit, right child represents a 1 bit +// - Masked nodes indicate the end of a prefix +// - Unused branches are represented by nil pointers +// +// Example tree for 128.0.0.0/4 (binary 1000 0000): +// +// 0 (0.0.0.0/0) +// / \ +// X 1 (128.0.0.0/1) +// / \ +// 0 X +// / \ +// 0 X +// / \ +// 0* X +// +// Where: +// * denotes a masked node (prefix end) +// X denotes an unused branch (nil pointer) +type PrefixTree struct { maxBits int root *prefixTreeNode } -func newPrefixTreeForIPv4() *prefixTree { - return &prefixTree{ +// NewPrefixTreeForIPv4 creates a new prefix tree for IPv4 addresses. +// The max depth of the tree is 32 + 1 (for the root). +func NewPrefixTreeForIPv4() *PrefixTree { + return &PrefixTree{ maxBits: 32, root: &prefixTreeNode{ prefix: netip.MustParsePrefix("0.0.0.0/0"), @@ -97,8 +144,10 @@ func newPrefixTreeForIPv4() *prefixTree { } } -func newPrefixTreeForIPv6() *prefixTree { - return &prefixTree{ +// NewPrefixTreeForIPv6 creates a new prefix tree for IPv6 addresses. +// The max depth of the tree is 128 + 1 (for the root). +func NewPrefixTreeForIPv6() *PrefixTree { + return &PrefixTree{ maxBits: 128, root: &prefixTreeNode{ prefix: netip.MustParsePrefix("::/0"), @@ -107,7 +156,8 @@ func newPrefixTreeForIPv6() *prefixTree { } // Add adds a prefix to the tree. -func (t *prefixTree) Add(prefix netip.Prefix) { +// It will merge overlapping prefixes and collapse neighboring prefixes if possible. +func (t *PrefixTree) Add(prefix netip.Prefix) { var ( n = t.root bytes = prefix.Addr().AsSlice() @@ -132,12 +182,12 @@ func (t *prefixTree) Add(prefix netip.Prefix) { n.masked = true n.l, n.r = nil, nil - n.MaskAndPruneToRoot() + n.CondenseUntilRoot() } // Remove removes a prefix from the tree. // If the prefix is not in the tree, it does nothing. -func (t *prefixTree) Remove(prefix netip.Prefix) { +func (t *PrefixTree) Remove(prefix netip.Prefix) { var ( n = t.root bytes = prefix.Addr().AsSlice() @@ -185,7 +235,7 @@ func (t *prefixTree) Remove(prefix netip.Prefix) { // Example: // - [192.168.0.0/16, 192.168.1.0/24, 192.168.0.1/32] -> [192.168.0.0/16] // - [192.168.0.0/32, 192.168.0.1/32] -> [192.168.0.0/31] -func (t *prefixTree) List() []netip.Prefix { +func (t *PrefixTree) List() []netip.Prefix { var ( rv []netip.Prefix q = []*prefixTreeNode{t.root} diff --git a/pkg/provider/loadbalancer/iputil/prefix_tree_test.go b/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go similarity index 58% rename from pkg/provider/loadbalancer/iputil/prefix_tree_test.go rename to pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go index 72eddbbc12..0b0c797932 100644 --- a/pkg/provider/loadbalancer/iputil/prefix_tree_test.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go @@ -14,17 +14,44 @@ See the License for the specific language governing permissions and limitations under the License. */ -package iputil +package internal import ( + "fmt" "math" "net/netip" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" "sort" "testing" "github.com/stretchr/testify/assert" ) +func Test_bitAt(t *testing.T) { + bytes := []byte{0b1010_1010, 0b0101_0101} + assert.Equal(t, uint8(1), bitAt(bytes, 0)) + assert.Equal(t, uint8(0), bitAt(bytes, 1)) + assert.Equal(t, uint8(1), bitAt(bytes, 2)) + assert.Equal(t, uint8(0), bitAt(bytes, 3)) + + assert.Equal(t, uint8(1), bitAt(bytes, 4)) + assert.Equal(t, uint8(0), bitAt(bytes, 5)) + assert.Equal(t, uint8(1), bitAt(bytes, 6)) + assert.Equal(t, uint8(0), bitAt(bytes, 7)) + + assert.Equal(t, uint8(0), bitAt(bytes, 8)) + assert.Equal(t, uint8(1), bitAt(bytes, 9)) + assert.Equal(t, uint8(0), bitAt(bytes, 10)) + assert.Equal(t, uint8(1), bitAt(bytes, 11)) + + assert.Equal(t, uint8(0), bitAt(bytes, 12)) + assert.Equal(t, uint8(1), bitAt(bytes, 13)) + assert.Equal(t, uint8(0), bitAt(bytes, 14)) + assert.Equal(t, uint8(1), bitAt(bytes, 15)) + + assert.Panics(t, func() { bitAt(bytes, 16) }) +} + func TestPrefixTreeIPv4(t *testing.T) { tests := []struct { Name string @@ -86,7 +113,7 @@ func TestPrefixTreeIPv4(t *testing.T) { for _, tt := range tests { t.Run(tt.Name, func(t *testing.T) { - var tree = newPrefixTreeForIPv4() + var tree = NewPrefixTreeForIPv4() for _, ip := range tt.Input { p := netip.MustParsePrefix(ip) tree.Add(p) @@ -172,7 +199,7 @@ func TestPrefixTreeIPv6(t *testing.T) { for _, tt := range tests { t.Run(tt.Name, func(t *testing.T) { - var tree = newPrefixTreeForIPv6() + var tree = NewPrefixTreeForIPv6() for _, ip := range tt.Input { p := netip.MustParsePrefix(ip) tree.Add(p) @@ -191,76 +218,6 @@ func TestPrefixTreeIPv6(t *testing.T) { } } -func BenchmarkPrefixTree_Add(b *testing.B) { - b.Run("IPv4", func(b *testing.B) { - var tree = newPrefixTreeForIPv4() - for i := 0; i < b.N; i++ { - addr := netip.AddrFrom4([4]byte{ - byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), - }) - prefix, _ := addr.Prefix(32) - - tree.Add(prefix) - } - }) - - b.Run("IPv6", func(b *testing.B) { - var tree = newPrefixTreeForIPv6() - for i := 0; i < b.N; i++ { - addr := netip.AddrFrom16([16]byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - byte(i >> 56), byte(i >> 48), byte(i >> 40), byte(i >> 32), - byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), - }) - prefix, _ := addr.Prefix(128) - - tree.Add(prefix) - } - }) -} - -func BenchmarkPrefixTree_List(b *testing.B) { - - b.Run("IPv4", func(b *testing.B) { - b.StopTimer() - var tree = newPrefixTreeForIPv4() - for i := 0; i < math.MaxInt8; i++ { - addr := netip.AddrFrom4([4]byte{ - byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), - }) - prefix, err := addr.Prefix(32) - assert.NoError(b, err) - - tree.Add(prefix) - } - b.StartTimer() - for i := 0; i < b.N; i++ { - tree.List() - } - }) - - b.Run("IPv6", func(b *testing.B) { - b.StopTimer() - var tree = newPrefixTreeForIPv6() - for i := 0; i < math.MaxInt8; i++ { - addr := netip.AddrFrom16([16]byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - byte(i >> 56), byte(i >> 48), byte(i >> 40), byte(i >> 32), - byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), - }) - prefix, err := addr.Prefix(128) - assert.NoError(b, err) - - tree.Add(prefix) - } - b.StartTimer() - for i := 0; i < b.N; i++ { - tree.List() - } - }) -} func TestPrefixTree_Remove(t *testing.T) { tests := []struct { name string @@ -320,9 +277,9 @@ func TestPrefixTree_Remove(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tree := newPrefixTreeForIPv4() + tree := NewPrefixTreeForIPv4() if len(tt.add) > 0 && netip.MustParsePrefix(tt.add[0]).Addr().Is6() { - tree = newPrefixTreeForIPv6() + tree = NewPrefixTreeForIPv6() } for _, prefix := range tt.add { @@ -343,3 +300,161 @@ func TestPrefixTree_Remove(t *testing.T) { }) } } + +func benchmarkFixtureForIPv4(b *testing.B) *PrefixTree { + b.StopTimer() + defer b.StartTimer() + + var tree = NewPrefixTreeForIPv4() + for i := 0; i < math.MaxInt8; i++ { + prefix := netip.MustParsePrefix(fmt.Sprintf("%d.0.0.0/8", i)) + tree.Add(prefix) + } + return tree +} + +func benchmarkFixtureForIPv6(b *testing.B) *PrefixTree { + b.StopTimer() + defer b.StartTimer() + + var tree = NewPrefixTreeForIPv6() + for i := 0; i < math.MaxInt8; i++ { + prefix := netip.MustParsePrefix(fmt.Sprintf("2001:db8:%x::/64", i)) + tree.Add(prefix) + } + return tree +} + +func BenchmarkPrefixTree_AddAndRemove(b *testing.B) { + b.Run("IPv4", func(b *testing.B) { + tree := benchmarkFixtureForIPv4(b) + for i := 0; i < b.N; i++ { + prefix := netip.MustParsePrefix("10.10.10.0/24") + tree.Remove(prefix) + tree.Add(prefix) + } + }) + + b.Run("IPv6", func(b *testing.B) { + tree := benchmarkFixtureForIPv6(b) + for i := 0; i < b.N; i++ { + prefix := netip.MustParsePrefix("2001:db8:10:ff::/90") + tree.Remove(prefix) + tree.Add(prefix) + } + }) +} + +func BenchmarkPrefixTree_List(b *testing.B) { + + b.Run("IPv4", func(b *testing.B) { + tree := benchmarkFixtureForIPv4(b) + for i := 0; i < b.N; i++ { + tree.List() + } + }) + + b.Run("IPv6", func(b *testing.B) { + tree := benchmarkFixtureForIPv6(b) + for i := 0; i < b.N; i++ { + tree.List() + } + }) +} + +func FuzzPrefixTree(f *testing.F) { + // To reduce fuzzing time + const ( + MinIPv4Bits = 20 + MinIPv6Bits = 118 + ) + var ( + InitialPrefixIPv4 = netip.MustParsePrefix(fmt.Sprintf("0.0.0.0/%d", MinIPv4Bits)) + InitialPrefixIPv6 = netip.MustParsePrefix(fmt.Sprintf("::/%d", MinIPv6Bits)) + InitialIPv4 = fnutil.Map(fnutil.AsString, ListAddresses(InitialPrefixIPv4)) + InitialIPv6 = fnutil.Map(fnutil.AsString, ListAddresses(InitialPrefixIPv6)) + ) + + f.Add( + netip.MustParseAddr("192.168.0.0").AsSlice(), + 24, + ) + f.Add( + netip.MustParseAddr("2001:db8::").AsSlice(), + 64, + ) + + f.Fuzz(func(t *testing.T, ip []byte, bits int) { + var ( + targetPrefix netip.Prefix + targetAddresses []string + ) + { + addr, ok := netip.AddrFromSlice(ip) + if !ok || + (addr.Is4() && !InitialPrefixIPv4.Contains(addr)) || + (addr.Is6() && !InitialPrefixIPv6.Contains(addr)) { + // Skip invalid addresses + t.SkipNow() + return + } + if bits < 0 || + (addr.Is4() && (bits <= MinIPv4Bits || bits > 32)) || + (addr.Is6() && (bits <= MinIPv6Bits || bits > 128)) { + // Skip invalid bit lengths + t.SkipNow() + return + } + p, err := addr.Prefix(bits) + assert.NoError(t, err) + targetPrefix = p + targetAddresses = fnutil.Map(fnutil.AsString, ListAddresses(targetPrefix)) + } + fmt.Printf("target-prefix: %s\n", targetPrefix.String()) + + var ( + tree *PrefixTree + allAddresses []string + initPrefix netip.Prefix + ) + if targetPrefix.Addr().Is4() { + tree = NewPrefixTreeForIPv4() + tree.Add(InitialPrefixIPv4) + initPrefix = InitialPrefixIPv4 + allAddresses = InitialIPv4 + } else { + tree = NewPrefixTreeForIPv6() + tree.Add(InitialPrefixIPv6) + initPrefix = InitialPrefixIPv6 + allAddresses = InitialIPv6 + } + + tree.Remove(targetPrefix) + { + prefixes := tree.List() + addresses := fnutil.Map(fnutil.AsString, ListAddresses(prefixes...)) + + assert.Empty( + t, fnutil.Intersection(targetAddresses, addresses), + "actual-prefixes: %s, target-prefixes: %s", + fnutil.Map(fnutil.AsString, prefixes), + targetPrefix.String(), + ) + assert.ElementsMatch( + t, addresses, fnutil.Difference(allAddresses, targetAddresses), + "actual-prefixes: %s, target-prefixes: %s", + fnutil.Map(fnutil.AsString, prefixes), + targetPrefix.String(), + ) + } + + tree.Add(targetPrefix) + { + prefixes := tree.List() + addresses := fnutil.Map(fnutil.AsString, ListAddresses(prefixes...)) + + assert.Equal(t, []string{initPrefix.String()}, fnutil.Map(fnutil.AsString, prefixes)) + assert.ElementsMatch(t, addresses, allAddresses) + } + }) +} diff --git a/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/26ad160f482a1840 b/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/26ad160f482a1840 new file mode 100644 index 0000000000..f64a46cce3 --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/26ad160f482a1840 @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("0000") +int(30) diff --git a/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/3863b12be77d5992 b/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/3863b12be77d5992 new file mode 100644 index 0000000000..31fa1df03d --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/3863b12be77d5992 @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("\x00\x00&0") +int(17) diff --git a/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/3c2ebd299c6fbdb9 b/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/3c2ebd299c6fbdb9 new file mode 100644 index 0000000000..31d641daba --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/internal/testdata/fuzz/FuzzPrefixTree/3c2ebd299c6fbdb9 @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("\x00\x00\x00@") +int(24) diff --git a/pkg/provider/loadbalancer/iputil/prefix.go b/pkg/provider/loadbalancer/iputil/prefix.go index dd9fbedb82..32fb141fd7 100644 --- a/pkg/provider/loadbalancer/iputil/prefix.go +++ b/pkg/provider/loadbalancer/iputil/prefix.go @@ -19,6 +19,8 @@ package iputil import ( "fmt" "net/netip" + + "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil/internal" ) // IsPrefixesAllowAll returns true if one of the prefixes allows all addresses. @@ -80,8 +82,8 @@ func GroupPrefixesByFamily(vs []netip.Prefix) ([]netip.Prefix, []netip.Prefix) { func AggregatePrefixes(prefixes []netip.Prefix) []netip.Prefix { var ( v4, v6 = GroupPrefixesByFamily(prefixes) - v4Tree = newPrefixTreeForIPv4() - v6Tree = newPrefixTreeForIPv6() + v4Tree = internal.NewPrefixTreeForIPv4() + v6Tree = internal.NewPrefixTreeForIPv6() ) for _, p := range v4 { @@ -97,8 +99,8 @@ func AggregatePrefixes(prefixes []netip.Prefix) []netip.Prefix { // ExcludePrefixes excludes prefixes from the given prefixes. func ExcludePrefixes(prefixes []netip.Prefix, exclude []netip.Prefix) []netip.Prefix { var ( - v4Tree = newPrefixTreeForIPv4() - v6Tree = newPrefixTreeForIPv6() + v4Tree = internal.NewPrefixTreeForIPv4() + v6Tree = internal.NewPrefixTreeForIPv6() ) // Build the prefix tree for the prefixes. diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup.go b/pkg/provider/loadbalancer/securitygroup/securitygroup.go index ed949c37a4..1c61fe7060 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup.go +++ b/pkg/provider/loadbalancer/securitygroup/securitygroup.go @@ -218,7 +218,7 @@ func (helper *RuleHelper) AddRuleForAllowedServiceTag( var ( ipFamily = iputil.FamilyOfAddr(dstAddresses[0]) srcPrefixes = []string{serviceTag} - dstPrefixes = fnutil.Map(func(ip netip.Addr) string { return ip.String() }, dstAddresses) + dstPrefixes = fnutil.Map(fnutil.AsString, dstAddresses) ) helper.logger.V(4).Info("Patching a rule for allowed service tag", "ip-family", ipFamily) @@ -245,8 +245,8 @@ func (helper *RuleHelper) AddRuleForAllowedIPRanges( var ( ipFamily = iputil.FamilyOfAddr(ipRanges[0].Addr()) - srcPrefixes = fnutil.Map(func(ip netip.Prefix) string { return ip.String() }, ipRanges) - dstPrefixes = fnutil.Map(func(ip netip.Addr) string { return ip.String() }, dstAddresses) + srcPrefixes = fnutil.Map(fnutil.AsString, ipRanges) + dstPrefixes = fnutil.Map(fnutil.AsString, dstAddresses) ) helper.logger.V(4).Info("Patching a rule for allowed IP ranges", "ip-family", ipFamily) @@ -284,7 +284,7 @@ func (helper *RuleHelper) AddRuleForDenyAll(dstAddresses []netip.Addr) error { } { // Destination - addresses := fnutil.Map(func(ip netip.Addr) string { return ip.String() }, dstAddresses) + addresses := fnutil.Map(fnutil.AsString, dstAddresses) // Tidy up and aggregate the destination prefixes. addressPrefixes := NewAddressPrefixes(append(addresses, ListDestinationPrefixes(rule)...)) From 005139a9e0021414617842758ea13e2fac03dca3 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Thu, 3 Oct 2024 12:44:48 +1000 Subject: [PATCH 4/5] Fix license heander --- .../loadbalancer/iputil/internal/prefix.go | 20 +++++++++++++++++++ .../iputil/internal/prefix_test.go | 16 +++++++++++++++ .../iputil/internal/prefixtree_test.go | 3 ++- 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/pkg/provider/loadbalancer/iputil/internal/prefix.go b/pkg/provider/loadbalancer/iputil/internal/prefix.go index 91b6413d5f..96fee0abf9 100644 --- a/pkg/provider/loadbalancer/iputil/internal/prefix.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefix.go @@ -1,9 +1,29 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package internal import ( "net/netip" ) +// ListAddresses returns all IP addresses contained within the given prefixes. +// Note: This function is not optimized for large address ranges. +// It may consume significant memory and perform poorly when listing +// a large number of addresses. Use with caution on large prefixes. func ListAddresses(prefixes ...netip.Prefix) []netip.Addr { var rv []netip.Addr for _, p := range prefixes { diff --git a/pkg/provider/loadbalancer/iputil/internal/prefix_test.go b/pkg/provider/loadbalancer/iputil/internal/prefix_test.go index c5795e677c..f422c2cad0 100644 --- a/pkg/provider/loadbalancer/iputil/internal/prefix_test.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefix_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package internal import ( diff --git a/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go b/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go index 0b0c797932..e2ae5308f4 100644 --- a/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go @@ -20,11 +20,12 @@ import ( "fmt" "math" "net/netip" - "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" "sort" "testing" "github.com/stretchr/testify/assert" + + "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" ) func Test_bitAt(t *testing.T) { From 2628df2287733f27a5a078e3341def67c943c3e0 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Sat, 5 Oct 2024 21:25:38 +1000 Subject: [PATCH 5/5] Optimize AggregatePrefixes performance --- .../loadbalancer/iputil/internal/bits.go | 37 +++++++ .../loadbalancer/iputil/internal/bits_test.go | 103 ++++++++++++++++++ .../loadbalancer/iputil/internal/prefix.go | 82 ++++++++++++++ .../iputil/internal/prefix_test.go | 48 ++++++++ .../iputil/internal/prefixtree.go | 7 -- .../iputil/internal/prefixtree_test.go | 25 ----- pkg/provider/loadbalancer/iputil/prefix.go | 25 +++-- .../loadbalancer/iputil/prefix_test.go | 7 +- 8 files changed, 290 insertions(+), 44 deletions(-) create mode 100644 pkg/provider/loadbalancer/iputil/internal/bits.go create mode 100644 pkg/provider/loadbalancer/iputil/internal/bits_test.go diff --git a/pkg/provider/loadbalancer/iputil/internal/bits.go b/pkg/provider/loadbalancer/iputil/internal/bits.go new file mode 100644 index 0000000000..be36f55a7b --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/internal/bits.go @@ -0,0 +1,37 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +// setBitAt sets the bit at the i-th position in the byte slice to the given value. +// Panics if the index is out of bounds. +// For example, +// - setBitAt([0x00, 0x00], 8, 1) returns [0x00, 0b1000_0000]. +// - setBitAt([0xff, 0xff], 0, 0) returns [0b0111_1111, 0xff]. +func setBitAt(bytes []byte, i int, bit uint8) { + if bit == 1 { + bytes[i/8] |= 1 << (7 - i%8) + } else { + bytes[i/8] &^= 1 << (7 - i%8) + } +} + +// bitAt returns the bit at the i-th position in the byte slice. +// The return value is either 0 or 1 as uint8. +// Panics if the index is out of bounds. +func bitAt(bytes []byte, i int) uint8 { + return bytes[i/8] >> (7 - i%8) & 1 +} diff --git a/pkg/provider/loadbalancer/iputil/internal/bits_test.go b/pkg/provider/loadbalancer/iputil/internal/bits_test.go new file mode 100644 index 0000000000..fad72aabb2 --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/internal/bits_test.go @@ -0,0 +1,103 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_bitAt(t *testing.T) { + bytes := []byte{0b1010_1010, 0b0101_0101} + assert.Equal(t, uint8(1), bitAt(bytes, 0)) + assert.Equal(t, uint8(0), bitAt(bytes, 1)) + assert.Equal(t, uint8(1), bitAt(bytes, 2)) + assert.Equal(t, uint8(0), bitAt(bytes, 3)) + + assert.Equal(t, uint8(1), bitAt(bytes, 4)) + assert.Equal(t, uint8(0), bitAt(bytes, 5)) + assert.Equal(t, uint8(1), bitAt(bytes, 6)) + assert.Equal(t, uint8(0), bitAt(bytes, 7)) + + assert.Equal(t, uint8(0), bitAt(bytes, 8)) + assert.Equal(t, uint8(1), bitAt(bytes, 9)) + assert.Equal(t, uint8(0), bitAt(bytes, 10)) + assert.Equal(t, uint8(1), bitAt(bytes, 11)) + + assert.Equal(t, uint8(0), bitAt(bytes, 12)) + assert.Equal(t, uint8(1), bitAt(bytes, 13)) + assert.Equal(t, uint8(0), bitAt(bytes, 14)) + assert.Equal(t, uint8(1), bitAt(bytes, 15)) + + assert.Panics(t, func() { bitAt(bytes, 16) }) +} + +func Test_setBitAt(t *testing.T) { + tests := []struct { + name string + initial []byte + index int + bit uint8 + expected []byte + }{ + { + name: "Set first bit to 1", + initial: []byte{0b0000_0000}, + index: 0, + bit: 1, + expected: []byte{0b1000_0000}, + }, + { + name: "Set last bit to 1", + initial: []byte{0b0000_0000}, + index: 7, + bit: 1, + expected: []byte{0b0000_0001}, + }, + { + name: "Set middle bit to 1", + initial: []byte{0b0000_0000}, + index: 4, + bit: 1, + expected: []byte{0b0000_1000}, + }, + { + name: "Set bit to 0", + initial: []byte{0b1111_1111}, + index: 3, + bit: 0, + expected: []byte{0b1110_1111}, + }, + { + name: "Set bit in second byte", + initial: []byte{0b0000_0000, 0b0000_0000}, + index: 9, + bit: 1, + expected: []byte{0b0000_0000, 0b0100_0000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setBitAt(tt.initial, tt.index, tt.bit) + assert.Equal(t, tt.expected, tt.initial) + }) + } + + assert.Panics(t, func() { setBitAt([]byte{0x00}, 8, 1) }) +} diff --git a/pkg/provider/loadbalancer/iputil/internal/prefix.go b/pkg/provider/loadbalancer/iputil/internal/prefix.go index 96fee0abf9..956a90a3c1 100644 --- a/pkg/provider/loadbalancer/iputil/internal/prefix.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefix.go @@ -18,6 +18,7 @@ package internal import ( "net/netip" + "sort" ) // ListAddresses returns all IP addresses contained within the given prefixes. @@ -33,3 +34,84 @@ func ListAddresses(prefixes ...netip.Prefix) []netip.Addr { } return rv } + +// ContainsPrefix checks if prefix p fully contains prefix o. +// It returns true if o is a subset of p, meaning all addresses in o are also in p. +// This is true when p overlaps with o and p has fewer or equal number of bits than o. +func ContainsPrefix(p netip.Prefix, o netip.Prefix) bool { + return p.Overlaps(o) && p.Bits() <= o.Bits() +} + +// IsAdjacentPrefixes checks if two prefixes are adjacent and can be merged into a single prefix. +// Two prefixes are considered adjacent if they have the same length and their +// addresses are consecutive. +// +// Examples: +// - Adjacent: 192.168.0.0/32 and 192.168.0.1/32 +// - Adjacent: 192.168.0.0/24 and 192.168.1.0/24 +// - Not adjacent: 192.168.0.1/32 and 192.168.0.2/32 (cannot merge) +// - Not adjacent: 192.168.0.0/24 and 192.168.0.0/25 (different lengths) +func IsAdjacentPrefixes(p1, p2 netip.Prefix) bool { + if p1.Bits() != p2.Bits() { + return false + } + + p1Bytes := p1.Addr().AsSlice() + p2Bytes := p2.Addr().AsSlice() + + if bitAt(p1Bytes, p1.Bits()-1) == 0 { + setBitAt(p1Bytes, p1.Bits()-1, 1) + addr, _ := netip.AddrFromSlice(p1Bytes) + return addr == p2.Addr() + } else { + setBitAt(p2Bytes, p2.Bits()-1, 1) + addr, _ := netip.AddrFromSlice(p2Bytes) + return addr == p1.Addr() + } +} + +// AggregatePrefixesForSingleIPFamily merges overlapping or adjacent prefixes into a single prefix. +// The input prefixes must be the same IP family (IPv4 or IPv6). +// For example, +// - [192.168.0.0/32, 192.168.0.1/32] -> [192.168.0.0/31] (adjacent) +// - [192.168.0.0/24, 192.168.0.1/32] -> [192.168.1.0/24] (overlapping) +func AggregatePrefixesForSingleIPFamily(prefixes []netip.Prefix) []netip.Prefix { + if len(prefixes) <= 1 { + return prefixes + } + + sort.Slice(prefixes, func(i, j int) bool { + if prefixes[i].Addr() == prefixes[j].Addr() { + return prefixes[i].Bits() < prefixes[j].Bits() + } + return prefixes[i].Addr().Less(prefixes[j].Addr()) + }) + + var rv = []netip.Prefix{ + prefixes[0], + } + + for i := 1; i < len(prefixes); { + last, p := rv[len(rv)-1], prefixes[i] + if ContainsPrefix(last, p) { + // Skip overlapping prefixes + i++ + continue + } + rv = append(rv, p) + + // Merge adjacent prefixes if possible + for len(rv) >= 2 { + p1, p2 := rv[len(rv)-2], rv[len(rv)-1] + if !IsAdjacentPrefixes(p1, p2) { + break + } + + bits := p1.Bits() - 1 + p, _ := p1.Addr().Prefix(bits) + rv = rv[:len(rv)-2] + rv = append(rv, p) + } + } + return rv +} diff --git a/pkg/provider/loadbalancer/iputil/internal/prefix_test.go b/pkg/provider/loadbalancer/iputil/internal/prefix_test.go index f422c2cad0..694d18aeab 100644 --- a/pkg/provider/loadbalancer/iputil/internal/prefix_test.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefix_test.go @@ -17,6 +17,8 @@ limitations under the License. package internal import ( + "fmt" + "math/rand" "net/netip" "testing" @@ -84,3 +86,49 @@ func TestListAddresses(t *testing.T) { }) } } + +func benchmarkPrefixFixtures() []netip.Prefix { + var rv []netip.Prefix + for i := 0; i <= 255; i++ { + for j := 0; j <= 255; j++ { + rv = append(rv, netip.MustParsePrefix(fmt.Sprintf("192.168.%d.%d/32", i, j))) + } + } + rand.Shuffle(len(rv), func(i, j int) { + rv[i], rv[j] = rv[j], rv[i] + }) + + return rv +} + +func BenchmarkAggregatePrefixesDefault(b *testing.B) { + prefixes := benchmarkPrefixFixtures() + b.ResetTimer() + for i := 0; i < b.N; i++ { + actual := AggregatePrefixesForSingleIPFamily(prefixes) + assert.Len(b, actual, 1) + assert.Equal(b, []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, actual) + } +} + +func BenchmarkAggregatePrefixesUsingPrefixTree(b *testing.B) { + do := func(prefixes []netip.Prefix) []netip.Prefix { + tree := NewPrefixTreeForIPv4() + for _, p := range prefixes { + tree.Add(p) + } + return tree.List() + } + + prefixes := benchmarkPrefixFixtures() + b.ResetTimer() + for i := 0; i < b.N; i++ { + actual := do(prefixes) + assert.Len(b, actual, 1) + assert.Equal(b, []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, actual) + } +} diff --git a/pkg/provider/loadbalancer/iputil/internal/prefixtree.go b/pkg/provider/loadbalancer/iputil/internal/prefixtree.go index 3a1ac96730..b6795ff22b 100644 --- a/pkg/provider/loadbalancer/iputil/internal/prefixtree.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefixtree.go @@ -20,13 +20,6 @@ import ( "net/netip" ) -// bitAt returns the bit at the i-th position in the byte slice. -// The return value is either 0 or 1 as uint8. -// Panics if the index is out of bounds. -func bitAt(bytes []byte, i int) uint8 { - return bytes[i/8] >> (7 - i%8) & 1 -} - type prefixTreeNode struct { masked bool prefix netip.Prefix diff --git a/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go b/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go index e2ae5308f4..6c190d94a7 100644 --- a/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go +++ b/pkg/provider/loadbalancer/iputil/internal/prefixtree_test.go @@ -28,31 +28,6 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" ) -func Test_bitAt(t *testing.T) { - bytes := []byte{0b1010_1010, 0b0101_0101} - assert.Equal(t, uint8(1), bitAt(bytes, 0)) - assert.Equal(t, uint8(0), bitAt(bytes, 1)) - assert.Equal(t, uint8(1), bitAt(bytes, 2)) - assert.Equal(t, uint8(0), bitAt(bytes, 3)) - - assert.Equal(t, uint8(1), bitAt(bytes, 4)) - assert.Equal(t, uint8(0), bitAt(bytes, 5)) - assert.Equal(t, uint8(1), bitAt(bytes, 6)) - assert.Equal(t, uint8(0), bitAt(bytes, 7)) - - assert.Equal(t, uint8(0), bitAt(bytes, 8)) - assert.Equal(t, uint8(1), bitAt(bytes, 9)) - assert.Equal(t, uint8(0), bitAt(bytes, 10)) - assert.Equal(t, uint8(1), bitAt(bytes, 11)) - - assert.Equal(t, uint8(0), bitAt(bytes, 12)) - assert.Equal(t, uint8(1), bitAt(bytes, 13)) - assert.Equal(t, uint8(0), bitAt(bytes, 14)) - assert.Equal(t, uint8(1), bitAt(bytes, 15)) - - assert.Panics(t, func() { bitAt(bytes, 16) }) -} - func TestPrefixTreeIPv4(t *testing.T) { tests := []struct { Name string diff --git a/pkg/provider/loadbalancer/iputil/prefix.go b/pkg/provider/loadbalancer/iputil/prefix.go index 32fb141fd7..73a73d2dcf 100644 --- a/pkg/provider/loadbalancer/iputil/prefix.go +++ b/pkg/provider/loadbalancer/iputil/prefix.go @@ -77,26 +77,29 @@ func GroupPrefixesByFamily(vs []netip.Prefix) ([]netip.Prefix, []netip.Prefix) { return v4, v6 } -// AggregatePrefixes aggregates prefixes. -// Overlapping prefixes are merged. +// AggregatePrefixes merges overlapping or adjacent prefixes. +// It combines prefixes that can be represented by a single, larger prefix. +// +// Examples: +// - [192.168.0.0/32, 192.168.0.1/32] -> [192.168.0.0/31] (adjacent) +// - [192.168.0.0/24, 192.168.0.1/32] -> [192.168.0.0/24] (overlapping) +// - [10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16] -> [10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16] (non-overlapping) func AggregatePrefixes(prefixes []netip.Prefix) []netip.Prefix { var ( v4, v6 = GroupPrefixesByFamily(prefixes) - v4Tree = internal.NewPrefixTreeForIPv4() - v6Tree = internal.NewPrefixTreeForIPv6() ) - for _, p := range v4 { - v4Tree.Add(p) - } - for _, p := range v6 { - v6Tree.Add(p) - } + v4 = internal.AggregatePrefixesForSingleIPFamily(v4) + v6 = internal.AggregatePrefixesForSingleIPFamily(v6) - return append(v4Tree.List(), v6Tree.List()...) + return append(v4, v6...) } // ExcludePrefixes excludes prefixes from the given prefixes. +// +// Examples: +// - ([192.168.0.0/24], [192.168.0.0/25]) -> [192.168.0.128/25] +// - ([2001:db8::/64], [2001:db8::1/128, 2001:db8::2/128]) -> [2001:db8::/64] func ExcludePrefixes(prefixes []netip.Prefix, exclude []netip.Prefix) []netip.Prefix { var ( v4Tree = internal.NewPrefixTreeForIPv4() diff --git a/pkg/provider/loadbalancer/iputil/prefix_test.go b/pkg/provider/loadbalancer/iputil/prefix_test.go index 0c2c7651b8..a5a2587eb0 100644 --- a/pkg/provider/loadbalancer/iputil/prefix_test.go +++ b/pkg/provider/loadbalancer/iputil/prefix_test.go @@ -23,6 +23,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" ) func TestIsPrefixesAllowAll(t *testing.T) { @@ -265,7 +267,10 @@ func TestAggregatePrefixes(t *testing.T) { sort.Slice(tt.Output, func(i, j int) bool { return tt.Output[i].String() < tt.Output[j].String() }) - assert.Equal(t, tt.Output, got) + + expected := fnutil.Map(fnutil.AsString, tt.Output) + actual := fnutil.Map(fnutil.AsString, got) + assert.Equal(t, expected, actual) }) } }