From 7debaabaa690b4afe083f7eb35916856232cb3cd Mon Sep 17 00:00:00 2001 From: stacknil Date: Fri, 15 May 2026 15:40:20 +0800 Subject: [PATCH] Validate rule threshold config --- src/telemetry_window_demo/cli.py | 161 ++++++++++++++++++++++++++-- tests/test_run_config_validation.py | 54 ++++++++++ 2 files changed, 207 insertions(+), 8 deletions(-) diff --git a/src/telemetry_window_demo/cli.py b/src/telemetry_window_demo/cli.py index c4d032b..6b4f2cc 100644 --- a/src/telemetry_window_demo/cli.py +++ b/src/telemetry_window_demo/cli.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import math from collections.abc import Mapping, Sequence from pathlib import Path from typing import Any @@ -29,6 +30,16 @@ "source_spread_spike", "rare_event_repeat", ) +RUN_RULE_CONFIG_FIELDS = { + "high_error_rate": frozenset(("threshold", "severity")), + "login_fail_burst": frozenset(("threshold", "severity")), + "high_severity_spike": frozenset(("threshold", "severity")), + "persistent_high_error": frozenset( + ("threshold", "consecutive_windows", "severity") + ), + "source_spread_spike": frozenset(("absolute_threshold", "multiplier", "severity")), + "rare_event_repeat": frozenset(("threshold", "event_types", "severity")), +} def main() -> None: @@ -293,6 +304,15 @@ def _validate_rules_config(raw_rules_config: Any) -> dict[str, Any]: if raw_rules_config is None else dict(_optional_mapping(raw_rules_config, "rules")) ) + allowed_rule_keys = {"cooldown_seconds", *RUN_RULE_SECTION_NAMES} + unknown_rule_keys = sorted( + str(key) for key in rules_config if key not in allowed_rule_keys + ) + if unknown_rule_keys: + raise ValueError( + "Unknown config field(s) under 'rules': " + ", ".join(unknown_rule_keys) + ) + rules_config["cooldown_seconds"] = _int_config_value( rules_config.get("cooldown_seconds", 0), "rules.cooldown_seconds", @@ -301,18 +321,101 @@ def _validate_rules_config(raw_rules_config: Any) -> dict[str, Any]: for rule_name in RUN_RULE_SECTION_NAMES: if rule_name in rules_config: - rules_config[rule_name] = dict( - _optional_mapping(rules_config[rule_name], f"rules.{rule_name}") + rule_config = dict( + _optional_mapping( + rules_config[rule_name], + f"rules.{rule_name}", + ) ) + rules_config[rule_name] = _validate_rule_section_config( + rule_name, + rule_config, + ) + + return rules_config + - rare_event_repeat = rules_config.get("rare_event_repeat") - if isinstance(rare_event_repeat, dict) and "event_types" in rare_event_repeat: - rare_event_repeat["event_types"] = _string_sequence( - rare_event_repeat["event_types"], - "rules.rare_event_repeat.event_types", +def _validate_rule_section_config( + rule_name: str, + rule_config: dict[str, Any], +) -> dict[str, Any]: + allowed_fields = RUN_RULE_CONFIG_FIELDS[rule_name] + unknown_fields = sorted( + str(key) for key in rule_config if key not in allowed_fields + ) + if unknown_fields: + raise ValueError( + f"Unknown config field(s) under 'rules.{rule_name}': " + + ", ".join(unknown_fields) ) - return rules_config + if "severity" in rule_config: + rule_config["severity"] = _string_config_value( + rule_config["severity"], + f"rules.{rule_name}.severity", + ) + + if rule_name == "high_error_rate": + _normalize_optional_float( + rule_config, + "threshold", + "rules.high_error_rate.threshold", + minimum=0.0, + ) + elif rule_name == "login_fail_burst": + _normalize_optional_int( + rule_config, + "threshold", + "rules.login_fail_burst.threshold", + minimum=1, + ) + elif rule_name == "high_severity_spike": + _normalize_optional_int( + rule_config, + "threshold", + "rules.high_severity_spike.threshold", + minimum=1, + ) + elif rule_name == "persistent_high_error": + _normalize_optional_float( + rule_config, + "threshold", + "rules.persistent_high_error.threshold", + minimum=0.0, + ) + _normalize_optional_int( + rule_config, + "consecutive_windows", + "rules.persistent_high_error.consecutive_windows", + minimum=1, + ) + elif rule_name == "source_spread_spike": + _normalize_optional_int( + rule_config, + "absolute_threshold", + "rules.source_spread_spike.absolute_threshold", + minimum=1, + ) + _normalize_optional_float( + rule_config, + "multiplier", + "rules.source_spread_spike.multiplier", + minimum=1.0, + ) + elif rule_name == "rare_event_repeat": + _normalize_optional_int( + rule_config, + "threshold", + "rules.rare_event_repeat.threshold", + minimum=1, + ) + if "event_types" in rule_config: + rule_config["event_types"] = _string_sequence( + rule_config["event_types"], + "rules.rare_event_repeat.event_types", + ) + + return rule_config def _optional_mapping(value: Any, field_name: str) -> Mapping[str, Any]: @@ -349,6 +452,48 @@ def _int_config_value(value: Any, field_name: str, *, minimum: int) -> int: return parsed +def _float_config_value(value: Any, field_name: str, *, minimum: float) -> float: + if isinstance(value, bool): + raise ValueError(f"Config field '{field_name}' must be a number.") + if isinstance(value, (int, float)): + parsed = float(value) + elif isinstance(value, str): + try: + parsed = float(value.strip()) + except ValueError as exc: + raise ValueError(f"Config field '{field_name}' must be a number.") from exc + else: + raise ValueError(f"Config field '{field_name}' must be a number.") + + if not math.isfinite(parsed): + raise ValueError(f"Config field '{field_name}' must be a finite number.") + if parsed < minimum: + raise ValueError(f"Config field '{field_name}' must be at least {minimum:g}.") + return parsed + + +def _normalize_optional_int( + config: dict[str, Any], + key: str, + field_name: str, + *, + minimum: int, +) -> None: + if key in config: + config[key] = _int_config_value(config[key], field_name, minimum=minimum) + + +def _normalize_optional_float( + config: dict[str, Any], + key: str, + field_name: str, + *, + minimum: float, +) -> None: + if key in config: + config[key] = _float_config_value(config[key], field_name, minimum=minimum) + + def _optional_string_sequence(value: Any, field_name: str) -> list[str] | None: if value is None: return None diff --git a/tests/test_run_config_validation.py b/tests/test_run_config_validation.py index 744b1ea..0b95cca 100644 --- a/tests/test_run_config_validation.py +++ b/tests/test_run_config_validation.py @@ -89,3 +89,57 @@ def test_run_config_rejects_string_rare_event_types(tmp_path) -> None: with pytest.raises(ValueError, match="rules.rare_event_repeat.event_types"): run_command(Namespace(config=str(config_path))) + + +def test_run_config_rejects_unknown_rule_name(tmp_path) -> None: + config = _base_config(tmp_path) + config["rules"]["high_error_rates"] = {"threshold": 0.30} + config_path = _write_config(tmp_path, config) + + with pytest.raises(ValueError, match="high_error_rates"): + run_command(Namespace(config=str(config_path))) + + +def test_run_config_rejects_unknown_rule_field(tmp_path) -> None: + config = _base_config(tmp_path) + config["rules"]["high_error_rate"]["thresholds"] = 0.30 + config_path = _write_config(tmp_path, config) + + with pytest.raises(ValueError, match="rules.high_error_rate"): + run_command(Namespace(config=str(config_path))) + + +def test_run_config_rejects_boolean_rule_threshold(tmp_path) -> None: + config = _base_config(tmp_path) + config["rules"]["high_error_rate"]["threshold"] = True + config_path = _write_config(tmp_path, config) + + with pytest.raises(ValueError, match="rules.high_error_rate.threshold"): + run_command(Namespace(config=str(config_path))) + + +def test_run_config_rejects_non_positive_count_threshold(tmp_path) -> None: + config = _base_config(tmp_path) + config["rules"]["login_fail_burst"]["threshold"] = 0 + config_path = _write_config(tmp_path, config) + + with pytest.raises(ValueError, match="rules.login_fail_burst.threshold"): + run_command(Namespace(config=str(config_path))) + + +def test_run_config_rejects_source_spread_multiplier_below_one(tmp_path) -> None: + config = _base_config(tmp_path) + config["rules"]["source_spread_spike"]["multiplier"] = 0.5 + config_path = _write_config(tmp_path, config) + + with pytest.raises(ValueError, match="rules.source_spread_spike.multiplier"): + run_command(Namespace(config=str(config_path))) + + +def test_run_config_rejects_empty_rule_severity(tmp_path) -> None: + config = _base_config(tmp_path) + config["rules"]["persistent_high_error"]["severity"] = "" + config_path = _write_config(tmp_path, config) + + with pytest.raises(ValueError, match="rules.persistent_high_error.severity"): + run_command(Namespace(config=str(config_path)))