diff --git a/chaoslib/litmus/pod-network-partition/lib/network-policy.go b/chaoslib/litmus/pod-network-partition/lib/network-policy.go index 786e32162..e3b272c74 100644 --- a/chaoslib/litmus/pod-network-partition/lib/network-policy.go +++ b/chaoslib/litmus/pod-network-partition/lib/network-policy.go @@ -2,12 +2,12 @@ package lib import ( "fmt" + "net" + "strings" + "github.com/litmuschaos/litmus-go/pkg/cerrors" - "github.com/litmuschaos/litmus-go/pkg/clients" "github.com/palantir/stacktrace" - "strings" - network_chaos "github.com/litmuschaos/litmus-go/chaoslib/litmus/network-chaos/lib" experimentTypes "github.com/litmuschaos/litmus-go/pkg/generic/pod-network-partition/types" "gopkg.in/yaml.v2" corev1 "k8s.io/api/core/v1" @@ -181,30 +181,51 @@ func getPort(port int32, protocol corev1.Protocol) networkv1.NetworkPolicyPort { // setExceptIPs sets all the destination ips // for which traffic should be blocked func (np *NetworkPolicy) setExceptIPs(experimentsDetails *experimentTypes.ExperimentDetails) error { - // get all the target ips - destinationIPs, err := network_chaos.GetTargetIps(experimentsDetails.DestinationIPs, experimentsDetails.DestinationHosts, clients.ClientSets{}, false) - if err != nil { - return stacktrace.Propagate(err, "could not get destination ips") - } - - ips := strings.Split(destinationIPs, ",") - var uniqueIps []string - // removing all the duplicates and ipv6 ips from the list, if any - for i := range ips { - isPresent := false - for j := range uniqueIps { - if ips[i] == uniqueIps[j] { - isPresent = true - } + ips := strings.Split(experimentsDetails.DestinationIPs, ",") + seen := make(map[string]struct{}) + var ordered []string + for _, raw := range ips { + norm, err := normalizeIPOrCIDR(raw) + if err != nil { + return err + } + if norm == "" { + continue } - if ips[i] != "" && !isPresent && !strings.Contains(ips[i], ":") { - uniqueIps = append(uniqueIps, ips[i]+"/32") + if _, ok := seen[norm]; ok { + continue } + seen[norm] = struct{}{} + ordered = append(ordered, norm) } - np.ExceptIPs = uniqueIps + np.ExceptIPs = ordered return nil } +// normalizeIPOrCIDR validates and normalizes IP addresses or CIDR blocks, +// adding appropriate subnet masks (/32 for IPv4, /128 for IPv6) to plain IP addresses. + +func normalizeIPOrCIDR(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", nil + } + if strings.Contains(raw, "/") { + if _, _, err := net.ParseCIDR(raw); err != nil { + return "", fmt.Errorf("invalid CIDR %q: %w", raw, err) + } + return raw, nil + } + ip := net.ParseIP(raw) + if ip == nil { + return "", fmt.Errorf("invalid IP %q", raw) + } + if ip.To4() != nil { + return raw + "/32", nil + } + return raw + "/128", nil +} + // setIngressRules sets the ingress traffic rules func (np *NetworkPolicy) setIngressRules() *NetworkPolicy { diff --git a/chaoslib/litmus/pod-network-partition/lib/network-policy_cidr_test.go b/chaoslib/litmus/pod-network-partition/lib/network-policy_cidr_test.go new file mode 100644 index 000000000..c145bb7ef --- /dev/null +++ b/chaoslib/litmus/pod-network-partition/lib/network-policy_cidr_test.go @@ -0,0 +1,66 @@ +package lib + +import ( + "testing" + + partitionTypes "github.com/litmuschaos/litmus-go/pkg/generic/pod-network-partition/types" +) + +func Test_SetExceptIPs_CIDRHandling(t *testing.T) { + np := &NetworkPolicy{} + exp := &partitionTypes.ExperimentDetails{ + DestinationIPs: "10.0.0.5,10.0.0.0/24,10.0.1.0/28, 2001:db8::1,2001:db8::/64,10.0.0.5,2001:db8::1/64", + } + if err := np.setExceptIPs(exp); err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := []string{ + "10.0.0.5/32", + "10.0.0.0/24", + "10.0.1.0/28", + "2001:db8::1/128", + "2001:db8::/64", + "2001:db8::1/64", + } + got := np.ExceptIPs + if len(got) != len(want) { + t.Fatalf("len mismatch got=%v want=%v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("index %d got=%s want=%s full=%v", i, got[i], want[i], got) + } + } +} + +func Test_normalizeIPOrCIDR(t *testing.T) { + cases := map[string]string{ + "10.1.1.1": "10.1.1.1/32", + "10.1.1.0/28": "10.1.1.0/28", + "2001:db8::1": "2001:db8::1/128", + "2001:db8::/64": "2001:db8::/64", + } + for in, expect := range cases { + out, err := normalizeIPOrCIDR(in) + if err != nil { + t.Fatalf("unexpected err for %s: %v", in, err) + } + if out != expect { + t.Fatalf("normalize %s got %s want %s", in, out, expect) + } + } + + // Empty input- expect no error and empty output + out, err := normalizeIPOrCIDR("") + if err != nil || out != "" { + t.Fatalf("empty input: got (%q,%v) want (\"\",nil)", out, err) + } + + // Invalid (non-empty) inputs must return error + invalid := []string{"foo", "10.0.0.0/33", "2001:db8::/129"} + for _, in := range invalid { + if _, err := normalizeIPOrCIDR(in); err == nil { + t.Fatalf("expected error for %q", in) + } + } +}