diff --git a/cmd/katana/main.go b/cmd/katana/main.go index 0880628d..258bf03c 100644 --- a/cmd/katana/main.go +++ b/cmd/katana/main.go @@ -174,6 +174,10 @@ pipelines offering both headless and non-headless crawling.`) flagSet.StringSliceVarP(&options.ExtensionFilter, "extension-filter", "ef", nil, "filter output for given extension (eg, -ef png,css)", goflags.CommaSeparatedStringSliceOptions), flagSet.StringVarP(&options.OutputMatchCondition, "match-condition", "mdc", "", "match response with dsl based condition"), flagSet.StringVarP(&options.OutputFilterCondition, "filter-condition", "fdc", "", "filter response with dsl based condition"), + flagSet.StringSliceVarP(&options.CountPathDepth, "count-path-depth", "cpd", nil, "filter urls by path depth count (e.g., '>=3', '==2', '3-5')", goflags.CommaSeparatedStringSliceOptions), + flagSet.StringSliceVarP(&options.CountQueryParams, "count-query-params", "cqp", nil, "filter urls by query parameter count (e.g., '>=3', '==2', '1-3')", goflags.CommaSeparatedStringSliceOptions), + flagSet.StringSliceVarP(&options.CountSubdomainDepth, "count-subdomain-depth", "csd", nil, "filter urls by subdomain depth count (e.g., '>=2', '==1', '1-3')", goflags.CommaSeparatedStringSliceOptions), + flagSet.BoolVar(&options.DepthFilterOrLogic, "depth-filter-or", false, "use OR logic between different depth filter types (default: AND logic)"), flagSet.BoolVarP(&options.DisableUniqueFilter, "disable-unique-filter", "duf", false, "disable duplicate content filtering"), ) diff --git a/internal/runner/options.go b/internal/runner/options.go index 5da77575..c239ffa7 100644 --- a/internal/runner/options.go +++ b/internal/runner/options.go @@ -11,6 +11,7 @@ import ( "github.com/projectdiscovery/gologger/formatter" "github.com/projectdiscovery/katana/pkg/types" "github.com/projectdiscovery/katana/pkg/utils" + "github.com/projectdiscovery/katana/pkg/utils/filters" errorutil "github.com/projectdiscovery/utils/errors" fileutil "github.com/projectdiscovery/utils/file" "gopkg.in/yaml.v3" @@ -58,6 +59,33 @@ func validateOptions(options *types.Options) error { } options.FilterRegex = append(options.FilterRegex, cr) } + + // Validate depth filter expressions + for _, filter := range options.CountPathDepth { + if filter == "" { + continue + } + if err := filters.ValidateAndSuggest("path depth", filter); err != nil { + return err + } + } + for _, filter := range options.CountQueryParams { + if filter == "" { + continue + } + if err := filters.ValidateAndSuggest("query parameter", filter); err != nil { + return err + } + } + for _, filter := range options.CountSubdomainDepth { + if filter == "" { + continue + } + if err := filters.ValidateAndSuggest("subdomain depth", filter); err != nil { + return err + } + } + if options.KnownFiles != "" && options.MaxDepth < 3 { gologger.Info().Msgf("Depth automatically set to 3 to accommodate the `--known-files` option (originally set to %d).", options.MaxDepth) options.MaxDepth = 3 diff --git a/pkg/output/options.go b/pkg/output/options.go index 2b4fe64b..e5dfe7a2 100644 --- a/pkg/output/options.go +++ b/pkg/output/options.go @@ -28,4 +28,8 @@ type Options struct { OutputTemplate string OutputMatchCondition string OutputFilterCondition string + CountPathDepth []string + CountQueryParams []string + CountSubdomainDepth []string + DepthFilterOrLogic bool } diff --git a/pkg/output/output.go b/pkg/output/output.go index b91f9190..98f01c65 100644 --- a/pkg/output/output.go +++ b/pkg/output/output.go @@ -3,6 +3,7 @@ package output import ( "errors" "fmt" + "net/url" "os" "path/filepath" "regexp" @@ -17,6 +18,7 @@ import ( "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/katana/pkg/navigation" "github.com/projectdiscovery/katana/pkg/utils/extensions" + "github.com/projectdiscovery/katana/pkg/utils/filters" errorutil "github.com/projectdiscovery/utils/errors" fileutil "github.com/projectdiscovery/utils/file" "github.com/stoewer/go-strcase" @@ -63,6 +65,7 @@ type StandardWriter struct { outputTemplate *fasttemplate.Template outputMatchCondition string outputFilterCondition string + depthValidator *filters.DepthFilterValidator } // New returns a new output writer instance @@ -85,6 +88,20 @@ func New(options Options) (Writer, error) { outputFilterCondition: options.OutputFilterCondition, } + // Initialize depth filter validator if depth filters are configured + if len(options.CountPathDepth) > 0 || len(options.CountQueryParams) > 0 || len(options.CountSubdomainDepth) > 0 { + depthValidator, err := filters.NewDepthFilterValidator( + options.CountPathDepth, + options.CountQueryParams, + options.CountSubdomainDepth, + options.DepthFilterOrLogic, + ) + if err != nil { + return nil, err + } + writer.depthValidator = depthValidator + } + if options.StoreFieldDir != "" { storeFieldDir = options.StoreFieldDir } @@ -353,10 +370,22 @@ func (w *StandardWriter) matchOutput(event *Result) bool { // filterOutput returns true if the event should be filtered out func (w *StandardWriter) filterOutput(event *Result) bool { - if w.filterRegex == nil && w.outputFilterCondition == "" { + if w.filterRegex == nil && w.outputFilterCondition == "" && w.depthValidator == nil { return false } + // Apply depth filtering if configured + if w.depthValidator != nil { + parsedURL, err := url.Parse(event.Request.URL) + if err != nil { + // If URL parsing fails, filter out the result + return true + } + if !w.depthValidator.ValidateURL(parsedURL) { + return true + } + } + for _, regex := range w.filterRegex { if regex.MatchString(event.Request.URL) { return true diff --git a/pkg/types/crawler_options.go b/pkg/types/crawler_options.go index 7d88b70c..cfd857fd 100644 --- a/pkg/types/crawler_options.go +++ b/pkg/types/crawler_options.go @@ -37,6 +37,8 @@ type CrawlerOptions struct { Dialer *fastdialer.Dialer // Wappalyzer instance for technologies detection Wappalyzer *wappalyzer.Wappalyze + // DepthValidator is a validator for URL depth filtering + DepthValidator *filters.DepthFilterValidator } // NewCrawlerOptions creates a new crawler options structure @@ -94,6 +96,10 @@ func NewCrawlerOptions(options *Options) (*CrawlerOptions, error) { OutputTemplate: options.OutputTemplate, OutputMatchCondition: options.OutputMatchCondition, OutputFilterCondition: options.OutputFilterCondition, + CountPathDepth: options.CountPathDepth, + CountQueryParams: options.CountQueryParams, + CountSubdomainDepth: options.CountSubdomainDepth, + DepthFilterOrLogic: options.DepthFilterOrLogic, } for _, mr := range options.OutputMatchRegex { @@ -116,6 +122,20 @@ func NewCrawlerOptions(options *Options) (*CrawlerOptions, error) { return nil, errorutil.NewWithErr(err).Msgf("could not create output writer") } + // Initialize depth filter validator if depth filters are configured + var depthValidator *filters.DepthFilterValidator + if len(options.CountPathDepth) > 0 || len(options.CountQueryParams) > 0 || len(options.CountSubdomainDepth) > 0 { + depthValidator, err = filters.NewDepthFilterValidator( + options.CountPathDepth, + options.CountQueryParams, + options.CountSubdomainDepth, + options.DepthFilterOrLogic, + ) + if err != nil { + return nil, errorutil.NewWithErr(err).Msgf("could not create depth filter validator") + } + } + crawlerOptions := &CrawlerOptions{ ExtensionsValidator: extensionsValidator, Parser: responseParser, @@ -124,6 +144,7 @@ func NewCrawlerOptions(options *Options) (*CrawlerOptions, error) { Options: options, Dialer: fastdialerInstance, OutputWriter: outputWriter, + DepthValidator: depthValidator, } if options.RateLimit > 0 { @@ -150,9 +171,16 @@ func (c *CrawlerOptions) Close() error { } func (c *CrawlerOptions) ValidatePath(path string) bool { + // First check extension validation if c.ExtensionsValidator != nil { - return c.ExtensionsValidator.ValidatePath(path) + if !c.ExtensionsValidator.ValidatePath(path) { + return false + } } + + // Note: Depth validation is handled at output stage to allow crawling + // but filter final results. This ensures we can discover URLs first. + return true } diff --git a/pkg/types/options.go b/pkg/types/options.go index 1918bd50..3e73d61e 100644 --- a/pkg/types/options.go +++ b/pkg/types/options.go @@ -42,6 +42,14 @@ type Options struct { OutputMatchCondition string // OutputFilterCondition is the condition to filter output OutputFilterCondition string + // CountPathDepth filters URLs by path depth count + CountPathDepth goflags.StringSlice + // CountQueryParams filters URLs by query parameter count + CountQueryParams goflags.StringSlice + // CountSubdomainDepth filters URLs by subdomain depth count + CountSubdomainDepth goflags.StringSlice + // DepthFilterOrLogic uses OR logic between depth filter types + DepthFilterOrLogic bool // MaxDepth is the maximum depth to crawl MaxDepth int // BodyReadSize is the maximum size of response body to read diff --git a/pkg/utils/filters/depth_filter.go b/pkg/utils/filters/depth_filter.go new file mode 100644 index 00000000..b884919e --- /dev/null +++ b/pkg/utils/filters/depth_filter.go @@ -0,0 +1,474 @@ +package filters + +import ( + "fmt" + "net/url" + "regexp" + "strconv" + "strings" + "sync" + "time" +) + +// DepthFilter represents a single depth filtering condition +type DepthFilter struct { + Operator string // "==", ">=", "<=", ">", "<", "range" + Value int // The count value to compare against + MaxValue int // For range operations (Value-MaxValue) +} + +// URLComponents represents cached URL components for performance +type URLComponents struct { + PathDepth int + QueryParams int + SubdomainDepth int + CachedAt time.Time +} + +// DepthFilterCache manages caching for URL components and filter results +type DepthFilterCache struct { + urlComponents map[string]*URLComponents + filterResults map[string]bool + mutex sync.RWMutex + maxSize int + ttl time.Duration +} + +// DepthFilterValidator manages depth filtering for URLs +type DepthFilterValidator struct { + pathFilters []DepthFilter + queryFilters []DepthFilter + subdomainFilters []DepthFilter + useOrLogic bool // If true, use OR logic between filter types; if false, use AND logic + cache *DepthFilterCache +} + +// NewDepthFilterValidator creates a new depth filter validator +func NewDepthFilterValidator(pathFilters, queryFilters, subdomainFilters []string, useOrLogic bool) (*DepthFilterValidator, error) { + validator := &DepthFilterValidator{ + useOrLogic: useOrLogic, + cache: NewDepthFilterCache(1000, 5*time.Minute), // Cache up to 1000 entries for 5 minutes + } + + // Parse path depth filters + for i, filter := range pathFilters { + if filter == "" { + continue + } + parsed, err := parseDepthFilter(filter) + if err != nil { + return nil, fmt.Errorf("āŒ Path depth filter #%d error:\n%w\n\nšŸ”§ Fix: Update your -cpd flag:\n katana -cpd \"[corrected_filter]\"", i+1, err) + } + validator.pathFilters = append(validator.pathFilters, parsed) + } + + // Parse query parameter filters + for i, filter := range queryFilters { + if filter == "" { + continue + } + parsed, err := parseDepthFilter(filter) + if err != nil { + return nil, fmt.Errorf("āŒ Query parameter filter #%d error:\n%w\n\nšŸ”§ Fix: Update your -cqp flag:\n katana -cqp \"[corrected_filter]\"", i+1, err) + } + validator.queryFilters = append(validator.queryFilters, parsed) + } + + // Parse subdomain depth filters + for i, filter := range subdomainFilters { + if filter == "" { + continue + } + parsed, err := parseDepthFilter(filter) + if err != nil { + return nil, fmt.Errorf("āŒ Subdomain depth filter #%d error:\n%w\n\nšŸ”§ Fix: Update your -csd flag:\n katana -csd \"[corrected_filter]\"", i+1, err) + } + validator.subdomainFilters = append(validator.subdomainFilters, parsed) + } + + return validator, nil +} + +// ValidateURL validates a URL against all configured depth filters with caching +func (d *DepthFilterValidator) ValidateURL(parsedURL *url.URL) bool { + if parsedURL == nil { + return false + } + + // Early exit: If no filters are configured, allow all URLs + if len(d.pathFilters) == 0 && len(d.queryFilters) == 0 && len(d.subdomainFilters) == 0 { + return true + } + + urlStr := parsedURL.String() + + // Get or compute URL components (cached) + components := d.cache.GetURLComponents(urlStr, parsedURL) + + // Check cache for filter result + cacheKey := d.generateCacheKey(components) + if result, exists := d.cache.GetFilterResult(cacheKey); exists { + return result + } + + // Compute filter result + var result bool + if d.useOrLogic { + result = d.evaluateOrLogic(components) + } else { + result = d.evaluateAndLogic(components) + } + + // Cache the result + d.cache.SetFilterResult(cacheKey, result) + + return result +} + +// evaluateOrLogic implements OR logic between filter types using cached components +func (d *DepthFilterValidator) evaluateOrLogic(components *URLComponents) bool { + // At least one configured filter type must pass + hasConfiguredFilters := len(d.pathFilters) > 0 || len(d.queryFilters) > 0 || len(d.subdomainFilters) > 0 + if !hasConfiguredFilters { + return true + } + + // Check each filter type - return true if any configured type passes + if len(d.pathFilters) > 0 && d.validatePathDepthWithComponents(components.PathDepth) { + return true + } + + if len(d.queryFilters) > 0 && d.validateQueryParamsWithComponents(components.QueryParams) { + return true + } + + if len(d.subdomainFilters) > 0 && d.validateSubdomainDepthWithComponents(components.SubdomainDepth) { + return true + } + + return false +} + +// evaluateAndLogic implements AND logic between filter types using cached components +func (d *DepthFilterValidator) evaluateAndLogic(components *URLComponents) bool { + // All configured filter types must pass + if len(d.pathFilters) > 0 && !d.validatePathDepthWithComponents(components.PathDepth) { + return false + } + + if len(d.queryFilters) > 0 && !d.validateQueryParamsWithComponents(components.QueryParams) { + return false + } + + if len(d.subdomainFilters) > 0 && !d.validateSubdomainDepthWithComponents(components.SubdomainDepth) { + return false + } + + return true +} + +// parseDepthFilter parses a filter expression like ">=3", "==2", "<=4", "3-5" +func parseDepthFilter(filter string) (DepthFilter, error) { + originalFilter := filter + filter = strings.TrimSpace(filter) + + if filter == "" { + return DepthFilter{}, fmt.Errorf("empty filter expression\n" + + "šŸ’” Tip: Use comparison operators like '>=3', '==2', '<=4' or ranges like '2-5'\n" + + "šŸ“– Examples:\n" + + " -cpd \">=3\" # Path depth 3 or more\n" + + " -cqp \"==2\" # Exactly 2 query parameters\n" + + " -csd \"1-3\" # Subdomain levels between 1 and 3") + } + + // Check for range syntax first (e.g., "3-5", "1-10") + rangeRe := regexp.MustCompile(`^(\d+)-(\d+)$`) + if rangeMatches := rangeRe.FindStringSubmatch(filter); len(rangeMatches) == 3 { + minValue, err1 := strconv.Atoi(rangeMatches[1]) + maxValue, err2 := strconv.Atoi(rangeMatches[2]) + + if err1 != nil || err2 != nil { + return DepthFilter{}, fmt.Errorf("invalid range values in '%s': numbers must be valid integers\n" + + "šŸ’” Tip: Use only non-negative integers in ranges\n" + + "šŸ“– Examples: '1-5', '0-3', '2-10'", filter) + } + + if minValue < 0 || maxValue < 0 { + return DepthFilter{}, fmt.Errorf("invalid range '%s': negative values not allowed (min=%d, max=%d)\n" + + "šŸ’” Tip: Use non-negative integers only\n" + + "šŸ“– Examples: '0-2' (0 to 2), '1-5' (1 to 5)", filter, minValue, maxValue) + } + + if minValue > maxValue { + return DepthFilter{}, fmt.Errorf("invalid range '%s': minimum value (%d) cannot be greater than maximum value (%d)\n" + + "šŸ’” Tip: Ensure the first number is smaller than or equal to the second\n" + + "āŒ Wrong: '%s'\n" + + "āœ… Correct: '%d-%d'", filter, minValue, maxValue, filter, maxValue, minValue) + } + + return DepthFilter{ + Operator: "range", + Value: minValue, + MaxValue: maxValue, + }, nil + } + + // Regular expression to match operator and value + re := regexp.MustCompile(`^(>=|<=|==|>|<)(\d+)$`) + matches := re.FindStringSubmatch(filter) + + if len(matches) != 3 { + // Provide specific guidance based on common mistakes + errorMsg := fmt.Sprintf("malformed filter expression '%s'\n", originalFilter) + + // Check for common mistakes and provide specific guidance + if strings.Contains(filter, " ") { + errorMsg += "šŸ’” Tip: Remove spaces from the filter expression\n" + + fmt.Sprintf("āŒ Wrong: '%s'\n", originalFilter) + + fmt.Sprintf("āœ… Correct: '%s'\n", strings.ReplaceAll(filter, " ", "")) + } else if regexp.MustCompile(`^\d+$`).MatchString(filter) { + errorMsg += "šŸ’” Tip: Add a comparison operator before the number\n" + + fmt.Sprintf("āŒ Wrong: '%s'\n", filter) + + fmt.Sprintf("āœ… Correct: '>=%s', '==%s', or '<=%s'\n", filter, filter, filter) + } else if regexp.MustCompile(`^[><=!]+$`).MatchString(filter) { + errorMsg += "šŸ’” Tip: Add a number after the operator\n" + + fmt.Sprintf("āŒ Wrong: '%s'\n", filter) + + fmt.Sprintf("āœ… Correct: '%s3', '%s2', or '%s5'\n", filter, filter, filter) + } else if strings.Contains(filter, "!=") || strings.Contains(filter, "<>") { + errorMsg += "šŸ’” Tip: Use '==' for equality, not '!=' or '<>'\n" + + "āŒ Wrong: '!=3', '<>2'\n" + + "āœ… Correct: '==3', '==2'\n" + + "ā„¹ļø Note: Use ranges for 'not equal' logic, e.g., '0-2' and '4-10' instead of '!=3'" + } else { + errorMsg += "šŸ’” Tip: Use valid comparison operators\n" + } + + errorMsg += "\nšŸ“– Valid formats:\n" + + " Comparisons: '>=3', '<=5', '==2', '>1', '<4'\n" + + " Ranges: '2-5', '0-3', '1-10'\n" + + "\nšŸŽÆ Common use cases:\n" + + " -cpd \">=2\" # URLs with path depth 2 or more\n" + + " -cqp \"1-3\" # URLs with 1 to 3 query parameters\n" + + " -csd \"==0\" # URLs with no subdomains" + + return DepthFilter{}, fmt.Errorf("%s", errorMsg) + } + + operator := matches[1] + valueStr := matches[2] + + value, err := strconv.Atoi(valueStr) + if err != nil { + return DepthFilter{}, fmt.Errorf("invalid value '%s' in filter '%s': must be a valid integer\n" + + "šŸ’” Tip: Use only numeric values\n" + + "āŒ Wrong: '%s'\n" + + "āœ… Correct: '%s123', '%s0', '%s5'", valueStr, originalFilter, operator, operator, operator, operator) + } + + if value < 0 { + return DepthFilter{}, fmt.Errorf("invalid value '%d' in filter '%s': negative values not allowed\n" + + "šŸ’” Tip: Use non-negative integers (0, 1, 2, 3, ...)\n" + + "āŒ Wrong: '%s'\n" + + "āœ… Correct: '%s%d'", value, originalFilter, originalFilter, operator, -value) + } + + return DepthFilter{ + Operator: operator, + Value: value, + MaxValue: 0, // Not used for non-range operators + }, nil +} + +// evaluateCondition evaluates a condition (actual operator expected) +func evaluateCondition(actual int, filter DepthFilter) bool { + switch filter.Operator { + case "==": + return actual == filter.Value + case ">=": + return actual >= filter.Value + case "<=": + return actual <= filter.Value + case ">": + return actual > filter.Value + case "<": + return actual < filter.Value + case "range": + return actual >= filter.Value && actual <= filter.MaxValue + default: + return false + } +} + +// countPathSegments counts the number of path segments in a URL path +func countPathSegments(path string) int { + if path == "" || path == "/" { + return 0 + } + + // Remove leading and trailing slashes + path = strings.Trim(path, "/") + if path == "" { + return 0 + } + + // Split by slash and count non-empty segments + segments := strings.Split(path, "/") + count := 0 + for _, segment := range segments { + if segment != "" { + count++ + } + } + + return count +} + +// countQueryParams counts the number of query parameters in a query string +func countQueryParams(query string) int { + if query == "" { + return 0 + } + + // Split by & and count valid parameters + params := strings.Split(query, "&") + count := 0 + for _, param := range params { + param = strings.TrimSpace(param) + // Count parameters that have a key (with or without value) + if param != "" && !strings.HasPrefix(param, "=") { + count++ + } + } + + return count +} + +// countSubdomainLevels counts the number of subdomain levels in a hostname +func countSubdomainLevels(hostname string) int { + if hostname == "" { + return 0 + } + + // Remove any port number + if colonIndex := strings.LastIndex(hostname, ":"); colonIndex != -1 { + hostname = hostname[:colonIndex] + } + + // Split by dots + parts := strings.Split(hostname, ".") + if len(parts) <= 2 { + // No subdomains (e.g., "example.com" or "localhost") + return 0 + } + + // Count subdomain levels (total parts - 2 for domain.tld) + return len(parts) - 2 +} + +// NewDepthFilterCache creates a new cache for URL components and filter results +func NewDepthFilterCache(maxSize int, ttl time.Duration) *DepthFilterCache { + return &DepthFilterCache{ + urlComponents: make(map[string]*URLComponents), + filterResults: make(map[string]bool), + maxSize: maxSize, + ttl: ttl, + } +} + +// GetURLComponents retrieves cached URL components or computes and caches them +func (c *DepthFilterCache) GetURLComponents(urlStr string, parsedURL *url.URL) *URLComponents { + c.mutex.RLock() + if components, exists := c.urlComponents[urlStr]; exists { + // Check if cache entry is still valid + if time.Since(components.CachedAt) < c.ttl { + c.mutex.RUnlock() + return components + } + } + c.mutex.RUnlock() + + // Compute components + components := &URLComponents{ + PathDepth: countPathSegments(parsedURL.Path), + QueryParams: countQueryParams(parsedURL.RawQuery), + SubdomainDepth: countSubdomainLevels(parsedURL.Hostname()), + CachedAt: time.Now(), + } + + // Cache the result + c.mutex.Lock() + // Implement simple LRU by clearing cache when it gets too large + if len(c.urlComponents) >= c.maxSize { + // Clear oldest entries (simple approach - clear all) + c.urlComponents = make(map[string]*URLComponents) + c.filterResults = make(map[string]bool) + } + c.urlComponents[urlStr] = components + c.mutex.Unlock() + + return components +} + +// GetFilterResult retrieves cached filter result or returns false if not found +func (c *DepthFilterCache) GetFilterResult(key string) (bool, bool) { + c.mutex.RLock() + defer c.mutex.RUnlock() + result, exists := c.filterResults[key] + return result, exists +} + +// SetFilterResult caches a filter result +func (c *DepthFilterCache) SetFilterResult(key string, result bool) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.filterResults[key] = result +} + +// generateCacheKey creates a cache key for filter results +func (d *DepthFilterValidator) generateCacheKey(components *URLComponents) string { + return fmt.Sprintf("p%d:q%d:s%d:or%t", + components.PathDepth, + components.QueryParams, + components.SubdomainDepth, + d.useOrLogic) +} + +// validatePathDepthWithComponents checks path depth using cached components +func (d *DepthFilterValidator) validatePathDepthWithComponents(depth int) bool { + for _, filter := range d.pathFilters { + if !evaluateCondition(depth, filter) { + return false + } + } + return true +} + +// validateQueryParamsWithComponents checks query params using cached components +func (d *DepthFilterValidator) validateQueryParamsWithComponents(count int) bool { + for _, filter := range d.queryFilters { + if !evaluateCondition(count, filter) { + return false + } + } + return true +} + +// validateSubdomainDepthWithComponents checks subdomain depth using cached components +func (d *DepthFilterValidator) validateSubdomainDepthWithComponents(depth int) bool { + for _, filter := range d.subdomainFilters { + if !evaluateCondition(depth, filter) { + return false + } + } + return true +} + +// ValidateAndSuggest validates a filter and provides suggestions if invalid +func ValidateAndSuggest(filterType, filter string) error { + _, err := parseDepthFilter(filter) + if err != nil { + return fmt.Errorf("āŒ %s filter validation failed:\n%w", + strings.Title(filterType), err) + } + return nil +} \ No newline at end of file diff --git a/pkg/utils/filters/depth_filter_test.go b/pkg/utils/filters/depth_filter_test.go new file mode 100644 index 00000000..34626987 --- /dev/null +++ b/pkg/utils/filters/depth_filter_test.go @@ -0,0 +1,260 @@ +package filters + +import ( + "net/url" + "testing" +) + +func TestParseDepthFilter(t *testing.T) { + tests := []struct { + name string + filter string + expectError bool + expected DepthFilter + }{ + { + name: "Valid greater than or equal", + filter: ">=3", + expectError: false, + expected: DepthFilter{Operator: ">=", Value: 3, MaxValue: 0}, + }, + { + name: "Valid equal", + filter: "==2", + expectError: false, + expected: DepthFilter{Operator: "==", Value: 2, MaxValue: 0}, + }, + { + name: "Valid range", + filter: "2-5", + expectError: false, + expected: DepthFilter{Operator: "range", Value: 2, MaxValue: 5}, + }, + { + name: "Invalid operator", + filter: "!=3", + expectError: true, + }, + { + name: "Invalid format", + filter: "3", + expectError: true, + }, + { + name: "Negative value", + filter: ">=-1", + expectError: true, + }, + { + name: "Invalid range", + filter: "5-2", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseDepthFilter(tt.filter) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error for filter '%s', but got none", tt.filter) + } + return + } + + if err != nil { + t.Errorf("Unexpected error for filter '%s': %v", tt.filter, err) + return + } + + if result.Operator != tt.expected.Operator { + t.Errorf("Expected operator '%s', got '%s'", tt.expected.Operator, result.Operator) + } + + if result.Value != tt.expected.Value { + t.Errorf("Expected value %d, got %d", tt.expected.Value, result.Value) + } + + if result.MaxValue != tt.expected.MaxValue { + t.Errorf("Expected max value %d, got %d", tt.expected.MaxValue, result.MaxValue) + } + }) + } +} + +func TestCountPathSegments(t *testing.T) { + tests := []struct { + path string + expected int + }{ + {"/", 0}, + {"", 0}, + {"/api", 1}, + {"/api/v1", 2}, + {"/api/v1/users", 3}, + {"/api/v1/users/", 3}, // trailing slash ignored + {"/api//v1", 2}, // empty segments ignored + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + result := countPathSegments(tt.path) + if result != tt.expected { + t.Errorf("For path '%s', expected %d, got %d", tt.path, tt.expected, result) + } + }) + } +} + +func TestCountQueryParams(t *testing.T) { + tests := []struct { + query string + expected int + }{ + {"", 0}, + {"user=admin", 1}, + {"user=admin&pass=secret", 2}, + {"user=admin&pass=secret&role=user", 3}, + {"user=admin&empty&pass=secret", 2}, // empty params ignored + {"=value", 0}, // invalid param ignored + } + + for _, tt := range tests { + t.Run(tt.query, func(t *testing.T) { + result := countQueryParams(tt.query) + if result != tt.expected { + t.Errorf("For query '%s', expected %d, got %d", tt.query, tt.expected, result) + } + }) + } +} + +func TestCountSubdomainLevels(t *testing.T) { + tests := []struct { + hostname string + expected int + }{ + {"example.com", 0}, + {"api.example.com", 1}, + {"api.v1.example.com", 2}, + {"cdn.assets.api.v1.example.com", 4}, + {"localhost", 0}, + {"example.com:8080", 0}, // port ignored + } + + for _, tt := range tests { + t.Run(tt.hostname, func(t *testing.T) { + result := countSubdomainLevels(tt.hostname) + if result != tt.expected { + t.Errorf("For hostname '%s', expected %d, got %d", tt.hostname, tt.expected, result) + } + }) + } +} + +func TestDepthFilterValidator(t *testing.T) { + validator, err := NewDepthFilterValidator( + []string{">=2"}, + []string{"<=3"}, + []string{"==1"}, + false, // AND logic + ) + if err != nil { + t.Fatalf("Failed to create validator: %v", err) + } + + tests := []struct { + name string + url string + expected bool + }{ + { + name: "Valid URL matching all filters", + url: "https://api.example.com/v1/users?id=1&format=json&sort=name", + expected: true, // path=3, query=3, subdomain=1 + }, + { + name: "URL with insufficient path depth", + url: "https://api.example.com/users", + expected: false, // path=1 (fails >=2) + }, + { + name: "URL with too many query params", + url: "https://api.example.com/v1/users?a=1&b=2&c=3&d=4&e=5", + expected: false, // query=5 (fails <=3) + }, + { + name: "URL with wrong subdomain count", + url: "https://cdn.api.example.com/v1/users?id=1", + expected: false, // subdomain=2 (fails ==1) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsedURL, err := url.Parse(tt.url) + if err != nil { + t.Fatalf("Failed to parse URL: %v", err) + } + + result := validator.ValidateURL(parsedURL) + if result != tt.expected { + t.Errorf("For URL '%s', expected %t, got %t", tt.url, tt.expected, result) + } + }) + } +} + +func TestDepthFilterValidatorORLogic(t *testing.T) { + validator, err := NewDepthFilterValidator( + []string{">=4"}, + []string{">=3"}, + []string{">=2"}, + true, // OR logic + ) + if err != nil { + t.Fatalf("Failed to create validator: %v", err) + } + + tests := []struct { + name string + url string + expected bool + }{ + { + name: "URL matching path filter only", + url: "https://example.com/a/b/c/d/e", // path=5, query=0, subdomain=0 + expected: true, // passes path filter (>=4) + }, + { + name: "URL matching query filter only", + url: "https://example.com/?a=1&b=2&c=3&d=4", // path=0, query=4, subdomain=0 + expected: true, // passes query filter (>=3) + }, + { + name: "URL matching subdomain filter only", + url: "https://a.b.example.com/", // path=0, query=0, subdomain=2 + expected: true, // passes subdomain filter (>=2) + }, + { + name: "URL matching no filters", + url: "https://example.com/api", // path=1, query=0, subdomain=0 + expected: false, // fails all filters + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsedURL, err := url.Parse(tt.url) + if err != nil { + t.Fatalf("Failed to parse URL: %v", err) + } + + result := validator.ValidateURL(parsedURL) + if result != tt.expected { + t.Errorf("For URL '%s', expected %t, got %t", tt.url, tt.expected, result) + } + }) + } +} \ No newline at end of file