From 78e23cbcece0905cae1df55c208f92ae34b575a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Thu, 20 Nov 2025 13:31:59 +0100 Subject: [PATCH 01/12] feat(attack-surfaces): add model for overviews --- .../0060_attack_surface_overview.py | 90 +++++++++++++++++++ api/src/backend/api/models.py | 60 +++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 api/src/backend/api/migrations/0060_attack_surface_overview.py diff --git a/api/src/backend/api/migrations/0060_attack_surface_overview.py b/api/src/backend/api/migrations/0060_attack_surface_overview.py new file mode 100644 index 0000000000..134c9306eb --- /dev/null +++ b/api/src/backend/api/migrations/0060_attack_surface_overview.py @@ -0,0 +1,90 @@ +# Generated by Django 5.1.14 on 2025-11-19 13:03 + +import uuid + +import django.db.models.deletion +from django.db import migrations, models + +import api.rls + + +class Migration(migrations.Migration): + + dependencies = [ + ("api", "0059_compliance_overview_summary"), + ] + + operations = [ + migrations.CreateModel( + name="AttackSurfaceOverview", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("inserted_at", models.DateTimeField(auto_now_add=True)), + ( + "attack_surface_type", + models.CharField( + choices=[ + ("internet-exposed", "Internet Exposed"), + ("secrets", "Exposed Secrets"), + ("privilege-escalation", "Privilege Escalation"), + ("ec2-imdsv1", "EC2 IMDSv1 Enabled"), + ], + max_length=50, + ), + ), + ("total_findings", models.IntegerField(default=0)), + ("failed_findings", models.IntegerField(default=0)), + ("muted_failed_findings", models.IntegerField(default=0)), + ], + options={ + "db_table": "attack_surface_overviews", + "abstract": False, + }, + ), + migrations.AddField( + model_name="attacksurfaceoverview", + name="scan", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="attack_surface_overviews", + related_query_name="attack_surface_overview", + to="api.scan", + ), + ), + migrations.AddField( + model_name="attacksurfaceoverview", + name="tenant", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="api.tenant" + ), + ), + migrations.AddIndex( + model_name="attacksurfaceoverview", + index=models.Index( + fields=["tenant_id", "scan_id"], name="attack_surf_tenant_scan_idx" + ), + ), + migrations.AddConstraint( + model_name="attacksurfaceoverview", + constraint=models.UniqueConstraint( + fields=("tenant_id", "scan_id", "attack_surface_type"), + name="unique_attack_surface_per_scan", + ), + ), + migrations.AddConstraint( + model_name="attacksurfaceoverview", + constraint=api.rls.RowLevelSecurityConstraint( + "tenant_id", + name="rls_on_attacksurfaceoverview", + statements=["SELECT", "INSERT", "UPDATE", "DELETE"], + ), + ), + ] diff --git a/api/src/backend/api/models.py b/api/src/backend/api/models.py index 1191350ac6..e2bb410894 100644 --- a/api/src/backend/api/models.py +++ b/api/src/backend/api/models.py @@ -2405,3 +2405,63 @@ class Meta(RowLevelSecurityProtectedModel.Meta): class JSONAPIMeta: resource_name = "threatscore-snapshots" + + +class AttackSurfaceOverview(RowLevelSecurityProtectedModel): + """ + Pre-aggregated attack surface metrics per scan. + + Stores counts for each attack surface type (internet-exposed, secrets, + privilege-escalation, ec2-imdsv1) to enable fast overview queries. + """ + + class AttackSurfaceTypeChoices(models.TextChoices): + INTERNET_EXPOSED = "internet-exposed", _("Internet Exposed") + SECRETS = "secrets", _("Exposed Secrets") + PRIVILEGE_ESCALATION = "privilege-escalation", _("Privilege Escalation") + EC2_IMDSV1 = "ec2-imdsv1", _("EC2 IMDSv1 Enabled") + + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) + inserted_at = models.DateTimeField(auto_now_add=True, editable=False) + + scan = models.ForeignKey( + Scan, + on_delete=models.CASCADE, + related_name="attack_surface_overviews", + related_query_name="attack_surface_overview", + ) + + attack_surface_type = models.CharField( + max_length=50, + choices=AttackSurfaceTypeChoices.choices, + ) + + # Finding counts + total_findings = models.IntegerField(default=0) # All findings (PASS + FAIL) + failed_findings = models.IntegerField(default=0) # Non-muted failed findings + muted_failed_findings = models.IntegerField(default=0) # Muted failed findings + + class Meta(RowLevelSecurityProtectedModel.Meta): + db_table = "attack_surface_overviews" + + constraints = [ + models.UniqueConstraint( + fields=("tenant_id", "scan_id", "attack_surface_type"), + name="unique_attack_surface_per_scan", + ), + RowLevelSecurityConstraint( + field="tenant_id", + name="rls_on_%(class)s", + statements=["SELECT", "INSERT", "UPDATE", "DELETE"], + ), + ] + + indexes = [ + models.Index( + fields=["tenant_id", "scan_id"], + name="attack_surf_tenant_scan_idx", + ), + ] + + class JSONAPIMeta: + resource_name = "attack-surface-overviews" From e909e68920e16ae18f8e075d3fab46acea0def74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Thu, 20 Nov 2025 13:32:41 +0100 Subject: [PATCH 02/12] feat(attack-surfaces): add serializer and DRY base classes --- api/src/backend/api/v1/serializers.py | 142 +++++++++++++------------- 1 file changed, 72 insertions(+), 70 deletions(-) diff --git a/api/src/backend/api/v1/serializers.py b/api/src/backend/api/v1/serializers.py index b2d95a2609..98836f55d6 100644 --- a/api/src/backend/api/v1/serializers.py +++ b/api/src/backend/api/v1/serializers.py @@ -72,6 +72,42 @@ from api.v1.serializer_utils.providers import ProviderSecretField from prowler.lib.mutelist.mutelist import Mutelist +# Base + + +class BaseModelSerializerV1(serializers.ModelSerializer): + def get_root_meta(self, _resource, _many): + return {"version": "v1"} + + +class BaseSerializerV1(serializers.Serializer): + def get_root_meta(self, _resource, _many): + return {"version": "v1"} + + +class BaseWriteSerializer(BaseModelSerializerV1): + def validate(self, data): + if hasattr(self, "initial_data"): + initial_data = set(self.initial_data.keys()) - {"id", "type"} + unknown_keys = initial_data - set(self.fields.keys()) + if unknown_keys: + raise ValidationError(f"Invalid fields: {unknown_keys}") + return data + + +class RLSSerializer(BaseModelSerializerV1): + def create(self, validated_data): + tenant_id = self.context.get("tenant_id") + validated_data["tenant_id"] = tenant_id + return super().create(validated_data) + + +class StateEnumSerializerField(serializers.ChoiceField): + def __init__(self, **kwargs): + kwargs["choices"] = StateChoices.choices + super().__init__(**kwargs) + + # Tokens @@ -179,7 +215,7 @@ def validate(self, attrs): # TODO: Check if we can change the parent class to TokenRefreshSerializer from rest_framework_simplejwt.serializers -class TokenRefreshSerializer(serializers.Serializer): +class TokenRefreshSerializer(BaseSerializerV1): refresh = serializers.CharField() # Output token @@ -213,7 +249,7 @@ def validate(self, attrs): raise ValidationError({"refresh": "Invalid or expired token"}) -class TokenSwitchTenantSerializer(serializers.Serializer): +class TokenSwitchTenantSerializer(BaseSerializerV1): tenant_id = serializers.UUIDField( write_only=True, help_text="The tenant ID for which to request a new token." ) @@ -237,41 +273,10 @@ def validate(self, attrs): return generate_tokens(user, tenant_id) -# Base - - -class BaseSerializerV1(serializers.ModelSerializer): - def get_root_meta(self, _resource, _many): - return {"version": "v1"} - - -class BaseWriteSerializer(BaseSerializerV1): - def validate(self, data): - if hasattr(self, "initial_data"): - initial_data = set(self.initial_data.keys()) - {"id", "type"} - unknown_keys = initial_data - set(self.fields.keys()) - if unknown_keys: - raise ValidationError(f"Invalid fields: {unknown_keys}") - return data - - -class RLSSerializer(BaseSerializerV1): - def create(self, validated_data): - tenant_id = self.context.get("tenant_id") - validated_data["tenant_id"] = tenant_id - return super().create(validated_data) - - -class StateEnumSerializerField(serializers.ChoiceField): - def __init__(self, **kwargs): - kwargs["choices"] = StateChoices.choices - super().__init__(**kwargs) - - # Users -class UserSerializer(BaseSerializerV1): +class UserSerializer(BaseModelSerializerV1): """ Serializer for the User model. """ @@ -402,7 +407,7 @@ def update(self, instance, validated_data): return super().update(instance, validated_data) -class RoleResourceIdentifierSerializer(serializers.Serializer): +class RoleResourceIdentifierSerializer(BaseSerializerV1): resource_type = serializers.CharField(source="type") id = serializers.UUIDField() @@ -585,7 +590,7 @@ def get_json_field(obj, field_name): # Tenants -class TenantSerializer(BaseSerializerV1): +class TenantSerializer(BaseModelSerializerV1): """ Serializer for the Tenant model. """ @@ -597,7 +602,7 @@ class Meta: fields = ["id", "name", "memberships"] -class TenantIncludeSerializer(BaseSerializerV1): +class TenantIncludeSerializer(BaseModelSerializerV1): class Meta: model = Tenant fields = ["id", "name"] @@ -773,7 +778,7 @@ def update(self, instance, validated_data): return super().update(instance, validated_data) -class ProviderResourceIdentifierSerializer(serializers.Serializer): +class ProviderResourceIdentifierSerializer(BaseSerializerV1): resource_type = serializers.CharField(source="type") id = serializers.UUIDField() @@ -1110,7 +1115,7 @@ class Meta: ] -class ScanReportSerializer(serializers.Serializer): +class ScanReportSerializer(BaseSerializerV1): id = serializers.CharField(source="scan") class Meta: @@ -1118,7 +1123,7 @@ class Meta: fields = ["id"] -class ScanComplianceReportSerializer(serializers.Serializer): +class ScanComplianceReportSerializer(BaseSerializerV1): id = serializers.CharField(source="scan") name = serializers.CharField() @@ -1267,7 +1272,7 @@ def get_fields(self): return fields -class ResourceMetadataSerializer(serializers.Serializer): +class ResourceMetadataSerializer(BaseSerializerV1): services = serializers.ListField(child=serializers.CharField(), allow_empty=True) regions = serializers.ListField(child=serializers.CharField(), allow_empty=True) types = serializers.ListField(child=serializers.CharField(), allow_empty=True) @@ -1337,7 +1342,7 @@ class Meta: # To be removed when the related endpoint is removed as well -class FindingDynamicFilterSerializer(serializers.Serializer): +class FindingDynamicFilterSerializer(BaseSerializerV1): services = serializers.ListField(child=serializers.CharField(), allow_empty=True) regions = serializers.ListField(child=serializers.CharField(), allow_empty=True) @@ -1345,7 +1350,7 @@ class Meta: resource_name = "finding-dynamic-filters" -class FindingMetadataSerializer(serializers.Serializer): +class FindingMetadataSerializer(BaseSerializerV1): services = serializers.ListField(child=serializers.CharField(), allow_empty=True) regions = serializers.ListField(child=serializers.CharField(), allow_empty=True) resource_types = serializers.ListField( @@ -2039,7 +2044,7 @@ class Meta: # Compliance overview -class ComplianceOverviewSerializer(serializers.Serializer): +class ComplianceOverviewSerializer(BaseSerializerV1): """ Serializer for compliance requirement status aggregated by compliance framework. @@ -2061,7 +2066,7 @@ class JSONAPIMeta: resource_name = "compliance-overviews" -class ComplianceOverviewDetailSerializer(serializers.Serializer): +class ComplianceOverviewDetailSerializer(BaseSerializerV1): """ Serializer for detailed compliance requirement information. @@ -2090,7 +2095,7 @@ class ComplianceOverviewDetailThreatscoreSerializer(ComplianceOverviewDetailSeri total_findings = serializers.IntegerField() -class ComplianceOverviewAttributesSerializer(serializers.Serializer): +class ComplianceOverviewAttributesSerializer(BaseSerializerV1): id = serializers.CharField() compliance_name = serializers.CharField() framework_description = serializers.CharField() @@ -2104,7 +2109,7 @@ class JSONAPIMeta: resource_name = "compliance-requirements-attributes" -class ComplianceOverviewMetadataSerializer(serializers.Serializer): +class ComplianceOverviewMetadataSerializer(BaseSerializerV1): regions = serializers.ListField(child=serializers.CharField(), allow_empty=True) class JSONAPIMeta: @@ -2114,7 +2119,7 @@ class JSONAPIMeta: # Overviews -class OverviewProviderSerializer(serializers.Serializer): +class OverviewProviderSerializer(BaseSerializerV1): id = serializers.CharField(source="provider") findings = serializers.SerializerMethodField(read_only=True) resources = serializers.SerializerMethodField(read_only=True) @@ -2122,9 +2127,6 @@ class OverviewProviderSerializer(serializers.Serializer): class JSONAPIMeta: resource_name = "providers-overview" - def get_root_meta(self, _resource, _many): - return {"version": "v1"} - @extend_schema_field( { "type": "object", @@ -2158,18 +2160,15 @@ def get_resources(self, obj): } -class OverviewProviderCountSerializer(serializers.Serializer): +class OverviewProviderCountSerializer(BaseSerializerV1): id = serializers.CharField(source="provider") count = serializers.IntegerField() class JSONAPIMeta: resource_name = "providers-count-overview" - def get_root_meta(self, _resource, _many): - return {"version": "v1"} - -class OverviewFindingSerializer(serializers.Serializer): +class OverviewFindingSerializer(BaseSerializerV1): id = serializers.CharField(default="n/a") new = serializers.IntegerField() changed = serializers.IntegerField() @@ -2188,15 +2187,12 @@ class OverviewFindingSerializer(serializers.Serializer): class JSONAPIMeta: resource_name = "findings-overview" - def get_root_meta(self, _resource, _many): - return {"version": "v1"} - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fields["pass"] = self.fields.pop("_pass") -class OverviewSeveritySerializer(serializers.Serializer): +class OverviewSeveritySerializer(BaseSerializerV1): id = serializers.CharField(default="n/a") critical = serializers.IntegerField() high = serializers.IntegerField() @@ -2207,11 +2203,8 @@ class OverviewSeveritySerializer(serializers.Serializer): class JSONAPIMeta: resource_name = "findings-severity-overview" - def get_root_meta(self, _resource, _many): - return {"version": "v1"} - -class OverviewServiceSerializer(serializers.Serializer): +class OverviewServiceSerializer(BaseSerializerV1): id = serializers.CharField(source="service") total = serializers.IntegerField() _pass = serializers.IntegerField() @@ -2225,14 +2218,23 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fields["pass"] = self.fields.pop("_pass") - def get_root_meta(self, _resource, _many): - return {"version": "v1"} + +class AttackSurfaceOverviewSerializer(BaseSerializerV1): + """Serializer for attack surface overview aggregations.""" + + id = serializers.CharField(source="attack_surface_type") + total_findings = serializers.IntegerField() + failed_findings = serializers.IntegerField() + muted_failed_findings = serializers.IntegerField() + + class JSONAPIMeta: + resource_name = "attack-surface-overviews" # Schedules -class ScheduleDailyCreateSerializer(serializers.Serializer): +class ScheduleDailyCreateSerializer(BaseSerializerV1): provider_id = serializers.UUIDField(required=True) class JSONAPIMeta: @@ -2568,7 +2570,7 @@ def to_representation(self, instance): return representation -class IntegrationJiraDispatchSerializer(serializers.Serializer): +class IntegrationJiraDispatchSerializer(BaseSerializerV1): """ Serializer for dispatching findings to JIRA integration. """ @@ -2731,14 +2733,14 @@ def validate_mutelist_configuration(self, configuration): # SSO -class SamlInitiateSerializer(serializers.Serializer): +class SamlInitiateSerializer(BaseSerializerV1): email_domain = serializers.CharField() class JSONAPIMeta: resource_name = "saml-initiate" -class SamlMetadataSerializer(serializers.Serializer): +class SamlMetadataSerializer(BaseSerializerV1): class JSONAPIMeta: resource_name = "saml-meta" From 4a028ec1d3c747d52d0b411cd4f79908268094c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Thu, 20 Nov 2025 13:33:40 +0100 Subject: [PATCH 03/12] feat(attack-surfaces): add task --- api/src/backend/tasks/jobs/scan.py | 145 ++++++++++++++++++++++++++++- api/src/backend/tasks/tasks.py | 19 ++++ 2 files changed, 163 insertions(+), 1 deletion(-) diff --git a/api/src/backend/tasks/jobs/scan.py b/api/src/backend/tasks/jobs/scan.py index c4cf413301..66b57ff985 100644 --- a/api/src/backend/tasks/jobs/scan.py +++ b/api/src/backend/tasks/jobs/scan.py @@ -12,7 +12,7 @@ from config.env import env from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS from django.db import IntegrityError, OperationalError -from django.db.models import Case, Count, IntegerField, Prefetch, Sum, When +from django.db.models import Case, Count, IntegerField, Prefetch, Q, Sum, When from tasks.utils import CustomEncoder from api.compliance import PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE @@ -26,6 +26,7 @@ ) from api.exceptions import ProviderConnectionError from api.models import ( + AttackSurfaceOverview, ComplianceOverviewSummary, ComplianceRequirementOverview, Finding, @@ -43,6 +44,7 @@ from api.models import StatusChoices as FindingStatus from api.utils import initialize_prowler_provider, return_prowler_provider from api.v1.serializers import ScanTaskSerializer +from prowler.lib.check.models import CheckMetadata from prowler.lib.outputs.finding import Finding as ProwlerFinding from prowler.lib.scan.scan import Scan as ProwlerScan @@ -75,6 +77,35 @@ SCAN_DB_BATCH_SIZE = env.int("DJANGO_SCAN_DB_BATCH_SIZE", default=500) +ATTACK_SURFACE_PROVIDER_COMPATIBILITY = { + "internet-exposed": None, # Compatible with all providers + "secrets": None, # Compatible with all providers + "privilege-escalation": ["aws", "kubernetes"], + "ec2-imdsv1": ["aws"], +} + + +def _get_attack_surface_mapping_from_provider(provider_type: str) -> dict: + attack_surface_check_mappings = { + "internet-exposed": None, + "secrets": None, + "privilege-escalation": { + "iam_policy_allows_privilege_escalation", + "iam_inline_policy_allows_privilege_escalation", + }, + "ec2-imdsv1": { + "ec2_instance_imdsv2_enabled" + }, # AWS only - IMDSv1 enabled findings + } + for category_name, check_ids in attack_surface_check_mappings.items(): + if check_ids is None: + sdk_check_ids = CheckMetadata.list( + provider=provider_type, category=category_name + ) + attack_surface_check_mappings[category_name] = sdk_check_ids + return attack_surface_check_mappings + + def _create_finding_delta( last_status: FindingStatus | None | str, new_status: FindingStatus | None ) -> Finding.DeltaChoices: @@ -1191,3 +1222,115 @@ def create_compliance_requirements(tenant_id: str, scan_id: str): except Exception as e: logger.error(f"Error creating compliance requirements for scan {scan_id}: {e}") raise e + + +def aggregate_attack_surface(tenant_id: str, scan_id: str): + """ + Aggregate findings into attack surface overview records. + + Creates one AttackSurfaceOverview record per attack surface type + for the given scan, based on check_id mappings. + + Args: + tenant_id: Tenant that owns the scan. + scan_id: Scan UUID whose findings should be aggregated. + """ + with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): + scan_instance = Scan.all_objects.get(pk=scan_id) + provider_type = scan_instance.provider.provider + + provider_attack_surface_mapping = _get_attack_surface_mapping_from_provider( + provider_type=provider_type + ) + + # Filter out attack surfaces that are not compatible or have no resolved check IDs + supported_mappings: dict[str, list[str]] = {} + for attack_surface_type, check_ids in provider_attack_surface_mapping.items(): + compatible_providers = ATTACK_SURFACE_PROVIDER_COMPATIBILITY.get( + attack_surface_type + ) + if ( + compatible_providers is not None + and provider_type not in compatible_providers + ): + logger.info( + f"Skipping {attack_surface_type} - not supported for {provider_type}" + ) + continue + + if not check_ids: + logger.info( + f"Skipping {attack_surface_type} - no check IDs resolved for {provider_type}" + ) + continue + + supported_mappings[attack_surface_type] = list(check_ids) + + if not supported_mappings: + logger.info( + f"No attack surface mappings available for scan {scan_id} and provider {provider_type}" + ) + logger.info(f"No attack surface overview records created for scan {scan_id}") + return + + # Map every check_id to its attack surface, so we can aggregate with a single query + check_id_to_surface: dict[str, str] = {} + for attack_surface_type, check_ids in supported_mappings.items(): + for check_id in check_ids: + check_id_to_surface[check_id] = attack_surface_type + + aggregated_counts = { + attack_surface_type: {"total": 0, "failed": 0, "muted": 0} + for attack_surface_type in supported_mappings.keys() + } + + with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): + finding_stats = ( + Finding.all_objects.filter( + tenant_id=tenant_id, + scan_id=scan_id, + check_id__in=list(check_id_to_surface.keys()), + ) + .values("check_id") + .annotate( + total=Count("id"), + failed=Count("id", filter=Q(status="FAIL", muted=False)), + muted=Count("id", filter=Q(status="FAIL", muted=True)), + ) + ) + + for stats in finding_stats: + attack_surface_type = check_id_to_surface.get(stats["check_id"]) + if not attack_surface_type: + continue + + aggregated_counts[attack_surface_type]["total"] += stats["total"] or 0 + aggregated_counts[attack_surface_type]["failed"] += stats["failed"] or 0 + aggregated_counts[attack_surface_type]["muted"] += stats["muted"] or 0 + + overview_objects = [] + for attack_surface_type, counts in aggregated_counts.items(): + total = counts["total"] + if not total: + continue + + overview_objects.append( + AttackSurfaceOverview( + tenant_id=tenant_id, + scan_id=scan_id, + attack_surface_type=attack_surface_type, + total_findings=total, + failed_findings=counts["failed"], + muted_failed_findings=counts["muted"], + ) + ) + + # Bulk create overview records + if overview_objects: + with rls_transaction(tenant_id): + AttackSurfaceOverview.objects.bulk_create(overview_objects, batch_size=500) + logger.info( + f"Created {len(overview_objects)} attack surface overview records for scan {scan_id}" + ) + else: + logger.info(f"No attack surface overview records created for scan {scan_id}") diff --git a/api/src/backend/tasks/tasks.py b/api/src/backend/tasks/tasks.py index 06a35676a5..52c16c2b3a 100644 --- a/api/src/backend/tasks/tasks.py +++ b/api/src/backend/tasks/tasks.py @@ -37,6 +37,7 @@ from tasks.jobs.muting import mute_historical_findings from tasks.jobs.report import generate_compliance_reports_job from tasks.jobs.scan import ( + aggregate_attack_surface, aggregate_findings, create_compliance_requirements, perform_prowler_scan, @@ -69,6 +70,9 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str) create_compliance_requirements_task.apply_async( kwargs={"tenant_id": tenant_id, "scan_id": scan_id} ) + aggregate_attack_surface_task.apply_async( + kwargs={"tenant_id": tenant_id, "scan_id": scan_id} + ) chain( perform_scan_summary_task.si(tenant_id=tenant_id, scan_id=scan_id), generate_outputs_task.si( @@ -529,6 +533,21 @@ def create_compliance_requirements_task(tenant_id: str, scan_id: str): return create_compliance_requirements(tenant_id=tenant_id, scan_id=scan_id) +@shared_task(name="scan-attack-surface-overviews", queue="overview") +def aggregate_attack_surface_task(tenant_id: str, scan_id: str): + """ + Creates attack surface overview records for a scan. + + This task processes findings and aggregates them into attack surface categories + (internet-exposed, secrets, privilege-escalation, ec2-imdsv1) for quick overview queries. + + Args: + tenant_id (str): The tenant ID for which to create records. + scan_id (str): The ID of the scan for which to create records. + """ + return aggregate_attack_surface(tenant_id=tenant_id, scan_id=scan_id) + + @shared_task(base=RLSTask, name="lighthouse-connection-check") @set_tenant def check_lighthouse_connection_task(lighthouse_config_id: str, tenant_id: str = None): From bb3f4be9314209c46c625617445f784cc1c34253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Thu, 20 Nov 2025 13:44:47 +0100 Subject: [PATCH 04/12] feat(attack-surfaces): add endpoint to retrieve overview --- api/src/backend/api/v1/views.py | 192 +++++++++++++++++++++++++++----- 1 file changed, 164 insertions(+), 28 deletions(-) diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 6d2787b41b..f88413bc10 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -126,6 +126,7 @@ UserFilter, ) from api.models import ( + AttackSurfaceOverview, ComplianceOverviewSummary, ComplianceRequirementOverview, Finding, @@ -172,6 +173,7 @@ from api.uuid_utils import datetime_to_uuid7, uuid7_start from api.v1.mixins import DisablePaginationMixin, PaginateByPkMixin, TaskManagementMixin from api.v1.serializers import ( + AttackSurfaceOverviewSerializer, ComplianceOverviewAttributesSerializer, ComplianceOverviewDetailSerializer, ComplianceOverviewDetailThreatscoreSerializer, @@ -4035,6 +4037,8 @@ def get_serializer_class(self): return OverviewServiceSerializer elif self.action == "threatscore": return ThreatScoreSnapshotSerializer + elif self.action == "attack_surface": + return AttackSurfaceOverviewSerializer return super().get_serializer_class() def get_filterset_class(self): @@ -4048,6 +4052,65 @@ def get_filterset_class(self): return ScanSummaryFilter return None + def _get_rbac_provider_filter(self): + """Get RBAC provider filter dict. Ensures get_queryset() is called.""" + if not hasattr(self, "allowed_providers"): + self.get_queryset() + return ( + {"provider__in": self.allowed_providers} + if hasattr(self, "allowed_providers") + else {} + ) + + def _get_latest_scan_ids(self, additional_filters=None): + """Get latest completed scan IDs per provider with RBAC + optional filters.""" + scan_filter = { + "tenant_id": self.request.tenant_id, + "state": StateChoices.COMPLETED, + **self._get_rbac_provider_filter(), + } + if additional_filters: + scan_filter.update(additional_filters) + + return ( + Scan.all_objects.filter(**scan_filter) + .order_by("provider_id", "-inserted_at") + .distinct("provider_id") + .values_list("id", flat=True) + ) + + def _parse_provider_filters(self, request): + """Parse provider filters from JSON:API query params.""" + normalized_params = QueryDict(mutable=True) + allowed_filter_keys = { + "provider_id", + "provider_id__in", + "provider_type", + "provider_type__in", + } + for param_key, values in request.query_params.lists(): + if not (param_key.startswith("filter[") and param_key.endswith("]")): + continue + normalized_key = param_key[7:-1] + if normalized_key in allowed_filter_keys: + normalized_params.setlist(normalized_key, values) + + scan_filter = {} + if provider_id := normalized_params.get("provider_id"): + scan_filter["provider_id"] = provider_id + if provider_ids := normalized_params.get("provider_id__in"): + scan_filter["provider_id__in"] = [ + pid.strip() for pid in provider_ids.split(",") if pid.strip() + ] + if provider_type := normalized_params.get("provider_type"): + scan_filter["provider__provider"] = provider_type + if provider_types := normalized_params.get("provider_type__in"): + scan_filter["provider__provider__in"] = [ + pt.strip() for pt in provider_types.split(",") if pt.strip() + ] + + return scan_filter + @extend_schema(exclude=True) def list(self, request, *args, **kwargs): raise MethodNotAllowed(method="GET") @@ -4060,20 +4123,7 @@ def retrieve(self, request, *args, **kwargs): def providers(self, request): tenant_id = self.request.tenant_id queryset = self.get_queryset() - provider_filter = ( - {"provider__in": self.allowed_providers} - if hasattr(self, "allowed_providers") - else {} - ) - - latest_scan_ids = ( - Scan.all_objects.filter( - tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter - ) - .order_by("provider_id", "-inserted_at") - .distinct("provider_id") - .values_list("id", flat=True) - ) + latest_scan_ids = self._get_latest_scan_ids() findings_aggregated = ( queryset.filter(scan_id__in=latest_scan_ids) @@ -4151,20 +4201,7 @@ def findings(self, request): tenant_id = self.request.tenant_id queryset = self.get_queryset() filtered_queryset = self.filter_queryset(queryset) - provider_filter = ( - {"provider__in": self.allowed_providers} - if hasattr(self, "allowed_providers") - else {} - ) - - latest_scan_ids = ( - Scan.all_objects.filter( - tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter - ) - .order_by("provider_id", "-inserted_at") - .distinct("provider_id") - .values_list("id", flat=True) - ) + latest_scan_ids = self._get_latest_scan_ids() filtered_queryset = filtered_queryset.filter( tenant_id=tenant_id, scan_id__in=latest_scan_ids ) @@ -4602,6 +4639,105 @@ def requirement_sort_key(item): return aggregated_snapshot + @extend_schema( + tags=["Overview"], + summary="Attack surface overview", + description="Retrieve aggregated attack surface metrics from latest completed scans per provider. " + "Always returns all 4 attack surface types with zero counts if no data exists.", + parameters=[ + OpenApiParameter( + name="filter[provider_id]", + type=OpenApiTypes.UUID, + location=OpenApiParameter.QUERY, + description="Filter by specific provider ID", + ), + OpenApiParameter( + name="filter[provider_id__in]", + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + description="Filter by multiple provider IDs (comma-separated UUIDs)", + ), + OpenApiParameter( + name="filter[provider_type]", + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + description="Filter by provider type (aws, azure, gcp, etc.)", + ), + OpenApiParameter( + name="filter[provider_type__in]", + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + description="Filter by multiple provider types (comma-separated)", + ), + ], + ) + @action( + detail=False, + methods=["get"], + url_name="attack-surface", + url_path="attack-surfaces", + ) + def attack_surface(self, request): + """ + Aggregate attack surface metrics from latest completed scans. + + Returns counts for all 4 attack surface types: + - internet-exposed: Internet-facing resources + - secrets: Exposed secrets/credentials + - privilege-escalation: IAM privilege escalation paths + - ec2-imdsv1: EC2 instances with IMDSv1 enabled + + Note: Provider-specific attack surfaces (e.g., ec2-imdsv1 only for AWS) + will return zero counts if filtered by incompatible provider type. + """ + tenant_id = self.request.tenant_id + + # Parse provider filters and get latest scans + provider_filters = self._parse_provider_filters(request) + latest_scan_ids = self._get_latest_scan_ids(additional_filters=provider_filters) + + # Query attack surface overviews for latest scans + queryset = AttackSurfaceOverview.objects.filter( + tenant_id=tenant_id, + scan_id__in=latest_scan_ids, + ) + + # Aggregate by attack surface type + aggregation = queryset.values("attack_surface_type").annotate( + total_findings=Sum("total_findings"), + failed_findings=Sum("failed_findings"), + muted_failed_findings=Sum("muted_failed_findings"), + ) + + # Convert to dict for easy lookup + results_by_type = {item["attack_surface_type"]: item for item in aggregation} + + # Always return all 4 attack surface types (fill with zeros if missing) + all_types = [ + "internet-exposed", + "secrets", + "privilege-escalation", + "ec2-imdsv1", + ] + complete_results = [] + + for attack_surface_type in all_types: + if attack_surface_type in results_by_type: + complete_results.append(results_by_type[attack_surface_type]) + else: + # No data for this type - return zeros + complete_results.append( + { + "attack_surface_type": attack_surface_type, + "total_findings": 0, + "failed_findings": 0, + "muted_failed_findings": 0, + } + ) + + serializer = self.get_serializer(complete_results, many=True) + return Response(data=serializer.data, status=status.HTTP_200_OK) + @extend_schema(tags=["Schedule"]) @extend_schema_view( From 820c79ba0f866e0a3aeb4bc415242b050c8f9893 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Tue, 25 Nov 2025 10:59:22 +0100 Subject: [PATCH 05/12] fix(tests): tasks unit tests --- api/src/backend/tasks/tests/test_tasks.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/api/src/backend/tasks/tests/test_tasks.py b/api/src/backend/tasks/tests/test_tasks.py index 938290d947..619caa3b29 100644 --- a/api/src/backend/tasks/tests/test_tasks.py +++ b/api/src/backend/tasks/tests/test_tasks.py @@ -529,6 +529,7 @@ def test_generate_outputs_filters_enabled_s3_integrations( class TestScanCompleteTasks: + @patch("tasks.tasks.aggregate_attack_surface_task.apply_async") @patch("tasks.tasks.create_compliance_requirements_task.apply_async") @patch("tasks.tasks.perform_scan_summary_task.si") @patch("tasks.tasks.generate_outputs_task.si") @@ -541,6 +542,7 @@ def test_scan_complete_tasks( mock_outputs_task, mock_scan_summary_task, mock_compliance_requirements_task, + mock_attack_surface_task, ): """Test that scan complete tasks are properly orchestrated with optimized reports.""" _perform_scan_complete_tasks("tenant-id", "scan-id", "provider-id") @@ -550,6 +552,11 @@ def test_scan_complete_tasks( kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"}, ) + # Verify attack surface task is called + mock_attack_surface_task.assert_called_once_with( + kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"}, + ) + # Verify scan summary task is called mock_scan_summary_task.assert_called_once_with( scan_id="scan-id", From d9f83fdace2cb393fbfdbc07e96c0dd641346694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Tue, 25 Nov 2025 12:17:01 +0100 Subject: [PATCH 06/12] refactor(attack-surfaces): improve view and task logic --- api/src/backend/api/v1/views.py | 5 ++--- api/src/backend/tasks/jobs/scan.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 02c44629e0..148577104d 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -4116,10 +4116,9 @@ def attributes(self, request): ], ), attack_surface=extend_schema( - summary="Attack surface overview", + summary="Get Attack surface overview", description=( - "Retrieve aggregated attack surface metrics from latest completed scans per provider. " - "Always returns all 4 attack surface types with zero counts if no data exists." + "Retrieve aggregated attack surface metrics from latest completed scans per provider." ), parameters=[ OpenApiParameter( diff --git a/api/src/backend/tasks/jobs/scan.py b/api/src/backend/tasks/jobs/scan.py index 66b57ff985..e059c21fe9 100644 --- a/api/src/backend/tasks/jobs/scan.py +++ b/api/src/backend/tasks/jobs/scan.py @@ -84,8 +84,15 @@ "ec2-imdsv1": ["aws"], } +_ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {} + def _get_attack_surface_mapping_from_provider(provider_type: str) -> dict: + global _ATTACK_SURFACE_MAPPING_CACHE + + if provider_type in _ATTACK_SURFACE_MAPPING_CACHE: + return _ATTACK_SURFACE_MAPPING_CACHE[provider_type] + attack_surface_check_mappings = { "internet-exposed": None, "secrets": None, @@ -103,6 +110,8 @@ def _get_attack_surface_mapping_from_provider(provider_type: str) -> dict: provider=provider_type, category=category_name ) attack_surface_check_mappings[category_name] = sdk_check_ids + + _ATTACK_SURFACE_MAPPING_CACHE[provider_type] = attack_surface_check_mappings return attack_surface_check_mappings @@ -1236,7 +1245,7 @@ def aggregate_attack_surface(tenant_id: str, scan_id: str): scan_id: Scan UUID whose findings should be aggregated. """ with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): - scan_instance = Scan.all_objects.get(pk=scan_id) + scan_instance = Scan.all_objects.select_related("provider").get(pk=scan_id) provider_type = scan_instance.provider.provider provider_attack_surface_mapping = _get_attack_surface_mapping_from_provider( From 9aec34609290046dba8b2f80890afea7d704ba4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Tue, 25 Nov 2025 12:17:33 +0100 Subject: [PATCH 07/12] test(attack-surfaces): add new unit tests --- api/src/backend/api/tests/test_views.py | 352 +++++++++++++++++++++++ api/src/backend/conftest.py | 16 ++ api/src/backend/tasks/tests/test_scan.py | 282 ++++++++++++++++++ 3 files changed, 650 insertions(+) diff --git a/api/src/backend/api/tests/test_views.py b/api/src/backend/api/tests/test_views.py index 99d92805ea..2fc7c840c5 100644 --- a/api/src/backend/api/tests/test_views.py +++ b/api/src/backend/api/tests/test_views.py @@ -35,6 +35,7 @@ from api.compliance import get_compliance_frameworks from api.db_router import MainRouter from api.models import ( + AttackSurfaceOverview, ComplianceOverviewSummary, ComplianceRequirementOverview, Finding, @@ -7003,6 +7004,357 @@ def test_overview_findings_severity_provider_id_in_filter( assert combined_attributes["medium"] == 4 assert combined_attributes["critical"] == 3 + def test_overview_attack_surface_no_data(self, authenticated_client): + response = authenticated_client.get(reverse("overview-attack-surface")) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + assert len(data) == 4 + for item in data: + assert item["attributes"]["total_findings"] == 0 + assert item["attributes"]["failed_findings"] == 0 + assert item["attributes"]["muted_failed_findings"] == 0 + + def test_overview_attack_surface_with_data( + self, + authenticated_client, + tenants_fixture, + providers_fixture, + create_attack_surface_overview, + ): + tenant = tenants_fixture[0] + provider = providers_fixture[0] + + scan = Scan.objects.create( + name="attack-surface-scan", + provider=provider, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant=tenant, + ) + + create_attack_surface_overview( + tenant, + scan, + AttackSurfaceOverview.AttackSurfaceTypeChoices.INTERNET_EXPOSED, + total=20, + failed=10, + muted_failed=3, + ) + create_attack_surface_overview( + tenant, + scan, + AttackSurfaceOverview.AttackSurfaceTypeChoices.SECRETS, + total=15, + failed=8, + muted_failed=2, + ) + + response = authenticated_client.get(reverse("overview-attack-surface")) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + assert len(data) == 4 + + results_by_type = {item["id"]: item["attributes"] for item in data} + assert results_by_type["internet-exposed"]["total_findings"] == 20 + assert results_by_type["internet-exposed"]["failed_findings"] == 10 + assert results_by_type["secrets"]["total_findings"] == 15 + assert results_by_type["secrets"]["failed_findings"] == 8 + assert results_by_type["privilege-escalation"]["total_findings"] == 0 + assert results_by_type["ec2-imdsv1"]["total_findings"] == 0 + + def test_overview_attack_surface_provider_filter( + self, + authenticated_client, + tenants_fixture, + providers_fixture, + create_attack_surface_overview, + ): + tenant = tenants_fixture[0] + provider1, provider2, *_ = providers_fixture + + scan1 = Scan.objects.create( + name="attack-surface-scan-1", + provider=provider1, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant=tenant, + ) + scan2 = Scan.objects.create( + name="attack-surface-scan-2", + provider=provider2, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant=tenant, + ) + + create_attack_surface_overview( + tenant, + scan1, + AttackSurfaceOverview.AttackSurfaceTypeChoices.INTERNET_EXPOSED, + total=10, + failed=5, + muted_failed=1, + ) + create_attack_surface_overview( + tenant, + scan2, + AttackSurfaceOverview.AttackSurfaceTypeChoices.INTERNET_EXPOSED, + total=20, + failed=15, + muted_failed=3, + ) + + response = authenticated_client.get( + reverse("overview-attack-surface"), + {"filter[provider_id]": str(provider1.id)}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + results_by_type = {item["id"]: item["attributes"] for item in data} + assert results_by_type["internet-exposed"]["total_findings"] == 10 + assert results_by_type["internet-exposed"]["failed_findings"] == 5 + + def test_overview_services_region_filter( + self, authenticated_client, scan_summaries_fixture + ): + response = authenticated_client.get( + reverse("overview-services"), + {"filter[region]": "region1"}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + assert len(data) == 2 + service_ids = {item["id"] for item in data} + assert service_ids == {"service1", "service2"} + + def test_overview_services_provider_type_filter( + self, authenticated_client, tenants_fixture, providers_fixture + ): + tenant = tenants_fixture[0] + aws_provider, _, gcp_provider, *_ = providers_fixture + + aws_scan = Scan.objects.create( + name="aws-scan", + provider=aws_provider, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant=tenant, + ) + gcp_scan = Scan.objects.create( + name="gcp-scan", + provider=gcp_provider, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant=tenant, + ) + + ScanSummary.objects.create( + tenant=tenant, + scan=aws_scan, + check_id="aws-check", + service="aws-service", + severity="high", + region="us-east-1", + _pass=5, + fail=2, + muted=1, + total=8, + ) + ScanSummary.objects.create( + tenant=tenant, + scan=gcp_scan, + check_id="gcp-check", + service="gcp-service", + severity="medium", + region="us-central1", + _pass=3, + fail=1, + muted=0, + total=4, + ) + + response = authenticated_client.get( + reverse("overview-services"), + {"filter[provider_type]": "aws"}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + service_ids = [item["id"] for item in data] + assert "aws-service" in service_ids + assert "gcp-service" not in service_ids + + @pytest.mark.parametrize( + "status_filter,field_to_check", + [ + ("FAIL", "fail"), + ("PASS", "_pass"), + ], + ) + def test_overview_findings_severity_status_filter( + self, + authenticated_client, + tenants_fixture, + providers_fixture, + status_filter, + field_to_check, + ): + tenant = tenants_fixture[0] + provider = providers_fixture[0] + + scan = Scan.objects.create( + name="status-filter-scan", + provider=provider, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant=tenant, + ) + + ScanSummary.objects.create( + tenant=tenant, + scan=scan, + check_id="status-check-high", + service="service-a", + severity="high", + region="us-east-1", + _pass=10, + fail=5, + muted=3, + total=18, + ) + ScanSummary.objects.create( + tenant=tenant, + scan=scan, + check_id="status-check-medium", + service="service-a", + severity="medium", + region="us-east-1", + _pass=8, + fail=2, + muted=1, + total=11, + ) + + response = authenticated_client.get( + reverse("overview-findings_severity"), + { + "filter[provider_id]": str(provider.id), + "filter[status]": status_filter, + }, + ) + assert response.status_code == status.HTTP_200_OK + attrs = response.json()["data"]["attributes"] + if status_filter == "FAIL": + assert attrs["high"] == 5 + assert attrs["medium"] == 2 + else: + assert attrs["high"] == 10 + assert attrs["medium"] == 8 + + def test_overview_threatscore_compliance_id_filter( + self, authenticated_client, tenants_fixture, providers_fixture + ): + tenant = tenants_fixture[0] + provider = providers_fixture[0] + scan = self._create_scan(tenant, provider, "compliance-filter-scan") + + self._create_threatscore_snapshot( + tenant, + scan, + provider, + compliance_id="prowler_threatscore_aws", + overall_score="75.00", + score_delta="2.00", + section_scores={"1. IAM": "70.00"}, + critical_requirements=[], + total_requirements=50, + passed_requirements=35, + failed_requirements=15, + manual_requirements=0, + total_findings=30, + passed_findings=20, + failed_findings=10, + ) + self._create_threatscore_snapshot( + tenant, + scan, + provider, + compliance_id="cis_1.4_aws", + overall_score="65.00", + score_delta="1.00", + section_scores={"1. IAM": "60.00"}, + critical_requirements=[], + total_requirements=40, + passed_requirements=25, + failed_requirements=15, + manual_requirements=0, + total_findings=25, + passed_findings=15, + failed_findings=10, + ) + + response = authenticated_client.get( + reverse("overview-threatscore"), + {"filter[compliance_id]": "prowler_threatscore_aws"}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + assert len(data) == 1 + assert data[0]["attributes"]["overall_score"] == "75.00" + assert data[0]["attributes"]["compliance_id"] == "prowler_threatscore_aws" + + def test_overview_threatscore_provider_type_filter( + self, authenticated_client, tenants_fixture, providers_fixture + ): + tenant = tenants_fixture[0] + aws_provider, _, gcp_provider, *_ = providers_fixture + + aws_scan = self._create_scan(tenant, aws_provider, "aws-threatscore-scan") + gcp_scan = self._create_scan(tenant, gcp_provider, "gcp-threatscore-scan") + + self._create_threatscore_snapshot( + tenant, + aws_scan, + aws_provider, + compliance_id="prowler_threatscore_aws", + overall_score="80.00", + score_delta="3.00", + section_scores={"1. IAM": "75.00"}, + critical_requirements=[], + total_requirements=60, + passed_requirements=45, + failed_requirements=15, + manual_requirements=0, + total_findings=40, + passed_findings=30, + failed_findings=10, + ) + self._create_threatscore_snapshot( + tenant, + gcp_scan, + gcp_provider, + compliance_id="prowler_threatscore_gcp", + overall_score="70.00", + score_delta="2.00", + section_scores={"1. IAM": "65.00"}, + critical_requirements=[], + total_requirements=50, + passed_requirements=35, + failed_requirements=15, + manual_requirements=0, + total_findings=35, + passed_findings=25, + failed_findings=10, + ) + + response = authenticated_client.get( + reverse("overview-threatscore"), + {"filter[provider_type]": "aws"}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + assert len(data) == 1 + assert data[0]["attributes"]["overall_score"] == "80.00" + @pytest.mark.django_db class TestScheduleViewSet: diff --git a/api/src/backend/conftest.py b/api/src/backend/conftest.py index 2988822da2..153487a83c 100644 --- a/api/src/backend/conftest.py +++ b/api/src/backend/conftest.py @@ -15,6 +15,7 @@ from api.db_utils import rls_transaction from api.models import ( + AttackSurfaceOverview, ComplianceOverview, ComplianceRequirementOverview, Finding, @@ -1469,6 +1470,21 @@ def mute_rules_fixture(tenants_fixture, create_test_user, findings_fixture): return mute_rule1, mute_rule2 +@pytest.fixture +def create_attack_surface_overview(): + def _create(tenant, scan, attack_surface_type, total=10, failed=5, muted_failed=2): + return AttackSurfaceOverview.objects.create( + tenant=tenant, + scan=scan, + attack_surface_type=attack_surface_type, + total_findings=total, + failed_findings=failed, + muted_failed_findings=muted_failed, + ) + + return _create + + def get_authorization_header(access_token: str) -> dict: return {"Authorization": f"Bearer {access_token}"} diff --git a/api/src/backend/tasks/tests/test_scan.py b/api/src/backend/tasks/tests/test_scan.py index f2724de4e6..c278b58496 100644 --- a/api/src/backend/tasks/tests/test_scan.py +++ b/api/src/backend/tasks/tests/test_scan.py @@ -9,14 +9,17 @@ import pytest from tasks.jobs.scan import ( + _ATTACK_SURFACE_MAPPING_CACHE, _aggregate_findings_by_region, _copy_compliance_requirement_rows, _create_compliance_summaries, _create_finding_delta, + _get_attack_surface_mapping_from_provider, _normalized_compliance_key, _persist_compliance_requirement_rows, _process_finding_micro_batch, _store_resources, + aggregate_attack_surface, aggregate_findings, create_compliance_requirements, perform_prowler_scan, @@ -3471,3 +3474,282 @@ def test_aggregate_findings_by_region_empty_findings( assert check_status_by_region == {} assert findings_count_by_compliance == {} + + +@pytest.mark.django_db +class TestAggregateAttackSurface: + """Test aggregate_attack_surface function and related caching.""" + + def setup_method(self): + """Clear cache before each test.""" + _ATTACK_SURFACE_MAPPING_CACHE.clear() + + def teardown_method(self): + """Clear cache after each test.""" + _ATTACK_SURFACE_MAPPING_CACHE.clear() + + @patch("tasks.jobs.scan.CheckMetadata.list") + def test_get_attack_surface_mapping_caches_result(self, mock_check_metadata_list): + """Test that _get_attack_surface_mapping_from_provider caches results.""" + mock_check_metadata_list.return_value = {"check_internet_exposed_1"} + + # First call should hit CheckMetadata.list + result1 = _get_attack_surface_mapping_from_provider("aws") + assert mock_check_metadata_list.call_count == 2 # internet-exposed, secrets + + # Second call should use cache + result2 = _get_attack_surface_mapping_from_provider("aws") + assert mock_check_metadata_list.call_count == 2 # No additional calls + + assert result1 is result2 + assert "aws" in _ATTACK_SURFACE_MAPPING_CACHE + + @patch("tasks.jobs.scan.CheckMetadata.list") + def test_get_attack_surface_mapping_different_providers( + self, mock_check_metadata_list + ): + """Test caching works independently for different providers.""" + mock_check_metadata_list.return_value = {"check_1"} + + _get_attack_surface_mapping_from_provider("aws") + aws_call_count = mock_check_metadata_list.call_count + + _get_attack_surface_mapping_from_provider("gcp") + gcp_call_count = mock_check_metadata_list.call_count + + # Both providers should have made calls + assert gcp_call_count > aws_call_count + assert "aws" in _ATTACK_SURFACE_MAPPING_CACHE + assert "gcp" in _ATTACK_SURFACE_MAPPING_CACHE + + @patch("tasks.jobs.scan.CheckMetadata.list") + def test_get_attack_surface_mapping_returns_hardcoded_checks( + self, mock_check_metadata_list + ): + """Test that hardcoded check IDs are returned for privilege-escalation and ec2-imdsv1.""" + mock_check_metadata_list.return_value = set() + + result = _get_attack_surface_mapping_from_provider("aws") + + # Hardcoded checks should be present + assert ( + "iam_policy_allows_privilege_escalation" in result["privilege-escalation"] + ) + assert ( + "iam_inline_policy_allows_privilege_escalation" + in result["privilege-escalation"] + ) + assert "ec2_instance_imdsv2_enabled" in result["ec2-imdsv1"] + + @patch("tasks.jobs.scan.AttackSurfaceOverview.objects.bulk_create") + @patch("tasks.jobs.scan.Finding.all_objects.filter") + @patch("tasks.jobs.scan._get_attack_surface_mapping_from_provider") + @patch("tasks.jobs.scan.rls_transaction") + def test_aggregate_attack_surface_creates_overview_records( + self, + mock_rls_transaction, + mock_get_mapping, + mock_findings_filter, + mock_bulk_create, + tenants_fixture, + scans_fixture, + ): + """Test that aggregate_attack_surface creates AttackSurfaceOverview records.""" + tenant = tenants_fixture[0] + scan = scans_fixture[0] + scan.provider.provider = "aws" + scan.provider.save() + + mock_get_mapping.return_value = { + "internet-exposed": {"check_internet_1", "check_internet_2"}, + "secrets": {"check_secrets_1"}, + "privilege-escalation": {"check_privesc_1"}, + "ec2-imdsv1": {"check_imdsv1_1"}, + } + + # Mock findings aggregation + mock_queryset = MagicMock() + mock_queryset.values.return_value = mock_queryset + mock_queryset.annotate.return_value = [ + {"check_id": "check_internet_1", "total": 10, "failed": 3, "muted": 1}, + {"check_id": "check_secrets_1", "total": 5, "failed": 2, "muted": 0}, + ] + + ctx = MagicMock() + ctx.__enter__.return_value = None + ctx.__exit__.return_value = False + mock_rls_transaction.return_value = ctx + mock_findings_filter.return_value = mock_queryset + + aggregate_attack_surface(str(tenant.id), str(scan.id)) + + mock_bulk_create.assert_called_once() + args, kwargs = mock_bulk_create.call_args + objects = args[0] + + # Should create records for internet-exposed and secrets (the ones with findings) + assert len(objects) == 2 + assert kwargs["batch_size"] == 500 + + @patch("tasks.jobs.scan.AttackSurfaceOverview.objects.bulk_create") + @patch("tasks.jobs.scan.Finding.all_objects.filter") + @patch("tasks.jobs.scan._get_attack_surface_mapping_from_provider") + @patch("tasks.jobs.scan.rls_transaction") + def test_aggregate_attack_surface_skips_unsupported_provider( + self, + mock_rls_transaction, + mock_get_mapping, + mock_findings_filter, + mock_bulk_create, + tenants_fixture, + scans_fixture, + ): + """Test that ec2-imdsv1 is skipped for non-AWS providers.""" + tenant = tenants_fixture[0] + scan = scans_fixture[0] + scan.provider.provider = "gcp" + scan.provider.uid = "gcp-test-project-id" + scan.provider.save() + + mock_get_mapping.return_value = { + "internet-exposed": {"check_internet_1"}, + "secrets": {"check_secrets_1"}, + "privilege-escalation": set(), # Not supported for GCP + "ec2-imdsv1": {"check_imdsv1_1"}, # Should be skipped for GCP + } + + mock_queryset = MagicMock() + mock_queryset.values.return_value = mock_queryset + mock_queryset.annotate.return_value = [ + {"check_id": "check_internet_1", "total": 5, "failed": 1, "muted": 0}, + ] + + ctx = MagicMock() + ctx.__enter__.return_value = None + ctx.__exit__.return_value = False + mock_rls_transaction.return_value = ctx + mock_findings_filter.return_value = mock_queryset + + aggregate_attack_surface(str(tenant.id), str(scan.id)) + + # ec2-imdsv1 check_ids should not be in the filter + filter_call = mock_findings_filter.call_args + check_ids_in_filter = filter_call[1]["check_id__in"] + assert "check_imdsv1_1" not in check_ids_in_filter + + @patch("tasks.jobs.scan.AttackSurfaceOverview.objects.bulk_create") + @patch("tasks.jobs.scan.Finding.all_objects.filter") + @patch("tasks.jobs.scan._get_attack_surface_mapping_from_provider") + @patch("tasks.jobs.scan.rls_transaction") + def test_aggregate_attack_surface_no_findings( + self, + mock_rls_transaction, + mock_get_mapping, + mock_findings_filter, + mock_bulk_create, + tenants_fixture, + scans_fixture, + ): + """Test that no records are created when there are no findings.""" + tenant = tenants_fixture[0] + scan = scans_fixture[0] + + mock_get_mapping.return_value = { + "internet-exposed": {"check_1"}, + "secrets": {"check_2"}, + "privilege-escalation": set(), + "ec2-imdsv1": set(), + } + + mock_queryset = MagicMock() + mock_queryset.values.return_value = mock_queryset + mock_queryset.annotate.return_value = [] # No findings + + ctx = MagicMock() + ctx.__enter__.return_value = None + ctx.__exit__.return_value = False + mock_rls_transaction.return_value = ctx + mock_findings_filter.return_value = mock_queryset + + aggregate_attack_surface(str(tenant.id), str(scan.id)) + + mock_bulk_create.assert_not_called() + + @patch("tasks.jobs.scan.AttackSurfaceOverview.objects.bulk_create") + @patch("tasks.jobs.scan.Finding.all_objects.filter") + @patch("tasks.jobs.scan._get_attack_surface_mapping_from_provider") + @patch("tasks.jobs.scan.rls_transaction") + def test_aggregate_attack_surface_aggregates_counts_correctly( + self, + mock_rls_transaction, + mock_get_mapping, + mock_findings_filter, + mock_bulk_create, + tenants_fixture, + scans_fixture, + ): + """Test that counts from multiple check_ids are aggregated per attack surface type.""" + tenant = tenants_fixture[0] + scan = scans_fixture[0] + scan.provider.provider = "aws" + scan.provider.save() + + mock_get_mapping.return_value = { + "internet-exposed": {"check_internet_1", "check_internet_2"}, + "secrets": set(), + "privilege-escalation": set(), + "ec2-imdsv1": set(), + } + + mock_queryset = MagicMock() + mock_queryset.values.return_value = mock_queryset + mock_queryset.annotate.return_value = [ + {"check_id": "check_internet_1", "total": 10, "failed": 3, "muted": 1}, + {"check_id": "check_internet_2", "total": 5, "failed": 2, "muted": 0}, + ] + + ctx = MagicMock() + ctx.__enter__.return_value = None + ctx.__exit__.return_value = False + mock_rls_transaction.return_value = ctx + mock_findings_filter.return_value = mock_queryset + + aggregate_attack_surface(str(tenant.id), str(scan.id)) + + args, kwargs = mock_bulk_create.call_args + objects = args[0] + + assert len(objects) == 1 + overview = objects[0] + assert overview.attack_surface_type == "internet-exposed" + assert overview.total_findings == 15 # 10 + 5 + assert overview.failed_findings == 5 # 3 + 2 + assert overview.muted_failed_findings == 1 # 1 + 0 + + @patch("tasks.jobs.scan.Scan.all_objects.select_related") + @patch("tasks.jobs.scan.rls_transaction") + def test_aggregate_attack_surface_uses_select_related( + self, mock_rls_transaction, mock_select_related, tenants_fixture, scans_fixture + ): + """Test that select_related is used to avoid N+1 query.""" + tenant = tenants_fixture[0] + scan = scans_fixture[0] + + mock_scan = MagicMock() + mock_scan.provider.provider = "aws" + + mock_select_related.return_value.get.return_value = mock_scan + + ctx = MagicMock() + ctx.__enter__.return_value = None + ctx.__exit__.return_value = False + mock_rls_transaction.return_value = ctx + + with patch( + "tasks.jobs.scan._get_attack_surface_mapping_from_provider" + ) as mock_map: + mock_map.return_value = {} + + aggregate_attack_surface(str(tenant.id), str(scan.id)) + + mock_select_related.assert_called_once_with("provider") From e0fdbfa7e328760cc541326b0e6fa3b7a6ba4cc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Tue, 25 Nov 2025 12:17:59 +0100 Subject: [PATCH 08/12] style: apply ruff --- api/src/backend/api/migrations/0060_attack_surface_overview.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/src/backend/api/migrations/0060_attack_surface_overview.py b/api/src/backend/api/migrations/0060_attack_surface_overview.py index 134c9306eb..8007d49a70 100644 --- a/api/src/backend/api/migrations/0060_attack_surface_overview.py +++ b/api/src/backend/api/migrations/0060_attack_surface_overview.py @@ -9,7 +9,6 @@ class Migration(migrations.Migration): - dependencies = [ ("api", "0059_compliance_overview_summary"), ] From b731d9509f85524efa5bfb2e51c549bbce2d288c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Tue, 25 Nov 2025 12:22:44 +0100 Subject: [PATCH 09/12] chore: update changelog and API version --- api/CHANGELOG.md | 7 + api/pyproject.toml | 2 +- api/src/backend/api/specs/v1.yaml | 229 ++++++++++++++++++++++-------- api/src/backend/api/v1/views.py | 2 +- 4 files changed, 175 insertions(+), 65 deletions(-) diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index b96a946094..db95086930 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -2,6 +2,13 @@ All notable changes to the **Prowler API** are documented in this file. +## [1.16.0] (Prowler v5.15.0) + +### Added +- New endpoint to retrieve an overview of the attack surfaces [(#9309)](https://github.com/prowler-cloud/prowler/pull/9309) + +--- + ## [1.15.0] (Prowler v5.14.0) ### Added diff --git a/api/pyproject.toml b/api/pyproject.toml index ffc8413489..122144c4e1 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -43,7 +43,7 @@ name = "prowler-api" package-mode = false # Needed for the SDK compatibility requires-python = ">=3.11,<3.13" -version = "1.15.0" +version = "1.16.0" [project.scripts] celery = "src.backend.config.settings.celery" diff --git a/api/src/backend/api/specs/v1.yaml b/api/src/backend/api/specs/v1.yaml index 9d8a8b9310..b1c6e642dc 100644 --- a/api/src/backend/api/specs/v1.yaml +++ b/api/src/backend/api/specs/v1.yaml @@ -1,7 +1,7 @@ openapi: 3.0.3 info: title: Prowler API - version: 1.15.0 + version: 1.16.0 description: |- Prowler API specification. @@ -4597,6 +4597,59 @@ paths: responses: '204': description: No response body + /api/v1/overviews/attack-surfaces: + get: + operationId: overviews_attack_surfaces_retrieve + description: Retrieve aggregated attack surface metrics from latest completed + scans per provider. + summary: Get Attack surface overview + parameters: + - in: query + name: fields[attack-surface-overviews] + schema: + type: array + items: + type: string + enum: + - id + - total_findings + - failed_findings + - muted_failed_findings + description: endpoint return only specific fields in the response on a per-type + basis by including a fields[TYPE] query parameter. + explode: false + - in: query + name: filter[provider_id] + schema: + type: string + format: uuid + description: Filter by specific provider ID + - in: query + name: filter[provider_id__in] + schema: + type: string + description: Filter by multiple provider IDs (comma-separated UUIDs) + - in: query + name: filter[provider_type] + schema: + type: string + description: Filter by provider type (aws, azure, gcp, etc.) + - in: query + name: filter[provider_type__in] + schema: + type: string + description: Filter by multiple provider types (comma-separated) + tags: + - Overview + security: + - JWT or API Key: [] + responses: + '200': + content: + application/vnd.api+json: + schema: + $ref: '#/components/schemas/AttackSurfaceOverviewResponse' + description: '' /api/v1/overviews/findings: get: operationId: overviews_findings_retrieve @@ -5068,6 +5121,8 @@ paths: type: string enum: - id + - provider_type + - region - total - fail - muted @@ -5200,6 +5255,10 @@ paths: enum: - id - -id + - provider_type + - -provider_type + - region + - -region - total - -total - fail @@ -8984,50 +9043,12 @@ paths: description: CSV file containing the compliance report '404': description: Compliance report not found - /api/v1/scans/{id}/report: - get: - operationId: scans_report_retrieve - description: Returns a ZIP file containing the requested report - summary: Download ZIP report - parameters: - - in: query - name: fields[scan-reports] - schema: - type: array - items: - type: string - enum: - - id - description: endpoint return only specific fields in the response on a per-type - basis by including a fields[TYPE] query parameter. - explode: false - - in: path - name: id - schema: - type: string - format: uuid - description: A UUID string identifying this scan. - required: true - tags: - - Scan - security: - - JWT or API Key: [] - responses: - '200': - description: Report obtained successfully - '202': - description: The task is in progress - '403': - description: There is a problem with credentials - '404': - description: The scan has no reports, or the report generation task has - not started yet - /api/v1/scans/{id}/threatscore: + /api/v1/scans/{id}/ens: get: - operationId: scans_threatscore_retrieve - description: Download a specific threatscore report (e.g., 'prowler_threatscore_aws') + operationId: scans_ens_retrieve + description: Download ENS RD2022 compliance report (e.g., 'ens_rd2022_aws') as a PDF file. - summary: Retrieve threatscore report + summary: Retrieve ENS RD2022 compliance report parameters: - in: query name: fields[scans] @@ -9078,7 +9099,7 @@ paths: - JWT or API Key: [] responses: '200': - description: PDF file containing the threatscore report + description: PDF file containing the ENS compliance report '202': description: The task is in progress '401': @@ -9086,14 +9107,14 @@ paths: '403': description: There is a problem with credentials '404': - description: The scan has no threatscore reports, or the threatscore report - generation task has not started yet - /api/v1/scans/{id}/ens: + description: The scan has no ENS reports, or the ENS report generation task + has not started yet + /api/v1/scans/{id}/nis2: get: - operationId: scans_ens_retrieve - description: Download a specific ENS compliance report (e.g., 'prowler_ens_aws') - as a PDF file. - summary: Retrieve ENS compliance report + operationId: scans_nis2_retrieve + description: Download NIS2 compliance report (Directive (EU) 2022/2555) as a + PDF file. + summary: Retrieve NIS2 compliance report parameters: - in: query name: fields[scans] @@ -9144,7 +9165,7 @@ paths: - JWT or API Key: [] responses: '200': - description: PDF file containing the ENS compliance report + description: PDF file containing the NIS2 compliance report '202': description: The task is in progress '401': @@ -9152,14 +9173,52 @@ paths: '403': description: There is a problem with credentials '404': - description: The scan has no ENS reports, or the ENS report generation task - has not started yet - /api/v1/scans/{id}/nis2: + description: The scan has no NIS2 reports, or the NIS2 report generation + task has not started yet + /api/v1/scans/{id}/report: get: - operationId: scans_nis2_retrieve - description: Download NIS2 compliance report (Directive (EU) 2022/2555) as a - PDF file. - summary: Retrieve NIS2 compliance report + operationId: scans_report_retrieve + description: Returns a ZIP file containing the requested report + summary: Download ZIP report + parameters: + - in: query + name: fields[scan-reports] + schema: + type: array + items: + type: string + enum: + - id + description: endpoint return only specific fields in the response on a per-type + basis by including a fields[TYPE] query parameter. + explode: false + - in: path + name: id + schema: + type: string + format: uuid + description: A UUID string identifying this scan. + required: true + tags: + - Scan + security: + - JWT or API Key: [] + responses: + '200': + description: Report obtained successfully + '202': + description: The task is in progress + '403': + description: There is a problem with credentials + '404': + description: The scan has no reports, or the report generation task has + not started yet + /api/v1/scans/{id}/threatscore: + get: + operationId: scans_threatscore_retrieve + description: Download a specific threatscore report (e.g., 'prowler_threatscore_aws') + as a PDF file. + summary: Retrieve threatscore report parameters: - in: query name: fields[scans] @@ -9210,7 +9269,7 @@ paths: - JWT or API Key: [] responses: '200': - description: PDF file containing the NIS2 compliance report + description: PDF file containing the threatscore report '202': description: The task is in progress '401': @@ -9218,8 +9277,8 @@ paths: '403': description: There is a problem with credentials '404': - description: The scan has no NIS2 reports, or the NIS2 report generation - task has not started yet + description: The scan has no threatscore reports, or the threatscore report + generation task has not started yet /api/v1/schedules/daily: post: operationId: schedules_daily_create @@ -10712,6 +10771,44 @@ paths: description: '' components: schemas: + AttackSurfaceOverview: + type: object + required: + - type + - id + additionalProperties: false + properties: + type: + type: string + description: The [type](https://jsonapi.org/format/#document-resource-object-identification) + member is used to describe resource objects that share common attributes + and relationships. + enum: + - attack-surface-overviews + id: {} + attributes: + type: object + properties: + id: + type: string + total_findings: + type: integer + failed_findings: + type: integer + muted_failed_findings: + type: integer + required: + - id + - total_findings + - failed_findings + - muted_failed_findings + AttackSurfaceOverviewResponse: + type: object + properties: + data: + $ref: '#/components/schemas/AttackSurfaceOverview' + required: + - data ComplianceOverview: type: object required: @@ -13558,6 +13655,11 @@ components: properties: id: type: string + readOnly: true + provider_type: + type: string + region: + type: string total: type: integer fail: @@ -13567,7 +13669,8 @@ components: pass: type: integer required: - - id + - provider_type + - region - total - fail - muted diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 148577104d..4d02ebf65a 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -352,7 +352,7 @@ class SchemaView(SpectacularAPIView): def get(self, request, *args, **kwargs): spectacular_settings.TITLE = "Prowler API" - spectacular_settings.VERSION = "1.15.0" + spectacular_settings.VERSION = "1.16.0" spectacular_settings.DESCRIPTION = ( "Prowler API specification.\n\nThis file is auto-generated." ) From 443e5c179b1c1c07e98b43c8dbbd1fbcd24bc1bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Wed, 26 Nov 2025 16:17:04 +0100 Subject: [PATCH 10/12] fix(attack-surfaces): view logic --- api/src/backend/api/filters.py | 20 + api/src/backend/api/v1/views.py | 806 ++++++++++++++++---------------- 2 files changed, 424 insertions(+), 402 deletions(-) diff --git a/api/src/backend/api/filters.py b/api/src/backend/api/filters.py index 2b93926f6e..ec11744e09 100644 --- a/api/src/backend/api/filters.py +++ b/api/src/backend/api/filters.py @@ -23,6 +23,7 @@ StatusEnumField, ) from api.models import ( + AttackSurfaceOverview, ComplianceRequirementOverview, Finding, Integration, @@ -1021,3 +1022,22 @@ class Meta: "inserted_at": ["date", "gte", "lte"], "overall_score": ["exact", "gte", "lte"], } + + +class AttackSurfaceOverviewFilter(FilterSet): + """Filter for attack surface overview aggregations by provider.""" + + provider_id = UUIDFilter(field_name="scan__provider__id", lookup_expr="exact") + provider_id__in = UUIDInFilter(field_name="scan__provider__id", lookup_expr="in") + provider_type = ChoiceFilter( + field_name="scan__provider__provider", choices=Provider.ProviderChoices.choices + ) + provider_type__in = ChoiceInFilter( + field_name="scan__provider__provider", + choices=Provider.ProviderChoices.choices, + lookup_expr="in", + ) + + class Meta: + model = AttackSurfaceOverview + fields = {} diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 4d02ebf65a..8ceb857a91 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -352,7 +352,7 @@ class SchemaView(SpectacularAPIView): def get(self, request, *args, **kwargs): spectacular_settings.TITLE = "Prowler API" - spectacular_settings.VERSION = "1.16.0" + spectacular_settings.VERSION = "1.15.0" spectacular_settings.DESCRIPTION = ( "Prowler API specification.\n\nThis file is auto-generated." ) @@ -4076,50 +4076,10 @@ def attributes(self, request): ), filters=True, ), - threatscore=extend_schema( - summary="Get ThreatScore snapshots", - description=( - "Retrieve ThreatScore metrics. By default, returns the latest snapshot for each provider. " - "Use snapshot_id to retrieve a specific historical snapshot." - ), - parameters=[ - OpenApiParameter( - name="snapshot_id", - type=OpenApiTypes.UUID, - location=OpenApiParameter.QUERY, - description="Retrieve a specific snapshot by ID. If not provided, returns latest snapshots.", - ), - OpenApiParameter( - name="provider_id", - type=OpenApiTypes.UUID, - location=OpenApiParameter.QUERY, - description="Filter by specific provider ID", - ), - OpenApiParameter( - name="provider_id__in", - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - description="Filter by multiple provider IDs (comma-separated UUIDs)", - ), - OpenApiParameter( - name="provider_type", - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - description="Filter by provider type (aws, azure, gcp, etc.)", - ), - OpenApiParameter( - name="provider_type__in", - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - description="Filter by multiple provider types (comma-separated)", - ), - ], - ), attack_surface=extend_schema( - summary="Get Attack surface overview", - description=( - "Retrieve aggregated attack surface metrics from latest completed scans per provider." - ), + summary="Retrieve attack surface overview", + description="Returns aggregated attack surface metrics from latest completed scans per provider.", + tags=["Overview"], parameters=[ OpenApiParameter( name="filter[provider_id]", @@ -4128,7 +4088,7 @@ def attributes(self, request): description="Filter by specific provider ID", ), OpenApiParameter( - name="filter[provider_id__in]", + name="filter[provider_id.in]", type=OpenApiTypes.STR, location=OpenApiParameter.QUERY, description="Filter by multiple provider IDs (comma-separated UUIDs)", @@ -4140,13 +4100,13 @@ def attributes(self, request): description="Filter by provider type (aws, azure, gcp, etc.)", ), OpenApiParameter( - name="filter[provider_type__in]", + name="filter[provider_type.in]", type=OpenApiTypes.STR, location=OpenApiParameter.QUERY, description="Filter by multiple provider types (comma-separated)", ), ], - ), + ) ) @method_decorator(CACHE_DECORATOR, name="list") class OverviewViewSet(BaseRLSViewSet): @@ -4194,355 +4154,112 @@ def get_filterset_class(self): return ScanSummarySeverityFilter return None - def _get_rbac_provider_filter(self): - """Get RBAC provider filter dict. Ensures get_queryset() is called.""" - if not hasattr(self, "allowed_providers"): - self.get_queryset() - return ( - {"provider__in": self.allowed_providers} - if hasattr(self, "allowed_providers") - else {} - ) - - def _get_latest_scan_ids(self, additional_filters=None): - """Get latest completed scan IDs per provider with RBAC + optional filters.""" - scan_filter = { - "tenant_id": self.request.tenant_id, - "state": StateChoices.COMPLETED, - **self._get_rbac_provider_filter(), - } - if additional_filters: - scan_filter.update(additional_filters) - - return ( - Scan.all_objects.filter(**scan_filter) - .order_by("provider_id", "-inserted_at") - .distinct("provider_id") - .values_list("id", flat=True) - ) - - def _parse_provider_filters(self, request): - """Parse provider filters from JSON:API query params.""" - normalized_params = QueryDict(mutable=True) - allowed_filter_keys = { - "provider_id", - "provider_id__in", - "provider_type", - "provider_type__in", - } - for param_key, values in request.query_params.lists(): - if not (param_key.startswith("filter[") and param_key.endswith("]")): - continue - normalized_key = param_key[7:-1] - if normalized_key in allowed_filter_keys: - normalized_params.setlist(normalized_key, values) - - scan_filter = {} - if provider_id := normalized_params.get("provider_id"): - scan_filter["provider_id"] = provider_id - if provider_ids := normalized_params.get("provider_id__in"): - scan_filter["provider_id__in"] = [ - pid.strip() for pid in provider_ids.split(",") if pid.strip() - ] - if provider_type := normalized_params.get("provider_type"): - scan_filter["provider__provider"] = provider_type - if provider_types := normalized_params.get("provider_type__in"): - scan_filter["provider__provider__in"] = [ - pt.strip() for pt in provider_types.split(",") if pt.strip() - ] + @extend_schema(exclude=True) + def list(self, request, *args, **kwargs): + raise MethodNotAllowed(method="GET") - return scan_filter + @extend_schema(exclude=True) + def retrieve(self, request, *args, **kwargs): + raise MethodNotAllowed(method="GET") - def _get_latest_scans_queryset(self, additional_filters=None): + def _get_latest_scans_queryset(self): """ Get filtered queryset for the latest completed scans per provider. - Args: - additional_filters: Optional dict of extra Scan filters - Returns: Filtered ScanSummary queryset with latest scan IDs applied. """ tenant_id = self.request.tenant_id queryset = self.get_queryset() filtered_queryset = self.filter_queryset(queryset) - latest_scan_ids = self._get_latest_scan_ids(additional_filters) + provider_filter = ( + {"provider__in": self.allowed_providers} + if hasattr(self, "allowed_providers") + else {} + ) + + latest_scan_ids = ( + Scan.all_objects.filter( + tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter + ) + .order_by("provider_id", "-inserted_at") + .distinct("provider_id") + .values_list("id", flat=True) + ) return filtered_queryset.filter( tenant_id=tenant_id, scan_id__in=latest_scan_ids ) - def _build_threatscore_overview_snapshot(self, snapshot_ids, tenant_id): - """ - Aggregate the latest snapshots into a single overview snapshot for the tenant. - """ - if not snapshot_ids: - raise ValueError( - "Snapshot id list cannot be empty when aggregating threatscore overview" - ) - - base_queryset = ThreatScoreSnapshot.objects.filter( - tenant_id=tenant_id, id__in=snapshot_ids + @action(detail=False, methods=["get"], url_name="providers") + def providers(self, request): + tenant_id = self.request.tenant_id + queryset = self.get_queryset() + provider_filter = ( + {"provider__in": self.allowed_providers} + if hasattr(self, "allowed_providers") + else {} ) - annotated_queryset = ( - base_queryset.annotate( - active_requirements=ExpressionWrapper( - F("total_requirements") - F("manual_requirements"), - output_field=IntegerField(), - ) - ) - .annotate( - weight=Case( - When(total_findings__gt=0, then=F("total_findings")), - When( - active_requirements__gt=0, - then=F("active_requirements"), - ), - default=Value(1, output_field=IntegerField()), - output_field=IntegerField(), - ) + latest_scan_ids = ( + Scan.all_objects.filter( + tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter ) - .order_by() + .order_by("provider_id", "-inserted_at") + .distinct("provider_id") + .values_list("id", flat=True) ) - aggregated_metrics = annotated_queryset.aggregate( - total_requirements=Sum("total_requirements"), - passed_requirements=Sum("passed_requirements"), - failed_requirements=Sum("failed_requirements"), - manual_requirements=Sum("manual_requirements"), - total_findings=Sum("total_findings"), - passed_findings=Sum("passed_findings"), - failed_findings=Sum("failed_findings"), - weighted_overall_sum=Sum( - ExpressionWrapper( - F("overall_score") * F("weight"), - output_field=DecimalField(max_digits=14, decimal_places=4), - ) - ), - overall_weight=Sum("weight"), - unweighted_overall_sum=Sum("overall_score"), - weighted_delta_sum=Sum( - Case( - When( - score_delta__isnull=False, - then=ExpressionWrapper( - F("score_delta") * F("weight"), - output_field=DecimalField(max_digits=14, decimal_places=4), - ), - ), - default=Value( - Decimal("0"), - output_field=DecimalField(max_digits=14, decimal_places=4), - ), - output_field=DecimalField(max_digits=14, decimal_places=4), - ) - ), - delta_weight=Sum( - Case( - When(score_delta__isnull=False, then=F("weight")), - default=Value(0, output_field=IntegerField()), - output_field=IntegerField(), - ) - ), - provider_count=Count("id"), - latest_inserted_at=Max("inserted_at"), + findings_aggregated = ( + queryset.filter(scan_id__in=latest_scan_ids) + .values(provider=F("scan__provider__provider")) + .annotate( + findings_passed=Coalesce(Sum("_pass"), 0), + findings_failed=Coalesce(Sum("fail"), 0), + findings_muted=Coalesce(Sum("muted"), 0), + total_findings=Coalesce(Sum("total"), 0), + ) ) - total_requirements = aggregated_metrics["total_requirements"] or 0 - passed_requirements = aggregated_metrics["passed_requirements"] or 0 - failed_requirements = aggregated_metrics["failed_requirements"] or 0 - manual_requirements = aggregated_metrics["manual_requirements"] or 0 - total_findings = aggregated_metrics["total_findings"] or 0 - passed_findings = aggregated_metrics["passed_findings"] or 0 - failed_findings = aggregated_metrics["failed_findings"] or 0 - - weighted_overall_sum = aggregated_metrics["weighted_overall_sum"] - if weighted_overall_sum is None: - weighted_overall_sum = Decimal("0") - unweighted_overall_sum = aggregated_metrics["unweighted_overall_sum"] - if unweighted_overall_sum is None: - unweighted_overall_sum = Decimal("0") - - overall_weight = aggregated_metrics["overall_weight"] or 0 - provider_count = aggregated_metrics["provider_count"] or 0 - - weighted_delta_sum = aggregated_metrics["weighted_delta_sum"] - if weighted_delta_sum is None: - weighted_delta_sum = Decimal("0") - delta_weight = aggregated_metrics["delta_weight"] or 0 - - if overall_weight > 0: - overall_score = (weighted_overall_sum / Decimal(overall_weight)).quantize( - Decimal("0.01"), rounding=ROUND_HALF_UP - ) - elif provider_count > 0: - overall_score = (unweighted_overall_sum / Decimal(provider_count)).quantize( - Decimal("0.01"), rounding=ROUND_HALF_UP + resources_queryset = Resource.all_objects.filter(tenant_id=tenant_id) + if hasattr(self, "allowed_providers"): + resources_queryset = resources_queryset.filter( + provider__in=self.allowed_providers ) - else: - overall_score = Decimal("0.00") + resources_aggregated = resources_queryset.values( + provider_type=F("provider__provider") + ).annotate(total_resources=Count("id")) + resource_map = { + row["provider_type"]: row["total_resources"] for row in resources_aggregated + } - if delta_weight > 0: - score_delta = (weighted_delta_sum / Decimal(delta_weight)).quantize( - Decimal("0.01"), rounding=ROUND_HALF_UP + overview = [] + for row in findings_aggregated: + overview.append( + { + "provider": row["provider"], + "total_resources": resource_map.get(row["provider"], 0), + "total_findings": row["total_findings"], + "findings_passed": row["findings_passed"], + "findings_failed": row["findings_failed"], + "findings_muted": row["findings_muted"], + } ) - else: - score_delta = None - - section_weighted_sums = defaultdict(lambda: Decimal("0")) - section_weights = defaultdict(lambda: Decimal("0")) - combined_critical_requirements = {} + return Response( + self.get_serializer(overview, many=True).data, + status=status.HTTP_200_OK, + ) - snapshots_with_weight = list(annotated_queryset) - - for snapshot in snapshots_with_weight: - weight_value = getattr(snapshot, "weight", None) - try: - weight_decimal = Decimal(weight_value) - except (InvalidOperation, TypeError): - weight_decimal = Decimal("1") - if weight_decimal <= 0: - weight_decimal = Decimal("1") - - section_scores = snapshot.section_scores or {} - for section, score in section_scores.items(): - try: - score_decimal = Decimal(str(score)) - except (InvalidOperation, TypeError): - continue - section_weighted_sums[section] += score_decimal * weight_decimal - section_weights[section] += weight_decimal - - for requirement in snapshot.critical_requirements or []: - key = requirement.get("requirement_id") or requirement.get("title") - if not key: - continue - existing = combined_critical_requirements.get(key) - - def requirement_sort_key(item): - return ( - item.get("risk_level") or 0, - item.get("weight") or 0, - ) - - if existing is None or requirement_sort_key( - requirement - ) > requirement_sort_key(existing): - combined_critical_requirements[key] = deepcopy(requirement) - - aggregated_section_scores = {} - for section, total in section_weighted_sums.items(): - weight_total = section_weights[section] - if weight_total > 0: - aggregated_section_scores[section] = str( - (total / weight_total).quantize( - Decimal("0.01"), rounding=ROUND_HALF_UP - ) - ) - - aggregated_section_scores = dict(sorted(aggregated_section_scores.items())) - - aggregated_critical_requirements = sorted( - combined_critical_requirements.values(), - key=lambda item: ( - item.get("risk_level") or 0, - item.get("weight") or 0, - ), - reverse=True, - ) - - aggregated_snapshot = ThreatScoreSnapshot( - tenant_id=tenant_id, - scan=None, - provider=None, - compliance_id="prowler_threatscore_overview", - overall_score=overall_score, - score_delta=score_delta, - section_scores=aggregated_section_scores, - critical_requirements=aggregated_critical_requirements, - total_requirements=total_requirements, - passed_requirements=passed_requirements, - failed_requirements=failed_requirements, - manual_requirements=manual_requirements, - total_findings=total_findings, - passed_findings=passed_findings, - failed_findings=failed_findings, - ) - - latest_inserted_at = aggregated_metrics["latest_inserted_at"] - if latest_inserted_at is not None: - aggregated_snapshot.inserted_at = latest_inserted_at - - aggregated_snapshot._aggregated = True - - return aggregated_snapshot - - @extend_schema(exclude=True) - def list(self, request, *args, **kwargs): - raise MethodNotAllowed(method="GET") - - @extend_schema(exclude=True) - def retrieve(self, request, *args, **kwargs): - raise MethodNotAllowed(method="GET") - - @action(detail=False, methods=["get"], url_name="providers") - def providers(self, request): - tenant_id = self.request.tenant_id - queryset = self.get_queryset() - latest_scan_ids = self._get_latest_scan_ids() - - findings_aggregated = ( - queryset.filter(scan_id__in=latest_scan_ids) - .values(provider=F("scan__provider__provider")) - .annotate( - findings_passed=Coalesce(Sum("_pass"), 0), - findings_failed=Coalesce(Sum("fail"), 0), - findings_muted=Coalesce(Sum("muted"), 0), - total_findings=Coalesce(Sum("total"), 0), - ) - ) - - resources_queryset = Resource.all_objects.filter(tenant_id=tenant_id) - if hasattr(self, "allowed_providers"): - resources_queryset = resources_queryset.filter( - provider__in=self.allowed_providers - ) - resources_aggregated = resources_queryset.values( - provider_type=F("provider__provider") - ).annotate(total_resources=Count("id")) - resource_map = { - row["provider_type"]: row["total_resources"] for row in resources_aggregated - } - - overview = [] - for row in findings_aggregated: - overview.append( - { - "provider": row["provider"], - "total_resources": resource_map.get(row["provider"], 0), - "total_findings": row["total_findings"], - "findings_passed": row["findings_passed"], - "findings_failed": row["findings_failed"], - "findings_muted": row["findings_muted"], - } - ) - - return Response( - self.get_serializer(overview, many=True).data, - status=status.HTTP_200_OK, - ) - - @action( - detail=False, - methods=["get"], - url_path="providers/count", - url_name="providers-count", - ) - def providers_count(self, request): - tenant_id = self.request.tenant_id - providers_qs = Provider.objects.filter(tenant_id=tenant_id) + @action( + detail=False, + methods=["get"], + url_path="providers/count", + url_name="providers-count", + ) + def providers_count(self, request): + tenant_id = self.request.tenant_id + providers_qs = Provider.objects.filter(tenant_id=tenant_id) if hasattr(self, "allowed_providers"): allowed_ids = list(self.allowed_providers.values_list("id", flat=True)) @@ -4651,6 +4368,46 @@ def regions(self, request): return Response(serializer.data, status=status.HTTP_200_OK) + @extend_schema( + summary="Get ThreatScore snapshots", + description=( + "Retrieve ThreatScore metrics. By default, returns the latest snapshot for each provider. " + "Use snapshot_id to retrieve a specific historical snapshot." + ), + tags=["Overview"], + parameters=[ + OpenApiParameter( + name="snapshot_id", + type=OpenApiTypes.UUID, + location=OpenApiParameter.QUERY, + description="Retrieve a specific snapshot by ID. If not provided, returns latest snapshots.", + ), + OpenApiParameter( + name="provider_id", + type=OpenApiTypes.UUID, + location=OpenApiParameter.QUERY, + description="Filter by specific provider ID", + ), + OpenApiParameter( + name="provider_id__in", + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + description="Filter by multiple provider IDs (comma-separated UUIDs)", + ), + OpenApiParameter( + name="provider_type", + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + description="Filter by provider type (aws, azure, gcp, etc.)", + ), + OpenApiParameter( + name="provider_type__in", + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + description="Filter by multiple provider types (comma-separated)", + ), + ], + ) @action(detail=False, methods=["get"], url_name="threatscore") def threatscore(self, request): """ @@ -4730,6 +4487,213 @@ def threatscore(self, request): ) return Response(serializer.data, status=status.HTTP_200_OK) + def _build_threatscore_overview_snapshot(self, snapshot_ids, tenant_id): + """ + Aggregate the latest snapshots into a single overview snapshot for the tenant. + """ + if not snapshot_ids: + raise ValueError( + "Snapshot id list cannot be empty when aggregating threatscore overview" + ) + + base_queryset = ThreatScoreSnapshot.objects.filter( + tenant_id=tenant_id, id__in=snapshot_ids + ) + + annotated_queryset = ( + base_queryset.annotate( + active_requirements=ExpressionWrapper( + F("total_requirements") - F("manual_requirements"), + output_field=IntegerField(), + ) + ) + .annotate( + weight=Case( + When(total_findings__gt=0, then=F("total_findings")), + When( + active_requirements__gt=0, + then=F("active_requirements"), + ), + default=Value(1, output_field=IntegerField()), + output_field=IntegerField(), + ) + ) + .order_by() + ) + + aggregated_metrics = annotated_queryset.aggregate( + total_requirements=Sum("total_requirements"), + passed_requirements=Sum("passed_requirements"), + failed_requirements=Sum("failed_requirements"), + manual_requirements=Sum("manual_requirements"), + total_findings=Sum("total_findings"), + passed_findings=Sum("passed_findings"), + failed_findings=Sum("failed_findings"), + weighted_overall_sum=Sum( + ExpressionWrapper( + F("overall_score") * F("weight"), + output_field=DecimalField(max_digits=14, decimal_places=4), + ) + ), + overall_weight=Sum("weight"), + unweighted_overall_sum=Sum("overall_score"), + weighted_delta_sum=Sum( + Case( + When( + score_delta__isnull=False, + then=ExpressionWrapper( + F("score_delta") * F("weight"), + output_field=DecimalField(max_digits=14, decimal_places=4), + ), + ), + default=Value( + Decimal("0"), + output_field=DecimalField(max_digits=14, decimal_places=4), + ), + output_field=DecimalField(max_digits=14, decimal_places=4), + ) + ), + delta_weight=Sum( + Case( + When(score_delta__isnull=False, then=F("weight")), + default=Value(0, output_field=IntegerField()), + output_field=IntegerField(), + ) + ), + provider_count=Count("id"), + latest_inserted_at=Max("inserted_at"), + ) + + total_requirements = aggregated_metrics["total_requirements"] or 0 + passed_requirements = aggregated_metrics["passed_requirements"] or 0 + failed_requirements = aggregated_metrics["failed_requirements"] or 0 + manual_requirements = aggregated_metrics["manual_requirements"] or 0 + total_findings = aggregated_metrics["total_findings"] or 0 + passed_findings = aggregated_metrics["passed_findings"] or 0 + failed_findings = aggregated_metrics["failed_findings"] or 0 + + weighted_overall_sum = aggregated_metrics["weighted_overall_sum"] + if weighted_overall_sum is None: + weighted_overall_sum = Decimal("0") + unweighted_overall_sum = aggregated_metrics["unweighted_overall_sum"] + if unweighted_overall_sum is None: + unweighted_overall_sum = Decimal("0") + + overall_weight = aggregated_metrics["overall_weight"] or 0 + provider_count = aggregated_metrics["provider_count"] or 0 + + weighted_delta_sum = aggregated_metrics["weighted_delta_sum"] + if weighted_delta_sum is None: + weighted_delta_sum = Decimal("0") + delta_weight = aggregated_metrics["delta_weight"] or 0 + + if overall_weight > 0: + overall_score = (weighted_overall_sum / Decimal(overall_weight)).quantize( + Decimal("0.01"), rounding=ROUND_HALF_UP + ) + elif provider_count > 0: + overall_score = (unweighted_overall_sum / Decimal(provider_count)).quantize( + Decimal("0.01"), rounding=ROUND_HALF_UP + ) + else: + overall_score = Decimal("0.00") + + if delta_weight > 0: + score_delta = (weighted_delta_sum / Decimal(delta_weight)).quantize( + Decimal("0.01"), rounding=ROUND_HALF_UP + ) + else: + score_delta = None + + section_weighted_sums = defaultdict(lambda: Decimal("0")) + section_weights = defaultdict(lambda: Decimal("0")) + + combined_critical_requirements = {} + + snapshots_with_weight = list(annotated_queryset) + + for snapshot in snapshots_with_weight: + weight_value = getattr(snapshot, "weight", None) + try: + weight_decimal = Decimal(weight_value) + except (InvalidOperation, TypeError): + weight_decimal = Decimal("1") + if weight_decimal <= 0: + weight_decimal = Decimal("1") + + section_scores = snapshot.section_scores or {} + for section, score in section_scores.items(): + try: + score_decimal = Decimal(str(score)) + except (InvalidOperation, TypeError): + continue + section_weighted_sums[section] += score_decimal * weight_decimal + section_weights[section] += weight_decimal + + for requirement in snapshot.critical_requirements or []: + key = requirement.get("requirement_id") or requirement.get("title") + if not key: + continue + existing = combined_critical_requirements.get(key) + + def requirement_sort_key(item): + return ( + item.get("risk_level") or 0, + item.get("weight") or 0, + ) + + if existing is None or requirement_sort_key( + requirement + ) > requirement_sort_key(existing): + combined_critical_requirements[key] = deepcopy(requirement) + + aggregated_section_scores = {} + for section, total in section_weighted_sums.items(): + weight_total = section_weights[section] + if weight_total > 0: + aggregated_section_scores[section] = str( + (total / weight_total).quantize( + Decimal("0.01"), rounding=ROUND_HALF_UP + ) + ) + + aggregated_section_scores = dict(sorted(aggregated_section_scores.items())) + + aggregated_critical_requirements = sorted( + combined_critical_requirements.values(), + key=lambda item: ( + item.get("risk_level") or 0, + item.get("weight") or 0, + ), + reverse=True, + ) + + aggregated_snapshot = ThreatScoreSnapshot( + tenant_id=tenant_id, + scan=None, + provider=None, + compliance_id="prowler_threatscore_overview", + overall_score=overall_score, + score_delta=score_delta, + section_scores=aggregated_section_scores, + critical_requirements=aggregated_critical_requirements, + total_requirements=total_requirements, + passed_requirements=passed_requirements, + failed_requirements=failed_requirements, + manual_requirements=manual_requirements, + total_findings=total_findings, + passed_findings=passed_findings, + failed_findings=failed_findings, + ) + + latest_inserted_at = aggregated_metrics["latest_inserted_at"] + if latest_inserted_at is not None: + aggregated_snapshot.inserted_at = latest_inserted_at + + aggregated_snapshot._aggregated = True + + return aggregated_snapshot + @action( detail=False, methods=["get"], @@ -4737,48 +4701,86 @@ def threatscore(self, request): url_path="attack-surfaces", ) def attack_surface(self, request): - tenant_id = self.request.tenant_id - - # Parse provider filters and get latest scans - provider_filters = self._parse_provider_filters(request) - latest_scan_ids = self._get_latest_scan_ids(additional_filters=provider_filters) + tenant_id = request.tenant_id + self.get_queryset() # Triggers RBAC setup (sets self.allowed_providers) - # Query attack surface overviews for latest scans - queryset = AttackSurfaceOverview.objects.filter( - tenant_id=tenant_id, - scan_id__in=latest_scan_ids, + # RBAC provider filter + provider_filter = ( + {"provider__in": self.allowed_providers} + if hasattr(self, "allowed_providers") + else {} ) - # Aggregate by attack surface type - aggregation = queryset.values("attack_surface_type").annotate( - total_findings=Sum("total_findings"), - failed_findings=Sum("failed_findings"), - muted_failed_findings=Sum("muted_failed_findings"), + # Parse filter params (filter[provider_id] -> provider_id) + normalized_params = QueryDict(mutable=True) + for key, values in request.query_params.lists(): + normalized_key = ( + key[7:-1] if key.startswith("filter[") and key.endswith("]") else key + ) + normalized_params.setlist(normalized_key, values) + + # Build provider filter from user params + user_provider_filter = {} + if normalized_params.get("provider_id"): + user_provider_filter["provider_id"] = normalized_params.get("provider_id") + if normalized_params.getlist("provider_id__in"): + user_provider_filter["provider_id__in"] = normalized_params.getlist( + "provider_id__in" + ) + if normalized_params.get("provider_type"): + user_provider_filter["provider__provider"] = normalized_params.get( + "provider_type" + ) + if normalized_params.getlist("provider_type__in"): + user_provider_filter["provider__provider__in"] = normalized_params.getlist( + "provider_type__in" + ) + + # Merge RBAC filter with user filter for scans + scan_filter = {**provider_filter, **user_provider_filter} + + # Get latest completed scan per provider + latest_scan_ids = ( + Scan.all_objects.filter( + tenant_id=tenant_id, state=StateChoices.COMPLETED, **scan_filter + ) + .order_by("provider_id", "-inserted_at") + .distinct("provider_id") + .values_list("id", flat=True) ) - # Convert to dict for easy lookup - results_by_type = {item["attack_surface_type"]: item for item in aggregation} + # Aggregate attack surface data + aggregation = ( + AttackSurfaceOverview.objects.filter( + tenant_id=tenant_id, scan_id__in=latest_scan_ids + ) + .values("attack_surface_type") + .annotate( + total_findings=Coalesce(Sum("total_findings"), 0), + failed_findings=Coalesce(Sum("failed_findings"), 0), + muted_failed_findings=Coalesce(Sum("muted_failed_findings"), 0), + ) + ) - # Always return all attack surface types (fill with zeros if missing) - all_types = AttackSurfaceOverview.AttackSurfaceTypeChoices.values - complete_results = [] + results = { + attack_surface_type: {"total_findings": 0, "failed_findings": 0, "muted_failed_findings": 0} + for attack_surface_type in AttackSurfaceOverview.AttackSurfaceTypeChoices.values + } + for item in aggregation: + results[item["attack_surface_type"]] = { + "total_findings": item["total_findings"], + "failed_findings": item["failed_findings"], + "muted_failed_findings": item["muted_failed_findings"], + } - for attack_surface_type in all_types: - if attack_surface_type in results_by_type: - complete_results.append(results_by_type[attack_surface_type]) - else: - # No data for this type - return zeros - complete_results.append( - { - "attack_surface_type": attack_surface_type, - "total_findings": 0, - "failed_findings": 0, - "muted_failed_findings": 0, - } - ) + response_data = [ + {"attack_surface_type": key, **value} for key, value in results.items() + ] - serializer = self.get_serializer(complete_results, many=True) - return Response(data=serializer.data, status=status.HTTP_200_OK) + return Response( + self.get_serializer(response_data, many=True).data, + status=status.HTTP_200_OK, + ) @extend_schema(tags=["Schedule"]) From 1362c2560430a4df784668f42b288cb980831e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Wed, 26 Nov 2025 16:56:02 +0100 Subject: [PATCH 11/12] refactor(attack-surfaces): DRY overviews --- api/src/backend/api/rbac/permissions.py | 4 +- api/src/backend/api/v1/views.py | 143 ++++++++++++------------ 2 files changed, 74 insertions(+), 73 deletions(-) diff --git a/api/src/backend/api/rbac/permissions.py b/api/src/backend/api/rbac/permissions.py index 6a95e82932..97d7d785e0 100644 --- a/api/src/backend/api/rbac/permissions.py +++ b/api/src/backend/api/rbac/permissions.py @@ -65,11 +65,11 @@ def get_providers(role: Role) -> QuerySet[Provider]: A QuerySet of Provider objects filtered by the role's provider groups. If the role has no provider groups, returns an empty queryset. """ - tenant = role.tenant + tenant_id = role.tenant_id provider_groups = role.provider_groups.all() if not provider_groups.exists(): return Provider.objects.none() return Provider.objects.filter( - tenant=tenant, provider_groups__in=provider_groups + tenant_id=tenant_id, provider_groups__in=provider_groups ).distinct() diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 8ceb857a91..2578b6551d 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -98,6 +98,7 @@ from api.db_utils import rls_transaction from api.exceptions import TaskFailedException from api.filters import ( + AttackSurfaceOverviewFilter, ComplianceOverviewFilter, CustomDjangoFilterBackend, FindingFilter, @@ -4106,7 +4107,7 @@ def attributes(self, request): description="Filter by multiple provider types (comma-separated)", ), ], - ) + ), ) @method_decorator(CACHE_DECORATOR, name="list") class OverviewViewSet(BaseRLSViewSet): @@ -4191,6 +4192,55 @@ def _get_latest_scans_queryset(self): tenant_id=tenant_id, scan_id__in=latest_scan_ids ) + def _normalize_jsonapi_params(self, query_params, exclude_keys=None): + """Convert JSON:API filter params (filter[X]) to flat params (X).""" + exclude_keys = exclude_keys or set() + normalized = QueryDict(mutable=True) + for key, values in query_params.lists(): + normalized_key = ( + key[7:-1] if key.startswith("filter[") and key.endswith("]") else key + ) + if normalized_key not in exclude_keys: + normalized.setlist(normalized_key, values) + return normalized + + def _ensure_allowed_providers(self): + """Populate allowed providers for RBAC-aware queries once per request.""" + if getattr(self, "_providers_initialized", False): + return + self.get_queryset() + self._providers_initialized = True + + def _get_provider_filter(self, provider_field="provider"): + self._ensure_allowed_providers() + if hasattr(self, "allowed_providers"): + return {f"{provider_field}__in": self.allowed_providers} + return {} + + def _apply_provider_filter(self, queryset, provider_field="provider"): + provider_filter = self._get_provider_filter(provider_field) + if provider_filter: + return queryset.filter(**provider_filter) + return queryset + + def _apply_filterset(self, queryset, filterset_class, exclude_keys=None): + normalized_params = self._normalize_jsonapi_params( + self.request.query_params, exclude_keys=set(exclude_keys or []) + ) + filterset = filterset_class(normalized_params, queryset=queryset) + return filterset.qs + + def _latest_scan_ids_for_allowed_providers(self, tenant_id): + provider_filter = self._get_provider_filter() + return ( + Scan.all_objects.filter( + tenant_id=tenant_id, state=StateChoices.COMPLETED, **provider_filter + ) + .order_by("provider_id", "-inserted_at") + .distinct("provider_id") + .values_list("id", flat=True) + ) + @action(detail=False, methods=["get"], url_name="providers") def providers(self, request): tenant_id = self.request.tenant_id @@ -4420,11 +4470,9 @@ def threatscore(self, request): snapshot_id = request.query_params.get("snapshot_id") # Base queryset with RLS - base_queryset = ThreatScoreSnapshot.objects.filter(tenant_id=tenant_id) - - # Apply RBAC filtering - if hasattr(self, "allowed_providers"): - base_queryset = base_queryset.filter(provider__in=self.allowed_providers) + base_queryset = self._apply_provider_filter( + ThreatScoreSnapshot.objects.filter(tenant_id=tenant_id) + ) # Case 1: Specific snapshot requested if snapshot_id: @@ -4440,17 +4488,9 @@ def threatscore(self, request): # Case 2: Latest snapshot per provider (default) # Apply filters manually: this @action is outside the standard list endpoint flow, # so DRF's filter backends don't execute and we must flatten JSON:API params ourselves. - normalized_params = QueryDict(mutable=True) - for param_key, values in request.query_params.lists(): - normalized_key = param_key - if param_key.startswith("filter[") and param_key.endswith("]"): - normalized_key = param_key[7:-1] - if normalized_key == "snapshot_id": - continue - normalized_params.setlist(normalized_key, values) - - filterset = ThreatScoreSnapshotFilter(normalized_params, queryset=base_queryset) - filtered_queryset = filterset.qs + filtered_queryset = self._apply_filterset( + base_queryset, ThreatScoreSnapshotFilter, exclude_keys={"snapshot_id"} + ) # Get distinct provider IDs from filtered queryset # Pick the latest snapshot per provider using Postgres DISTINCT ON pattern. @@ -4702,68 +4742,29 @@ def requirement_sort_key(item): ) def attack_surface(self, request): tenant_id = request.tenant_id - self.get_queryset() # Triggers RBAC setup (sets self.allowed_providers) + latest_scan_ids = self._latest_scan_ids_for_allowed_providers(tenant_id) - # RBAC provider filter - provider_filter = ( - {"provider__in": self.allowed_providers} - if hasattr(self, "allowed_providers") - else {} + # Build base queryset and apply user filters via FilterSet + base_queryset = AttackSurfaceOverview.objects.filter( + tenant_id=tenant_id, scan_id__in=latest_scan_ids ) - - # Parse filter params (filter[provider_id] -> provider_id) - normalized_params = QueryDict(mutable=True) - for key, values in request.query_params.lists(): - normalized_key = ( - key[7:-1] if key.startswith("filter[") and key.endswith("]") else key - ) - normalized_params.setlist(normalized_key, values) - - # Build provider filter from user params - user_provider_filter = {} - if normalized_params.get("provider_id"): - user_provider_filter["provider_id"] = normalized_params.get("provider_id") - if normalized_params.getlist("provider_id__in"): - user_provider_filter["provider_id__in"] = normalized_params.getlist( - "provider_id__in" - ) - if normalized_params.get("provider_type"): - user_provider_filter["provider__provider"] = normalized_params.get( - "provider_type" - ) - if normalized_params.getlist("provider_type__in"): - user_provider_filter["provider__provider__in"] = normalized_params.getlist( - "provider_type__in" - ) - - # Merge RBAC filter with user filter for scans - scan_filter = {**provider_filter, **user_provider_filter} - - # Get latest completed scan per provider - latest_scan_ids = ( - Scan.all_objects.filter( - tenant_id=tenant_id, state=StateChoices.COMPLETED, **scan_filter - ) - .order_by("provider_id", "-inserted_at") - .distinct("provider_id") - .values_list("id", flat=True) + filtered_queryset = self._apply_filterset( + base_queryset, AttackSurfaceOverviewFilter ) # Aggregate attack surface data - aggregation = ( - AttackSurfaceOverview.objects.filter( - tenant_id=tenant_id, scan_id__in=latest_scan_ids - ) - .values("attack_surface_type") - .annotate( - total_findings=Coalesce(Sum("total_findings"), 0), - failed_findings=Coalesce(Sum("failed_findings"), 0), - muted_failed_findings=Coalesce(Sum("muted_failed_findings"), 0), - ) + aggregation = filtered_queryset.values("attack_surface_type").annotate( + total_findings=Coalesce(Sum("total_findings"), 0), + failed_findings=Coalesce(Sum("failed_findings"), 0), + muted_failed_findings=Coalesce(Sum("muted_failed_findings"), 0), ) results = { - attack_surface_type: {"total_findings": 0, "failed_findings": 0, "muted_failed_findings": 0} + attack_surface_type: { + "total_findings": 0, + "failed_findings": 0, + "muted_failed_findings": 0, + } for attack_surface_type in AttackSurfaceOverview.AttackSurfaceTypeChoices.values } for item in aggregation: From 5f3dd4438a4b31edd512910038e65691c0939745 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Tue, 2 Dec 2025 11:52:21 +0100 Subject: [PATCH 12/12] feat(attack-surfaces): add check ID list to response --- api/src/backend/api/apps.py | 11 +++++++ api/src/backend/api/specs/v1.yaml | 22 ++++++++----- api/src/backend/api/tests/test_views.py | 43 ++++++++++++++++++++++--- api/src/backend/api/v1/serializers.py | 3 ++ api/src/backend/api/v1/views.py | 34 ++++++++++++++++--- 5 files changed, 96 insertions(+), 17 deletions(-) diff --git a/api/src/backend/api/apps.py b/api/src/backend/api/apps.py index add97cf376..887e41fcf3 100644 --- a/api/src/backend/api/apps.py +++ b/api/src/backend/api/apps.py @@ -40,6 +40,7 @@ def ready(self): self._ensure_crypto_keys() load_prowler_compliance() + self._initialize_attack_surface_mapping() def _ensure_crypto_keys(self): """ @@ -167,3 +168,13 @@ def _generate_jwt_keys(self): f"Error generating JWT keys: {e}. Please set '{SIGNING_KEY_ENV}' and '{VERIFYING_KEY_ENV}' manually." ) raise e + + def _initialize_attack_surface_mapping(self): + from tasks.jobs.scan import ( # noqa: F401 + _get_attack_surface_mapping_from_provider, + ) + + from api.models import Provider # noqa: F401 + + for provider_type, _label in Provider.ProviderChoices.choices: + _get_attack_surface_mapping_from_provider(provider_type) diff --git a/api/src/backend/api/specs/v1.yaml b/api/src/backend/api/specs/v1.yaml index dcb62d3156..638ec7cc52 100644 --- a/api/src/backend/api/specs/v1.yaml +++ b/api/src/backend/api/specs/v1.yaml @@ -4502,7 +4502,7 @@ paths: operationId: overviews_attack_surfaces_retrieve description: Retrieve aggregated attack surface metrics from latest completed scans per provider. - summary: Get Attack surface overview + summary: Get attack surface overview parameters: - in: query name: fields[attack-surface-overviews] @@ -4515,9 +4515,15 @@ paths: - total_findings - failed_findings - muted_failed_findings + - check_ids description: endpoint return only specific fields in the response on a per-type basis by including a fields[TYPE] query parameter. explode: false + - in: query + name: filter[provider_id.in] + schema: + type: string + description: Filter by multiple provider IDs (comma-separated UUIDs) - in: query name: filter[provider_id] schema: @@ -4525,20 +4531,15 @@ paths: format: uuid description: Filter by specific provider ID - in: query - name: filter[provider_id__in] + name: filter[provider_type.in] schema: type: string - description: Filter by multiple provider IDs (comma-separated UUIDs) + description: Filter by multiple provider types (comma-separated) - in: query name: filter[provider_type] schema: type: string description: Filter by provider type (aws, azure, gcp, etc.) - - in: query - name: filter[provider_type__in] - schema: - type: string - description: Filter by multiple provider types (comma-separated) tags: - Overview security: @@ -10697,6 +10698,11 @@ components: type: integer muted_failed_findings: type: integer + check_ids: + type: array + items: + type: string + readOnly: true required: - id - total_findings diff --git a/api/src/backend/api/tests/test_views.py b/api/src/backend/api/tests/test_views.py index c60915d530..f2b32f23ee 100644 --- a/api/src/backend/api/tests/test_views.py +++ b/api/src/backend/api/tests/test_views.py @@ -6866,6 +6866,7 @@ def test_overview_attack_surface_no_data(self, authenticated_client): assert item["attributes"]["total_findings"] == 0 assert item["attributes"]["failed_findings"] == 0 assert item["attributes"]["muted_failed_findings"] == 0 + assert item["attributes"]["check_ids"] == [] def test_overview_attack_surface_with_data( self, @@ -6877,6 +6878,13 @@ def test_overview_attack_surface_with_data( tenant = tenants_fixture[0] provider = providers_fixture[0] + mapping = { + "internet-exposed": {"aws-check-1", "aws-check-2"}, + "secrets": {"aws-secret-check"}, + "privilege-escalation": {"aws-priv-check"}, + "ec2-imdsv1": {"aws-imdsv1-check"}, + } + scan = Scan.objects.create( name="attack-surface-scan", provider=provider, @@ -6902,7 +6910,11 @@ def test_overview_attack_surface_with_data( muted_failed=2, ) - response = authenticated_client.get(reverse("overview-attack-surface")) + with patch( + "api.v1.views._get_attack_surface_mapping_from_provider", + return_value=mapping, + ): + response = authenticated_client.get(reverse("overview-attack-surface")) assert response.status_code == status.HTTP_200_OK data = response.json()["data"] assert len(data) == 4 @@ -6910,10 +6922,19 @@ def test_overview_attack_surface_with_data( results_by_type = {item["id"]: item["attributes"] for item in data} assert results_by_type["internet-exposed"]["total_findings"] == 20 assert results_by_type["internet-exposed"]["failed_findings"] == 10 + assert set(results_by_type["internet-exposed"]["check_ids"]) == { + "aws-check-1", + "aws-check-2", + } assert results_by_type["secrets"]["total_findings"] == 15 assert results_by_type["secrets"]["failed_findings"] == 8 + assert set(results_by_type["secrets"]["check_ids"]) == {"aws-secret-check"} assert results_by_type["privilege-escalation"]["total_findings"] == 0 + assert set(results_by_type["privilege-escalation"]["check_ids"]) == { + "aws-priv-check" + } assert results_by_type["ec2-imdsv1"]["total_findings"] == 0 + assert set(results_by_type["ec2-imdsv1"]["check_ids"]) == {"aws-imdsv1-check"} def test_overview_attack_surface_provider_filter( self, @@ -6940,6 +6961,13 @@ def test_overview_attack_surface_provider_filter( tenant=tenant, ) + mapping = { + "internet-exposed": {"shared-check", "shared-check"}, + "secrets": set(), + "privilege-escalation": {"priv-check"}, + "ec2-imdsv1": {"imdsv1-check"}, + } + create_attack_surface_overview( tenant, scan1, @@ -6957,15 +6985,20 @@ def test_overview_attack_surface_provider_filter( muted_failed=3, ) - response = authenticated_client.get( - reverse("overview-attack-surface"), - {"filter[provider_id]": str(provider1.id)}, - ) + with patch( + "api.v1.views._get_attack_surface_mapping_from_provider", + return_value=mapping, + ): + response = authenticated_client.get( + reverse("overview-attack-surface"), + {"filter[provider_id]": str(provider1.id)}, + ) assert response.status_code == status.HTTP_200_OK data = response.json()["data"] results_by_type = {item["id"]: item["attributes"] for item in data} assert results_by_type["internet-exposed"]["total_findings"] == 10 assert results_by_type["internet-exposed"]["failed_findings"] == 5 + assert results_by_type["internet-exposed"]["check_ids"] == ["shared-check"] def test_overview_services_region_filter( self, authenticated_client, scan_summaries_fixture diff --git a/api/src/backend/api/v1/serializers.py b/api/src/backend/api/v1/serializers.py index cdec9fc87a..c35d2f91db 100644 --- a/api/src/backend/api/v1/serializers.py +++ b/api/src/backend/api/v1/serializers.py @@ -2226,6 +2226,9 @@ class AttackSurfaceOverviewSerializer(BaseSerializerV1): total_findings = serializers.IntegerField() failed_findings = serializers.IntegerField() muted_failed_findings = serializers.IntegerField() + check_ids = serializers.ListField( + child=serializers.CharField(), allow_empty=True, default=list, read_only=True + ) class JSONAPIMeta: resource_name = "attack-surface-overviews" diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index ddb5fc9053..5cd8dab7dc 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -74,6 +74,7 @@ from rest_framework_simplejwt.exceptions import InvalidToken, TokenError from tasks.beat import schedule_provider_scan from tasks.jobs.export import get_s3_client +from tasks.jobs.scan import _get_attack_surface_mapping_from_provider from tasks.tasks import ( backfill_scan_resource_summaries_task, check_integration_connection_task, @@ -3893,8 +3894,8 @@ def attributes(self, request): filters=True, ), attack_surface=extend_schema( - summary="Retrieve attack surface overview", - description="Returns aggregated attack surface metrics from latest completed scans per provider.", + summary="Get attack surface overview", + description="Retrieve aggregated attack surface metrics from latest completed scans per provider.", tags=["Overview"], parameters=[ OpenApiParameter( @@ -4056,6 +4057,19 @@ def _latest_scan_ids_for_allowed_providers(self, tenant_id): .values_list("id", flat=True) ) + def _attack_surface_check_ids_by_provider_types(self, provider_types): + check_ids_by_type = { + attack_surface_type: set() + for attack_surface_type in AttackSurfaceOverview.AttackSurfaceTypeChoices.values + } + for provider_type in provider_types: + attack_surface_mapping = _get_attack_surface_mapping_from_provider( + provider_type=provider_type + ) + for attack_surface_type, check_ids in attack_surface_mapping.items(): + check_ids_by_type[attack_surface_type].update(check_ids) + return check_ids_by_type + @action(detail=False, methods=["get"], url_name="providers") def providers(self, request): tenant_id = self.request.tenant_id @@ -4566,7 +4580,14 @@ def attack_surface(self, request): filtered_queryset = self._apply_filterset( base_queryset, AttackSurfaceOverviewFilter ) - + provider_types = list( + filtered_queryset.values_list( + "scan__provider__provider", flat=True + ).distinct() + ) + attack_surface_check_ids = self._attack_surface_check_ids_by_provider_types( + provider_types + ) # Aggregate attack surface data aggregation = filtered_queryset.values("attack_surface_type").annotate( total_findings=Coalesce(Sum("total_findings"), 0), @@ -4590,7 +4611,12 @@ def attack_surface(self, request): } response_data = [ - {"attack_surface_type": key, **value} for key, value in results.items() + { + "attack_surface_type": key, + **value, + "check_ids": attack_surface_check_ids.get(key, []), + } + for key, value in results.items() ] return Response(