diff --git a/docs/guides/customization.md b/docs/guides/customization.md new file mode 100644 index 00000000..c0add3e8 --- /dev/null +++ b/docs/guides/customization.md @@ -0,0 +1,326 @@ +# Customizing Pulumi Resource Properties + +Stelvio provides high-level abstractions for AWS resources, exposing the most commonly used configuration options through component constructors. However, your architecture might require fine-grained control over the underlying Pulumi resources that Stelvio creates. + +The `customize` parameter allows you to override or extend default Pulumi resource properties without modifying Stelvio's source code. + +## When to Use Customization + +Use the `customize` parameter when you need to: + +- Set Pulumi properties not exposed by Stelvio's API (e.g., `force_destroy` on S3 buckets) +- Override default values that Stelvio sets internally +- Add tags, encryption settings, or other resource-specific configurations +- Configure advanced features like VPC settings or custom IAM policies + +## Basic Usage + +Pass a `customize` dictionary to any Stelvio component. The dictionary keys correspond to the underlying resources that the component creates: + +```python +from stelvio.aws.s3 import Bucket + +@app.run +def run() -> None: + bucket = Bucket( + "example-bucket", + customize={ + "bucket": {"force_destroy": True} + } + ) +``` + +In this example, `"bucket"` refers to the S3 bucket resource created by the `Bucket` component, and `force_destroy` is a Pulumi property that allows the bucket to be deleted even when it contains objects. + +## Understanding Resource Keys + +Each Stelvio component creates one or more underlying Pulumi resources. The `customize` dictionary keys match the resource names defined in the component's resources dataclass. + +### Example: S3 Bucket + +The `Bucket` component creates these resources: + +| Resource Key | Pulumi Resource Type | Description | +|------------------------|-----------------------------------------|-----------------------------------------------| +| `bucket` | `pulumi_aws.s3.Bucket` | The S3 bucket itself | +| `public_access_block` | `pulumi_aws.s3.BucketPublicAccessBlock` | Public access block settings | +| `bucket_policy` | `pulumi_aws.s3.BucketPolicy` | Bucket policy (when `access="public"`) | + +You can customize any of these: + +```python +bucket = Bucket( + "my-bucket", + customize={ + "bucket": { + "force_destroy": True, + "tags": {"Environment": "dev"}, + }, + "public_access_block": { + "block_public_acls": True, + }, + } +) +``` + +### Example: Lambda Function + +The `Function` component creates these resources: + +| Resource Key | Pulumi Resource Type | Description | +|-----------------|----------------------------------|------------------------------------| +| `function` | `pulumi_aws.lambda_.Function` | The Lambda function | +| `role` | `pulumi_aws.iam.Role` | IAM execution role | +| `policy` | `pulumi_aws.iam.Policy` | IAM policy attached to the role | +| `function_url` | `pulumi_aws.lambda_.FunctionUrl` | Function URL (when configured) | + +```python +from stelvio.aws.function import Function + +fn = Function( + "my-function", + handler="functions/handler.main", + customize={ + "function": { + "reserved_concurrent_executions": 10, + "tracing_config": {"mode": "Active"}, + } + } +) +``` + +### Example: DynamoDB Table + +The `DynamoTable` component creates: + +| Resource Key | Pulumi Resource Type | Description | +|--------------|-------------------------------|-----------------------| +| `table` | `pulumi_aws.dynamodb.Table` | The DynamoDB table | + +```python +from stelvio.aws.dynamo_db import DynamoTable + +table = DynamoTable( + name="orders", + fields={"id": "string"}, + partition_key="id", + customize={ + "table": { + "tags": {"Project": "my-app"}, + "point_in_time_recovery": {"enabled": True}, + } + } +) +``` + +### Example: SQS Queue + +The `Queue` component creates: + +| Resource Key | Pulumi Resource Type | Description | +|--------------|-------------------------|-------------------| +| `queue` | `pulumi_aws.sqs.Queue` | The SQS queue | + +```python +from stelvio.aws.queue import Queue + +queue = Queue( + "my-queue", + customize={ + "queue": { + "tags": {"Team": "backend"}, + "kms_master_key_id": "alias/my-key", + } + } +) +``` + +### Example: SNS Topic + +The `Topic` component creates: + +| Resource Key | Pulumi Resource Type | Description | +|--------------|-------------------------|-------------------| +| `topic` | `pulumi_aws.sns.Topic` | The SNS topic | + +```python +from stelvio.aws.topic import Topic + +topic = Topic( + "my-topic", + customize={ + "topic": { + "tags": {"Service": "notifications"}, + "kms_master_key_id": "alias/my-key", + } + } +) +``` + +### Example: Cron (Scheduled Lambda) + +The `Cron` component creates these resources: + +| Resource Key | Pulumi Resource Type | Description | +|--------------|------------------------------------|--------------------------------------| +| `rule` | `pulumi_aws.cloudwatch.EventRule` | The EventBridge rule with schedule | +| `target` | `pulumi_aws.cloudwatch.EventTarget`| The target linking rule to Lambda | +| `function` | (nested `FunctionCustomizationDict`) | The Lambda function (see Function) | + +```python +from stelvio.aws.cron import Cron + +cron = Cron( + "my-cron", + "rate(1 hour)", + "functions/cleanup.handler", + customize={ + "rule": { + "tags": {"Schedule": "hourly"}, + }, + "target": { + "retry_policy": {"maximum_event_age_in_seconds": 3600}, + } + } +) +``` + +### Example: Email (SES) + +The `Email` component creates these resources: + +| Resource Key | Pulumi Resource Type | Description | +|-----------------------|---------------------------------------------------|---------------------------------------| +| `identity` | `pulumi_aws.sesv2.EmailIdentity` | The SES email identity | +| `configuration_set` | `pulumi_aws.sesv2.ConfigurationSet` | SES configuration set | +| `verification` | `pulumi_aws.ses.DomainIdentityVerification` | Domain verification (for domains) | +| `event_destinations` | `pulumi_aws.sesv2.ConfigurationSetEventDestination` | Event destination (when configured) | + +```python +from stelvio.aws.email import Email + +email = Email( + "my-email", + "notifications@example.com", + customize={ + "identity": { + "tags": {"Service": "notifications"}, + }, + "configuration_set": { + "tags": {"Environment": "production"}, + } + } +) +``` + +## How Customization Works + +When you provide customizations, Stelvio merges your values with its default configuration: + +1. **Stelvio defaults** are applied first +2. **Your customizations** override or extend those defaults + +This means you only need to specify the properties you want to change—Stelvio's sensible defaults remain in place for everything else. + +!!! note "Shallow Merge" + The merge is shallow at each property level. If you customize a nested object (like `tags`), your entire object replaces the default, rather than being deep-merged. + +## Global Customization + +Apply customizations to all instances of a component type using the `customize` option in `StelvioAppConfig`: + +```python +from stelvio.app import StelvioApp +from stelvio.config import StelvioAppConfig +from stelvio.aws.s3 import Bucket +from stelvio.aws.function import Function + +app = StelvioApp("my-project") + +@app.config +def configuration(env: str) -> StelvioAppConfig: + return StelvioAppConfig( + customize={ + Bucket: { + "bucket": {"force_destroy": True} + }, + Function: { + "function": { + "tracing_config": {"mode": "Active"} + } + } + } + ) + +@app.run +def run() -> None: + # Both buckets inherit force_destroy=True + bucket1 = Bucket("bucket-one") + bucket2 = Bucket("bucket-two") + + # All functions have X-Ray tracing enabled + fn = Function("my-fn", handler="functions/handler.main") +``` + +The global `customize` dictionary uses **component types** as keys (e.g., `Bucket`, `Function`) and the same resource customization dictionaries as values. + +### Combining Global and Per-Instance Customization + +When both global and per-instance customizations are provided, they are merged with the following precedence (highest to lowest): + +1. **Per-instance** `customize` parameter +2. **Global** `customize` from `StelvioAppConfig` +3. **Stelvio defaults** + +```python +@app.config +def configuration(env: str) -> StelvioAppConfig: + return StelvioAppConfig( + customize={ + Bucket: {"bucket": {"force_destroy": True}} + } + ) + +@app.run +def run() -> None: + # Uses global customization: force_destroy=True + bucket1 = Bucket("bucket-one") + + # Per-instance overrides global: force_destroy=False + bucket2 = Bucket( + "bucket-two", + customize={"bucket": {"force_destroy": False}} + ) +``` + +## Environment-Specific Customization + +Combine customization with environment-based configuration for environment-specific settings: + +```python +@app.config +def configuration(env: str) -> StelvioAppConfig: + if env == "dev": + return StelvioAppConfig( + customize={ + Bucket: {"bucket": {"force_destroy": True}}, + } + ) + else: + # Production: keep default safe behavior + return StelvioAppConfig() +``` + +## Finding Available Properties + +To discover which properties you can customize for each resource, refer to the Pulumi AWS provider documentation: + +- [S3 Bucket](https://www.pulumi.com/registry/packages/aws/api-docs/s3/bucket/) +- [Lambda Function](https://www.pulumi.com/registry/packages/aws/api-docs/lambda/function/) +- [DynamoDB Table](https://www.pulumi.com/registry/packages/aws/api-docs/dynamodb/table/) +- [SQS Queue](https://www.pulumi.com/registry/packages/aws/api-docs/sqs/queue/) +- [SNS Topic](https://www.pulumi.com/registry/packages/aws/api-docs/sns/topic/) +- [API Gateway REST API](https://www.pulumi.com/registry/packages/aws/api-docs/apigateway/restapi/) + +!!! tip "IDE Support" + If you're using an IDE with Python type checking, the customization dictionaries are fully typed. Your IDE can provide autocompletion and validation for available properties. \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 75a83ff5..13c920c6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -127,6 +127,7 @@ Stelvio is released under the Apache 2.0 License. See the LICENSE file for detai - [Email](guides/email.md) - Send emails using SES - [Project Structure](guides/project-structure.md) - Organizing your code - [State Management](guides/state.md) - Understand Deployment State +- [Parameter Customization](guides/customization.md) - Customize internals of cloud primitives - [Troubleshooting](guides/troubleshooting.md) - Common misconceptions ### Reference diff --git a/mkdocs.yml b/mkdocs.yml index 0c251749..544db730 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -79,6 +79,7 @@ nav: - Routing components on one domain: guides/cloudfront-router.md - Project Structure: guides/project-structure.md - State Management: guides/state.md + - Parameter Customization: guides/customization.md - Troubleshooting: guides/troubleshooting.md markdown_extensions: diff --git a/stelvio/app.py b/stelvio/app.py index fc0d05f2..dc5c4bf7 100644 --- a/stelvio/app.py +++ b/stelvio/app.py @@ -2,7 +2,7 @@ from collections.abc import Callable from importlib import import_module from pathlib import Path -from typing import ClassVar, TypeVar, final +from typing import Any, ClassVar, TypeVar, final from pulumi import Resource as PulumiResource @@ -29,7 +29,7 @@ def __init__( self, name: str, modules: list[str] | None = None, - link_configs: dict[type[Component[T]], Callable[[T], LinkConfig]] | None = None, + link_configs: dict[type[Component[T, Any]], Callable[[T], LinkConfig]] | None = None, ): if StelvioApp.__instance is not None: raise RuntimeError("StelvioApp has already been instantiated.") @@ -88,7 +88,7 @@ def run() -> None: @staticmethod def set_user_link_for( - component_type: type[Component[T]], func: Callable[[T], LinkConfig] + component_type: type[Component[T, Any]], func: Callable[[T], LinkConfig] ) -> None: """Register a user-defined link creator that overrides defaults""" ComponentRegistry.register_user_link_creator(component_type, func) diff --git a/stelvio/aws/acm.py b/stelvio/aws/acm.py index 5ff99919..0ecc1342 100644 --- a/stelvio/aws/acm.py +++ b/stelvio/aws/acm.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import final +from typing import Any, TypedDict, final import pulumi import pulumi_aws @@ -17,11 +17,24 @@ class AcmValidatedDomainResources: cert_validation: pulumi_aws.acm.CertificateValidation +class AcmValidatedDomainCustomizationDict(TypedDict, total=False): + certificate: pulumi_aws.acm.CertificateArgs | dict[str, Any] | None + validation_record: dict[str, Any] | None # TODO + cert_validation: pulumi_aws.acm.CertificateValidationArgs | dict[str, Any] | None + + @final -class AcmValidatedDomain(Component[AcmValidatedDomainResources]): - def __init__(self, name: str, domain_name: str): +class AcmValidatedDomain( + Component[AcmValidatedDomainResources, AcmValidatedDomainCustomizationDict] +): + def __init__( + self, + name: str, + domain_name: str, + customize: AcmValidatedDomainCustomizationDict | None = None, + ): self.domain_name = domain_name - super().__init__(name) + super().__init__(name, customize=customize) def _create_resources(self) -> AcmValidatedDomainResources: dns = context().dns @@ -34,8 +47,13 @@ def _create_resources(self) -> AcmValidatedDomainResources: # 1 - Issue Certificate certificate = pulumi_aws.acm.Certificate( context().prefix(f"{self.name}-certificate"), - domain_name=self.domain_name, - validation_method="DNS", + **self._customizer( + "certificate", + { + "domain_name": self.domain_name, + "validation_method": "DNS", + }, + ), ) # 2 - Validate Certificate with DNS PROVIDER @@ -43,17 +61,27 @@ def _create_resources(self) -> AcmValidatedDomainResources: validation_record = dns.create_caa_record( resource_name=context().prefix(f"{self.name}-certificate-validation-record"), name=first_option.apply(lambda opt: opt["resource_record_name"]), - record_type=first_option.apply(lambda opt: opt["resource_record_type"]), - content=first_option.apply(lambda opt: opt["resource_record_value"]), - ttl=1, + **self._customizer( + "validation_record", + { + "record_type": first_option.apply(lambda opt: opt["resource_record_type"]), + "content": first_option.apply(lambda opt: opt["resource_record_value"]), + "ttl": 1, + }, + ), ) # 3 - Wait for validation - use the validation record's FQDN to ensure it exists cert_validation = pulumi_aws.acm.CertificateValidation( context().prefix(f"{self.name}-certificate-validation"), - certificate_arn=certificate.arn, - # This ensures validation_record exists - validation_record_fqdns=[validation_record.name], + **self._customizer( + "cert_validation", + { + "certificate_arn": certificate.arn, + # This ensures validation_record exists + "validation_record_fqdns": [validation_record.name], + }, + ), opts=pulumi.ResourceOptions( depends_on=[certificate, validation_record.pulumi_resource] ), diff --git a/stelvio/aws/api_gateway/api.py b/stelvio/aws/api_gateway/api.py index 334f51d9..40e9479e 100644 --- a/stelvio/aws/api_gateway/api.py +++ b/stelvio/aws/api_gateway/api.py @@ -1,8 +1,10 @@ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, Unpack, final +from typing import Any, Literal, TypedDict, Unpack, final import pulumi -from pulumi import Output, ResourceOptions +import pulumi_aws +from pulumi import Input, Output, ResourceOptions from pulumi_aws import get_caller_identity, get_region from pulumi_aws.apigateway import ( Authorizer as PulumiAuthorizer, @@ -38,7 +40,7 @@ create_cors_gateway_responses, create_cors_options_methods, ) -from stelvio.aws.api_gateway.deployment import _calculate_deployment_hash, _create_deployment +from stelvio.aws.api_gateway.deployment import _calculate_deployment_hash from stelvio.aws.api_gateway.iam import _create_api_gateway_account_and_role from stelvio.aws.api_gateway.routing import ( _get_group_config_map, @@ -58,8 +60,14 @@ class ApiResources: stage: Stage +class ApiCustomizationDict(TypedDict, total=False): + rest_api: pulumi_aws.apigateway.RestApiArgs | dict[str, Any] | None + deployment: pulumi_aws.apigateway.DeploymentArgs | dict[str, Any] | None + stage: pulumi_aws.apigateway.StageArgs | dict[str, Any] | None + + @final -class Api(Component[ApiResources]): +class Api(Component[ApiResources, ApiCustomizationDict]): _routes: list[_ApiRoute] _config: ApiConfig _authorizers: list[_Authorizer] @@ -69,6 +77,7 @@ def __init__( self, name: str, config: ApiConfig | None = None, + customize: ApiCustomizationDict | None = None, **opts: Unpack[ApiConfigDict], ) -> None: self._routes = [] @@ -76,7 +85,7 @@ def __init__( self._default_auth = None self._config = self._parse_config(config, opts) self._validate_cors_for_rest_api() - super().__init__(name) + super().__init__(name, customize=customize) @staticmethod def _parse_config(config: ApiConfig | ApiConfigDict | None, opts: ApiConfigDict) -> ApiConfig: @@ -527,6 +536,30 @@ def _create_authorizers(self, rest_api: RestApi) -> dict[str, PulumiAuthorizer]: return authorizer_resources + def _create_deployment( + self, + api: RestApi, + api_name: str, + trigger_hash: str, + depends_on: Input[Sequence[Input[Resource]] | Resource] | None = None, + ) -> Deployment: + """Creates the API deployment, triggering redeployment based on config changes.""" + pulumi.log.debug(f"API '{api_name}' deployment trigger hash: {trigger_hash}") + + return Deployment( + context().prefix(f"{api_name}-deployment"), + **self._customizer( + "deployment", + { + "rest_api": api.id, + # Trigger new deployment only when API route config changes + "triggers": {"configuration_hash": trigger_hash}, + }, + ), + # Ensure deployment happens after all resources/methods/integrations are created + opts=ResourceOptions(depends_on=depends_on), + ) + def _create_resources(self) -> ApiResources: # This is what needs to be done: # 1. create rest api @@ -552,7 +585,13 @@ def _create_resources(self) -> ApiResources: # c. create base path mapping endpoint_type = self._config.endpoint_type or DEFAULT_ENDPOINT_TYPE rest_api = RestApi( - context().prefix(self.name), endpoint_configuration={"types": endpoint_type.upper()} + context().prefix(self.name), + **self._customizer( + "rest_api", + { + "endpoint_configuration": {"types": endpoint_type.upper()}, + }, + ), ) account = _create_api_gateway_account_and_role() @@ -604,31 +643,36 @@ def _create_resources(self) -> ApiResources: ) trigger_hash = _calculate_deployment_hash(self._routes, self._default_auth, cors_config) - deployment = _create_deployment( + deployment = self._create_deployment( rest_api, self.name, trigger_hash, depends_on=all_deployment_dependencies ) stage_name = self._config.stage_name or DEFAULT_STAGE_NAME stage = Stage( safe_name(context().prefix(), f"{self.name}-stage-{stage_name}", 128), - rest_api=rest_api.id, - deployment=deployment.id, - stage_name=stage_name, - # xray_tracing_enabled=True, - access_log_settings={ - "destination_arn": rest_api.name.apply( - lambda name: f"arn:aws:logs:{get_region().name}:" - f"{get_caller_identity().account_id}" - f":log-group:/aws/apigateway/{name}" - ), - "format": '{"requestId":"$context.requestId", "ip": "$context.identity.sourceIp", ' - '"caller":"$context.identity.caller", "user":"$context.identity.user",' - '"requestTime":"$context.requestTime", "httpMethod":' - '"$context.httpMethod","resourcePath":"$context.resourcePath", ' - '"status":"$context.status","protocol":"$context.protocol", ' - '"responseLength":"$context.responseLength"}', - }, - variables={"loggingLevel": "INFO"}, + **self._customizer( + "stage", + { + "rest_api": rest_api.id, + "deployment": deployment.id, + "stage_name": stage_name, + # xray_tracing_enabled=True, + "access_log_settings": { + "destination_arn": rest_api.name.apply( + lambda name: f"arn:aws:logs:{get_region().name}:" + f"{get_caller_identity().account_id}" + f":log-group:/aws/apigateway/{name}" + ), + "format": '{"requestId":"$context.requestId", "ip": "$context.identity.sourceIp", ' # noqa: E501 + '"caller":"$context.identity.caller", "user":"$context.identity.user",' + '"requestTime":"$context.requestTime", "httpMethod":' + '"$context.httpMethod","resourcePath":"$context.resourcePath", ' + '"status":"$context.status","protocol":"$context.protocol", ' + '"responseLength":"$context.responseLength"}', + }, + "variables": {"loggingLevel": "INFO"}, + }, + ), opts=ResourceOptions(depends_on=[account]), ) diff --git a/stelvio/aws/api_gateway/deployment.py b/stelvio/aws/api_gateway/deployment.py index fab69b2a..251d9b5b 100644 --- a/stelvio/aws/api_gateway/deployment.py +++ b/stelvio/aws/api_gateway/deployment.py @@ -1,13 +1,7 @@ import json -from collections.abc import Sequence from hashlib import sha256 from typing import TYPE_CHECKING, Literal -import pulumi -from pulumi import Input, ResourceOptions -from pulumi_aws.apigateway import Deployment, Resource, RestApi - -from stelvio import context from stelvio.aws.api_gateway.config import _ApiRoute, _Authorizer from stelvio.aws.function import Function from stelvio.aws.function.config import FunctionConfig @@ -82,22 +76,3 @@ def get_effective_auth(route: _ApiRoute) -> "_Authorizer | Literal['IAM', False] } return sha256(json.dumps(config, sort_keys=True).encode()).hexdigest() - - -def _create_deployment( - api: RestApi, - api_name: str, - trigger_hash: str, - depends_on: Input[Sequence[Input[Resource]] | Resource] | None = None, -) -> Deployment: - """Creates the API deployment, triggering redeployment based on config changes.""" - pulumi.log.debug(f"API '{api_name}' deployment trigger hash: {trigger_hash}") - - return Deployment( - context().prefix(f"{api_name}-deployment"), - rest_api=api.id, - # Trigger new deployment only when API route config changes - triggers={"configuration_hash": trigger_hash}, - # Ensure deployment happens after all resources/methods/integrations are created - opts=ResourceOptions(depends_on=depends_on), - ) diff --git a/stelvio/aws/cloudfront/cloudfront.py b/stelvio/aws/cloudfront/cloudfront.py index a3726c2a..fd1a7311 100644 --- a/stelvio/aws/cloudfront/cloudfront.py +++ b/stelvio/aws/cloudfront/cloudfront.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal, TypedDict, final +from typing import TYPE_CHECKING, Any, Literal, TypedDict, final import pulumi import pulumi_aws @@ -37,17 +37,29 @@ class CloudFrontDistributionResources: function_associations: list[FunctionAssociation] | None +class CloudFrontDistributionCustomizationDict(TypedDict, total=False): + distribution: pulumi_aws.cloudfront.DistributionArgs | dict[str, Any] | None + origin_access_control: pulumi_aws.cloudfront.OriginAccessControlArgs | dict[str, Any] | None + record: dict[str, Any] | None + bucket_policy: pulumi_aws.s3.BucketPolicyArgs | dict[str, Any] | None + # TODO + # function_associations: pulumi_aws.cloudfront.DistributionDefaultCacheBehaviorFunctionAssociationArgs | dict[str, Any] | None # noqa: E501 + + @final -class CloudFrontDistribution(Component[CloudFrontDistributionResources]): - def __init__( +class CloudFrontDistribution( + Component[CloudFrontDistributionResources, CloudFrontDistributionCustomizationDict] +): + def __init__( # noqa: PLR0913 self, name: str, bucket: Bucket, price_class: CloudfrontPriceClass = "PriceClass_100", custom_domain: str | None = None, function_associations: list[FunctionAssociation] | None = None, + customize: CloudFrontDistributionCustomizationDict | None = None, ): - super().__init__(name) + super().__init__(name, customize=customize) self.bucket = bucket self.custom_domain = custom_domain self.price_class = price_class @@ -63,103 +75,125 @@ def _create_resources(self) -> CloudFrontDistributionResources: acm_validated_domain = AcmValidatedDomain( f"{self.name}-acm-validated-domain", domain_name=self.custom_domain, + customize=self._customize, ) # Create Origin Access Control for S3 origin_access_control = pulumi_aws.cloudfront.OriginAccessControl( context().prefix(f"{self.name}-oac"), - description=f"Origin Access Control for {self.name}", - origin_access_control_origin_type="s3", - signing_behavior="always", - signing_protocol="sigv4", + **self._customizer( + "origin_access_control", + { + "description": f"Origin Access Control for {self.name}", + "origin_access_control_origin_type": "s3", + "signing_behavior": "always", + "signing_protocol": "sigv4", + }, + ), ) # Create CloudFront Distribution distribution = pulumi_aws.cloudfront.Distribution( context().prefix(self.name), - aliases=[self.custom_domain] if self.custom_domain else None, - origins=[ - { - "domain_name": self.bucket.resources.bucket.bucket_regional_domain_name, - "origin_id": f"{self.name}-S3-Origin", - "origin_access_control_id": origin_access_control.id, - } - ], - enabled=True, - is_ipv6_enabled=True, - default_root_object="index.html", - default_cache_behavior={ - "allowed_methods": ["GET", "HEAD", "OPTIONS"], # Reduced to read-only methods - "cached_methods": ["GET", "HEAD"], - "target_origin_id": f"{self.name}-S3-Origin", - "compress": True, - "viewer_protocol_policy": "redirect-to-https", - "forwarded_values": { - "query_string": False, - "cookies": {"forward": "none"}, - "headers": ["If-Modified-Since"], # Forward cache validation headers - }, - "min_ttl": 0, - "default_ttl": 300, # Reduce default TTL to 5 minutes for faster updates - "max_ttl": 3600, # Reduce max TTL to 1 hour - "function_associations": self.function_associations, - }, - price_class=self.price_class, - restrictions={ - "geo_restriction": { - "restriction_type": "none", - } - }, - viewer_certificate={ - "acm_certificate_arn": acm_validated_domain.resources.certificate.arn, - "ssl_support_method": "sni-only", - "minimum_protocol_version": "TLSv1.2_2021", - } - if self.custom_domain - else { - "cloudfront_default_certificate": True, - }, - custom_error_responses=[ - { - "error_code": 403, - "response_code": 404, - "response_page_path": "/error.html", - "error_caching_min_ttl": 0, # Don't cache 403 errors - }, + **self._customizer( + "distribution", { - "error_code": 404, - "response_code": 404, - "response_page_path": "/error.html", - "error_caching_min_ttl": 300, # Cache 404s for only 5 minutes + "aliases": [self.custom_domain] if self.custom_domain else None, + "origins": [ + { + "domain_name": self.bucket.resources.bucket.bucket_regional_domain_name, # noqa: E501 + "origin_id": f"{self.name}-S3-Origin", + "origin_access_control_id": origin_access_control.id, + } + ], + "enabled": True, + "is_ipv6_enabled": True, + "default_root_object": "index.html", + "default_cache_behavior": { + "allowed_methods": [ + "GET", + "HEAD", + "OPTIONS", + ], # Reduced to read-only methods + "cached_methods": ["GET", "HEAD"], + "target_origin_id": f"{self.name}-S3-Origin", + "compress": True, + "viewer_protocol_policy": "redirect-to-https", + "forwarded_values": { + "query_string": False, + "cookies": {"forward": "none"}, + "headers": ["If-Modified-Since"], # Forward cache validation headers + }, + "min_ttl": 0, + "default_ttl": 300, # Reduce default TTL to 5 minutes for faster updates + "max_ttl": 3600, # Reduce max TTL to 1 hour + "function_associations": self.function_associations, + }, + "price_class": self.price_class, + "restrictions": { + "geo_restriction": { + "restriction_type": "none", + } + }, + "viewer_certificate": { + "acm_certificate_arn": acm_validated_domain.resources.certificate.arn, + "ssl_support_method": "sni-only", + "minimum_protocol_version": "TLSv1.2_2021", + } + if self.custom_domain + else { + "cloudfront_default_certificate": True, + }, + "custom_error_responses": [ + { + "error_code": 403, + "response_code": 404, + "response_page_path": "/error.html", + "error_caching_min_ttl": 0, # Don't cache 403 errors + }, + { + "error_code": 404, + "response_code": 404, + "response_page_path": "/error.html", + "error_caching_min_ttl": 300, # Cache 404s for only 5 minutes + }, + ], }, - ], + ), ) # Update S3 bucket policy to allow CloudFront access bucket_policy = pulumi_aws.s3.BucketPolicy( context().prefix(f"{self.name}-bucket-policy"), - bucket=self.bucket.resources.bucket.id, - policy=pulumi.Output.all( - distribution_arn=distribution.arn, - bucket_arn=self.bucket.arn, - ).apply( - lambda args: pulumi.Output.json_dumps( - { - "Version": "2012-10-17", - "Statement": [ + **self._customizer( + "bucket_policy", + { + "bucket": self.bucket.resources.bucket.id, + "policy": pulumi.Output.all( + distribution_arn=distribution.arn, + bucket_arn=self.bucket.arn, + ).apply( + lambda args: pulumi.Output.json_dumps( { - "Sid": "AllowCloudFrontServicePrincipal", - "Effect": "Allow", - "Principal": {"Service": "cloudfront.amazonaws.com"}, - "Action": "s3:GetObject", - "Resource": f"{args['bucket_arn']}/*", - "Condition": { - "StringEquals": {"AWS:SourceArn": args["distribution_arn"]} - }, + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "AllowCloudFrontServicePrincipal", + "Effect": "Allow", + "Principal": {"Service": "cloudfront.amazonaws.com"}, + "Action": "s3:GetObject", + "Resource": f"{args['bucket_arn']}/*", + "Condition": { + "StringEquals": { + "AWS:SourceArn": args["distribution_arn"] + } + }, + } + ], } - ], - } - ) + ) + ), + }, ), opts=pulumi.ResourceOptions( depends_on=[distribution] @@ -171,9 +205,15 @@ def _create_resources(self) -> CloudFrontDistributionResources: record = context().dns.create_record( resource_name=context().prefix(f"{self.name}-cloudfront-record"), name=self.custom_domain, - record_type="CNAME", - value=distribution.domain_name, - ttl=1, + # TODO + **self._customizer( + "record", + { + "record_type": "CNAME", + "value": distribution.domain_name, + "ttl": 1, + }, + ), ) pulumi.export(f"cloudfront_{self.name}_domain_name", distribution.domain_name) diff --git a/stelvio/aws/cloudfront/origins/components/url.py b/stelvio/aws/cloudfront/origins/components/url.py index 35be2d89..7889a842 100644 --- a/stelvio/aws/cloudfront/origins/components/url.py +++ b/stelvio/aws/cloudfront/origins/components/url.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import final +from typing import Any, final from urllib.parse import urlparse import pulumi @@ -21,7 +21,7 @@ class UrlResources: @final -class Url(Component[UrlResources], LinkableMixin): +class Url(Component[UrlResources, Any], LinkableMixin): def __init__(self, name: str, url: str): super().__init__(name) self._validate_url(url) diff --git a/stelvio/aws/cloudfront/router.py b/stelvio/aws/cloudfront/router.py index ed7c455c..4738c66e 100644 --- a/stelvio/aws/cloudfront/router.py +++ b/stelvio/aws/cloudfront/router.py @@ -1,6 +1,6 @@ import hashlib from dataclasses import dataclass -from typing import final +from typing import Any, TypedDict, final import pulumi import pulumi_aws @@ -27,16 +27,26 @@ class RouterResources: record: Record | None +class RouterCustomizationDict(TypedDict, total=False): + distribution: pulumi_aws.cloudfront.DistributionArgs | dict[str, Any] | None + origin_access_controls: pulumi_aws.cloudfront.OriginAccessControlArgs | dict[str, Any] | None + access_policies: pulumi_aws.s3.BucketPolicyArgs | dict[str, Any] | None + cloudfront_functions: pulumi_aws.cloudfront.FunctionArgs | dict[str, Any] | None + acm_validated_domain: dict[str, Any] | None + record: dict[str, Any] | None # TODO + + @final -class Router(Component[RouterResources]): +class Router(Component[RouterResources, RouterCustomizationDict]): def __init__( self, name: str, routes: list[Route] | None = None, price_class: CloudfrontPriceClass = "PriceClass_100", custom_domain: str | None = None, + customize: RouterCustomizationDict | None = None, ): - super().__init__(name) + super().__init__(name, customize=customize) self.routes = routes or [] self.price_class = price_class self.custom_domain = custom_domain @@ -50,6 +60,7 @@ def _create_resources(self) -> RouterResources: acm_validated_domain = AcmValidatedDomain( f"{self.name}-acm-validated-domain", domain_name=self.custom_domain, + customize=self._customize, ) if not self.routes: @@ -72,6 +83,7 @@ def _create_resources(self) -> RouterResources: default_404_function = pulumi_aws.cloudfront.Function( context().prefix(f"{self.name}-default-404"), + # TODO: No _customizer here as inconsistent with resource namings runtime="cloudfront-js-2.0", code=default_404_function_code, comment="Return 404 for unmatched routes", @@ -131,27 +143,32 @@ def _create_resources(self) -> RouterResources: distribution = pulumi_aws.cloudfront.Distribution( context().prefix(self.name), - aliases=[self.custom_domain] if self.custom_domain else None, - origins=[rc.origins for rc in route_configs], - enabled=True, - is_ipv6_enabled=True, - default_cache_behavior=default_cache_behavior, - ordered_cache_behaviors=ordered_cache_behaviors or None, - price_class=self.price_class, - restrictions={ - "geo_restriction": { - "restriction_type": "none", - } - }, - viewer_certificate={ - "acm_certificate_arn": acm_validated_domain.resources.certificate.arn, - "ssl_support_method": "sni-only", - "minimum_protocol_version": "TLSv1.2_2021", - } - if self.custom_domain - else { - "cloudfront_default_certificate": True, - }, + **self._customizer( + "distribution", + { + "aliases": [self.custom_domain] if self.custom_domain else None, + "origins": [rc.origins for rc in route_configs], + "enabled": True, + "is_ipv6_enabled": True, + "default_cache_behavior": default_cache_behavior, + "ordered_cache_behaviors": ordered_cache_behaviors or None, + "price_class": self.price_class, + "restrictions": { + "geo_restriction": { + "restriction_type": "none", + } + }, + "viewer_certificate": { + "acm_certificate_arn": acm_validated_domain.resources.certificate.arn, + "ssl_support_method": "sni-only", + "minimum_protocol_version": "TLSv1.2_2021", + } + if self.custom_domain + else { + "cloudfront_default_certificate": True, + }, + }, + ), ) # Create bucket policies to allow CloudFront access for each S3 bucket @@ -166,6 +183,14 @@ def _create_resources(self) -> RouterResources: record = context().dns.create_record( resource_name=context().prefix(f"{self.name}-cloudfront-record"), name=self.custom_domain, + # **self._customizer( + # "record", + # { + # "record_type": "CNAME", + # "value": distribution.domain_name, + # "ttl": 1, + # }, + # ), record_type="CNAME", value=distribution.domain_name, ttl=1, diff --git a/stelvio/aws/cron.py b/stelvio/aws/cron.py index fc5b86ee..87e613c2 100644 --- a/stelvio/aws/cron.py +++ b/stelvio/aws/cron.py @@ -2,13 +2,14 @@ import json from dataclasses import dataclass -from typing import Any, Unpack, final +from typing import Any, TypedDict, Unpack, final import pulumi from pulumi_aws import cloudwatch, lambda_ from stelvio import context from stelvio.aws.function import Function, FunctionConfig, FunctionConfigDict +from stelvio.aws.function.function import FunctionCustomizationDict from stelvio.component import Component, safe_name @@ -116,7 +117,13 @@ class CronResources: function: lambda_.Function -class Cron(Component[CronResources]): +class CronCustomizationDict(TypedDict, total=False): + rule: cloudwatch.EventRuleArgs | dict[str, Any] | None + target: cloudwatch.EventTargetArgs | dict[str, Any] | None + function: FunctionCustomizationDict | dict[str, Any] | None # TODO + + +class Cron(Component[CronResources, CronCustomizationDict]): """Schedule Lambda function execution using EventBridge Rules. Creates an EventBridge Rule with a schedule expression (rate or cron) that @@ -160,7 +167,7 @@ class Cron(Component[CronResources]): ) """ - def __init__( + def __init__( # noqa: PLR0913 self, name: str, schedule: str, @@ -169,9 +176,10 @@ def __init__( *, enabled: bool = True, payload: dict[str, Any] | None = None, + customize: CronCustomizationDict | None = None, **opts: Unpack[FunctionConfigDict], ): - super().__init__(name) + super().__init__(name, customize=customize) # Validate and parse inputs using pure functions _validate_schedule(schedule) @@ -188,23 +196,37 @@ def _create_resources(self) -> CronResources: if isinstance(self._handler_config, Function): stelvio_function = self._handler_config else: - stelvio_function = Function(f"{self.name}-fn", config=self._handler_config) + stelvio_function = Function( + f"{self.name}-fn", + config=self._handler_config, + customize=self._customize.get("function"), + ) lambda_function = stelvio_function.resources.function # Create EventBridge Rule with schedule rule = cloudwatch.EventRule( safe_name(context().prefix(), f"{self.name}-rule", 64), - schedule_expression=self._schedule, - state="ENABLED" if self._enabled else "DISABLED", + **self._customizer( + "rule", + { + "schedule_expression": self._schedule, + "state": "ENABLED" if self._enabled else "DISABLED", + }, + ), ) # Create EventBridge Target linking rule to Lambda target = cloudwatch.EventTarget( safe_name(context().prefix(), f"{self.name}-target", 64), - rule=rule.name, - arn=lambda_function.arn, - input=json.dumps(self._payload) if self._payload is not None else None, + **self._customizer( + "target", + { + "rule": rule.name, + "arn": lambda_function.arn, + "input": json.dumps(self._payload) if self._payload is not None else None, + }, + ), ) # Create Lambda Permission for EventBridge to invoke the function diff --git a/stelvio/aws/dynamo_db.py b/stelvio/aws/dynamo_db.py index 20042a8c..86c4704f 100644 --- a/stelvio/aws/dynamo_db.py +++ b/stelvio/aws/dynamo_db.py @@ -1,14 +1,20 @@ from dataclasses import dataclass, field, replace from enum import Enum -from typing import Literal, TypedDict, Unpack, final +from typing import Any, Literal, TypedDict, Unpack, final import pulumi from pulumi import Output -from pulumi_aws.dynamodb import Table -from pulumi_aws.lambda_ import EventSourceMapping +from pulumi_aws.dynamodb import Table, TableArgs +from pulumi_aws.lambda_ import EventSourceMapping, EventSourceMappingArgs from stelvio import context -from stelvio.aws.function import Function, FunctionConfig, FunctionConfigDict, parse_handler_config +from stelvio.aws.function import ( + Function, + FunctionConfig, + FunctionConfigDict, + FunctionCustomizationDict, + parse_handler_config, +) from stelvio.aws.permission import AwsPermission from stelvio.component import Component, link_config_creator, safe_name from stelvio.link import Link, LinkableMixin, LinkConfig @@ -214,8 +220,19 @@ class DynamoTableResources: table: Table +class DynamoSubscriptionCustomizationDict(TypedDict, total=False): + function: FunctionCustomizationDict | dict[str, Any] | None # TODO! + event_source_mapping: EventSourceMappingArgs | dict[str, Any] | None + + +class DynamoTableCustomizationDict(TypedDict, total=False): + table: TableArgs | dict[str, Any] | None + + @final -class DynamoSubscription(Component[DynamoSubscriptionResources]): +class DynamoSubscription( + Component[DynamoSubscriptionResources, DynamoSubscriptionCustomizationDict] +): def __init__( # noqa: PLR0913 self, name: str, @@ -224,9 +241,10 @@ def __init__( # noqa: PLR0913 filters: list[dict] | None, batch_size: int | None, opts: FunctionConfigDict, + customize: DynamoSubscriptionCustomizationDict | None = None, ): # Add suffix because we want to use 'name' for Function, avoiding component name conflicts - super().__init__(f"{name}-subscription") + super().__init__(f"{name}-subscription", customize=customize) self.table = table self.function_name = name # Function gets the original name @@ -246,17 +264,26 @@ def _create_resources(self) -> DynamoSubscriptionResources: config_with_merged_links = replace(self.handler, links=merged_links) # Create function with merged permissions - function = Function(self.function_name, config_with_merged_links) + function = Function( + self.function_name, + config_with_merged_links, + customize=self._customize.get("function", {}), + ) # Create EventSourceMapping - table.stream_arn triggers table creation naturally mapping = EventSourceMapping( - context().prefix(f"{self.name}-mapping"), - event_source_arn=self.table.stream_arn, - function_name=function.function_name, - starting_position="LATEST", - batch_size=self.batch_size or 100, - maximum_batching_window_in_seconds=0, - filter_criteria={"filters": self.filters} if self.filters else None, + safe_name(context().prefix(), f"{self.name}-mapping", 128), + **self._customizer( + "event_source_mapping", + { + "event_source_arn": self.table.stream_arn, + "function_name": function.function_name, + "starting_position": "LATEST", + "batch_size": self.batch_size or 100, + "maximum_batching_window_in_seconds": 0, + "filter_criteria": {"filters": self.filters} if self.filters else None, + }, + ), ) return DynamoSubscriptionResources(function, mapping) @@ -281,7 +308,7 @@ def _create_stream_link(self) -> Link: @final -class DynamoTable(Component[DynamoTableResources], LinkableMixin): +class DynamoTable(Component[DynamoTableResources, DynamoTableCustomizationDict], LinkableMixin): _subscriptions: list[DynamoSubscription] def __init__( @@ -289,9 +316,10 @@ def __init__( name: str, *, config: DynamoTableConfig | DynamoTableConfigDict | None = None, + customize: DynamoTableCustomizationDict | None = None, **opts: Unpack[DynamoTableConfigDict], ): - super().__init__(name) + super().__init__(name, customize=customize) self._config = self._parse_config(config, opts) self._subscriptions = [] @@ -397,7 +425,9 @@ def subscribe( if any(sub.name == expected_subscription_name for sub in self._subscriptions): raise ValueError(f"Subscription '{name}' already exists for table '{self.name}'") - subscription = DynamoSubscription(function_name, self, handler, filters, batch_size, opts) + subscription = DynamoSubscription( + function_name, self, handler, filters, batch_size, opts, customize=self._customize + ) self._subscriptions.append(subscription) return subscription @@ -407,14 +437,21 @@ def _create_resources(self) -> DynamoTableResources: table = Table( safe_name(context().prefix(), self.name, TABLE_NAME_MAX_LENGTH), - billing_mode="PAY_PER_REQUEST", - hash_key=self.partition_key, - range_key=self.sort_key, - attributes=[{"name": k, "type": v} for k, v in self._config.normalized_fields.items()], - local_secondary_indexes=local_indexes or None, - global_secondary_indexes=global_indexes or None, - stream_enabled=self._config.stream_enabled, - stream_view_type=self._config.normalized_stream_view_type, + **self._customizer( + "table", + { + "billing_mode": "PAY_PER_REQUEST", + "hash_key": self.partition_key, + "range_key": self.sort_key, + "attributes": [ + {"name": k, "type": v} for k, v in self._config.normalized_fields.items() + ], + "local_secondary_indexes": local_indexes or None, + "global_secondary_indexes": global_indexes or None, + "stream_enabled": self._config.stream_enabled, + "stream_view_type": self._config.normalized_stream_view_type, + }, + ), ) pulumi.export(f"dynamotable_{self.name}_arn", table.arn) pulumi.export(f"dynamotable_{self.name}_name", table.name) diff --git a/stelvio/aws/email.py b/stelvio/aws/email.py index 3bcff444..8d810ac1 100644 --- a/stelvio/aws/email.py +++ b/stelvio/aws/email.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal, TypedDict, Unpack, final +from typing import Any, Literal, TypedDict, Unpack, final import pulumi import pulumi_aws @@ -44,6 +44,17 @@ class EmailResources: event_destinations: list[pulumi_aws.sesv2.ConfigurationSetEventDestination] | None = None +class EmailCustomizationDict(TypedDict, total=False): + identity: pulumi_aws.sesv2.EmailIdentityArgs | dict[str, Any] | None + configuration_set: pulumi_aws.sesv2.ConfigurationSetArgs | dict[str, Any] | None + dkim_records: dict[str, Any] | None # TODO + dmarc_record: dict[str, Any] | None # TODO + verification: pulumi_aws.ses.DomainIdentityVerificationArgs | dict[str, Any] | None + event_destinations: ( + pulumi_aws.sesv2.ConfigurationSetEventDestinationArgs | dict[str, Any] | None + ) + + class EventConfiguration(TypedDict, total=False): name: str types: list[EventType] @@ -72,16 +83,17 @@ class EmailConfig: @final -class Email(Component[EmailResources], LinkableMixin): +class Email(Component[EmailResources, EmailCustomizationDict], LinkableMixin): _config: EmailConfig def __init__( self, name: str, config: EmailConfig | EmailConfigDict | None = None, + customize: EmailCustomizationDict | None = None, **opts: Unpack[EmailConfigDict], ): - super().__init__(name) + super().__init__(name, customize=customize) self._config = self._parse_config(config, opts) self.is_domain = "@" not in self.config.sender # We allow passing in a DNS provider since email verification may @@ -183,14 +195,24 @@ def check_email(self, email: str) -> None: def _create_resources(self) -> EmailResources: configuration_set = pulumi_aws.sesv2.ConfigurationSet( - resource_name=context().prefix(f"{self.name}-config-set"), - configuration_set_name=f"{self.name}-config-set", + **self._customizer( + "configuration_set", + { + "resource_name": context().prefix(f"{self.name}-config-set"), + "configuration_set_name": f"{self.name}-config-set", + }, + ), ) identity = pulumi_aws.sesv2.EmailIdentity( resource_name=context().prefix(f"{self.name}-identity"), - email_identity=self.sender, - configuration_set_name=configuration_set.configuration_set_name, + **self._customizer( + "identity", + { + "email_identity": self.sender, + "configuration_set_name": configuration_set.configuration_set_name, + }, + ), ) pulumi.export(f"{self.name}-ses-configuration-set-arn", configuration_set.arn) @@ -208,10 +230,19 @@ def _create_resources(self) -> EmailResources: ) record = self.dns.create_record( resource_name=context().prefix(f"{self.name}-dkim-record-{i}"), - name=token.apply(lambda t: f"{t}._domainkey.{self.sender}"), - record_type="CNAME", - value=token.apply(lambda t: f"{t}.dkim.amazonses.com"), - ttl=600, + # name=token.apply(lambda t: f"{t}._domainkey.{self.sender}"), + # value=token.apply(lambda t: f"{t}.dkim.amazonses.com"), + # ttl=600, + # record_type="CNAME", + **self._customizer( # TODO + "dkim_records", + { + "name": token.apply(lambda t: f"{t}._domainkey.{self.sender}"), + "value": token.apply(lambda t: f"{t}.dkim.amazonses.com"), + "record_type": "CNAME", + "ttl": 600, + }, + ), ) dkim_records.append(record) pulumi.export(f"{self.name}-dkim-record-{i}-name", record.name) @@ -221,15 +252,25 @@ def _create_resources(self) -> EmailResources: dmarc_record = self.dns.create_record( resource_name=context().prefix(f"{self.name}-dmarc-record"), name=f"_dmarc.{self.sender}", - record_type="TXT", - value=self.dmarc, - ttl=600, + **self._customizer( + "dmarc_record", + { + "record_type": "TXT", + "value": self.dmarc, + "ttl": 600, + }, + ), ) pulumi.export(f"{self.name}-dmarc-record-name", dmarc_record.name) pulumi.export(f"{self.name}-dmarc-record-value", dmarc_record.value) verification = pulumi_aws.ses.DomainIdentityVerification( resource_name=context().prefix(f"{self.name}-identity-verification"), - domain=identity.email_identity, + **self._customizer( + "verification", + { + "domain": identity.email_identity, + }, + ), opts=pulumi.ResourceOptions(depends_on=[identity]), ) pulumi.export(f"{self.name}-ses-domain-verification-token-arn", verification.arn) @@ -238,14 +279,19 @@ def _create_resources(self) -> EmailResources: for event in self.events: event_destination = pulumi_aws.sesv2.ConfigurationSetEventDestination( resource_name=context().prefix(f"{self.name}-event-{event['name']}"), - configuration_set_name=configuration_set.configuration_set_name, - event_destination_name=event["name"], - event_destination=pulumi_aws.sesv2.ConfigurationSetEventDestinationEventDestinationArgs( - enabled=True, - matching_event_types=event["types"], - sns_destination=pulumi_aws.sesv2.ConfigurationSetEventDestinationEventDestinationSnsDestinationArgs( - topic_arn=event["topic_arn"] - ), + **self._customizer( + "event_destinations", + { + "configuration_set_name": configuration_set.configuration_set_name, + "event_destination_name": event["name"], + "event_destination": pulumi_aws.sesv2.ConfigurationSetEventDestinationEventDestinationArgs( # noqa: E501 + enabled=True, + matching_event_types=event["types"], + sns_destination=pulumi_aws.sesv2.ConfigurationSetEventDestinationEventDestinationSnsDestinationArgs( + topic_arn=event["topic_arn"] + ), + ), + }, ), ) event_destinations.append(event_destination) diff --git a/stelvio/aws/function/__init__.py b/stelvio/aws/function/__init__.py index 47868e87..43c852c7 100644 --- a/stelvio/aws/function/__init__.py +++ b/stelvio/aws/function/__init__.py @@ -1,5 +1,5 @@ from .config import FunctionConfig, FunctionConfigDict, FunctionUrlConfig, FunctionUrlConfigDict -from .function import Function, FunctionResources +from .function import Function, FunctionCustomizationDict, FunctionResources def parse_handler_config( @@ -57,6 +57,7 @@ def parse_handler_config( "Function", "FunctionConfig", "FunctionConfigDict", + "FunctionCustomizationDict", "FunctionResources", "FunctionUrlConfig", "FunctionUrlConfigDict", diff --git a/stelvio/aws/function/function.py b/stelvio/aws/function/function.py index 984d5d1e..4a54448a 100644 --- a/stelvio/aws/function/function.py +++ b/stelvio/aws/function/function.py @@ -11,13 +11,13 @@ from dataclasses import dataclass from hashlib import sha256 from pathlib import Path -from typing import ClassVar, Unpack, final +from typing import Any, ClassVar, TypedDict, Unpack, final import pulumi from awslambdaric.lambda_context import LambdaContext from pulumi import Input, Output, ResourceOptions from pulumi_aws import lambda_ -from pulumi_aws.iam import GetPolicyDocumentStatementArgs, Policy, Role +from pulumi_aws.iam import GetPolicyDocumentStatementArgs, Policy, PolicyArgs, Role, RoleArgs from pulumi_aws.lambda_ import FunctionUrl, FunctionUrlCorsArgs from stelvio import context @@ -62,8 +62,17 @@ class FunctionResources: function_url: FunctionUrl | None = None +class FunctionCustomizationDict(TypedDict, total=False): + function: lambda_.FunctionArgs | dict[str, Any] | None + role: RoleArgs | dict[str, Any] | None + policy: PolicyArgs | dict[str, Any] | None + function_url: lambda_.FunctionUrlArgs | dict[str, Any] | None + + @final -class Function(Component[FunctionResources], BridgeableMixin, LinkableMixin): +class Function( + Component[FunctionResources, FunctionCustomizationDict], BridgeableMixin, LinkableMixin +): """AWS Lambda function component with automatic resource discovery. Args: @@ -92,9 +101,10 @@ def __init__( self, name: str, config: None | FunctionConfig | FunctionConfigDict = None, + customize: FunctionCustomizationDict | None = None, **opts: Unpack[FunctionConfigDict], ): - super().__init__(name) + super().__init__(name, customize=customize) self._config = self._parse_config(config, opts) self._dev_endpoint_id = f"{self.name}-{sha256(uuid.uuid4().bytes).hexdigest()[:8]}" @@ -228,15 +238,22 @@ def _create_resources(self) -> FunctionResources: else: function_resource = lambda_.Function( safe_name(context().prefix(), self.name, 64), - role=lambda_role.arn, - architectures=[function_architecture], - runtime=function_runtime, - code=_create_lambda_archive(self.config, lambda_resource_file_content), - handler=self.config.handler_format, - environment={"variables": env_vars}, - memory_size=self.config.memory or DEFAULT_MEMORY, - timeout=self.config.timeout or DEFAULT_TIMEOUT, - layers=[layer.arn for layer in self.config.layers] if self.config.layers else None, + **self._customizer( + "function", + { + "role": lambda_role.arn, + "architectures": [function_architecture], + "runtime": function_runtime, + "code": _create_lambda_archive(self.config, lambda_resource_file_content), + "handler": self.config.handler_format, + "environment": {"variables": env_vars}, + "memory_size": self.config.memory or DEFAULT_MEMORY, + "timeout": self.config.timeout or DEFAULT_TIMEOUT, + "layers": [layer.arn for layer in self.config.layers] + if self.config.layers + else None, + }, + ), # Technically this is necessary only for tests as otherwise it's ok if role # attachments are created after functions opts=ResourceOptions(depends_on=role_attachments), diff --git a/stelvio/aws/layer.py b/stelvio/aws/layer.py index a9e3e935..32fdf209 100644 --- a/stelvio/aws/layer.py +++ b/stelvio/aws/layer.py @@ -1,11 +1,11 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Final, final +from typing import Any, Final, TypedDict, final import pulumi from pulumi import Archive, Asset, AssetArchive, FileArchive, Output -from pulumi_aws.lambda_ import LayerVersion +from pulumi_aws.lambda_ import LayerVersion, LayerVersionArgs from stelvio import context from stelvio.aws._packaging.dependencies import ( @@ -37,8 +37,12 @@ class LayerResources: layer_version: LayerVersion +class LayerCustomizationDict(TypedDict, total=False): + layer_version: LayerVersionArgs | dict[str, Any] | None + + @final -class Layer(Component[LayerResources]): +class Layer(Component[LayerResources, LayerCustomizationDict]): """ Represents an AWS Lambda Layer, enabling code and dependency sharing. @@ -68,7 +72,7 @@ class Layer(Component[LayerResources]): _architecture: AwsArchitecture | None _runtime: AwsLambdaRuntime | None - def __init__( + def __init__( # noqa: PLR0913 self, name: str, *, @@ -76,8 +80,9 @@ def __init__( requirements: str | list[str] | bool | None = None, runtime: AwsLambdaRuntime | None = None, architecture: AwsArchitecture | None = None, + customize: LayerCustomizationDict | None = None, ): - super().__init__(name) + super().__init__(name, customize=customize) self._code = code self._requirements = requirements self._runtime = runtime @@ -138,10 +143,15 @@ def _create_resources(self) -> LayerResources: layer_version_resource = LayerVersion( context().prefix(self.name), - layer_name=context().prefix(self.name), - code=asset_archive, - compatible_runtimes=[runtime], - compatible_architectures=[architecture], + **self._customizer( + "layer_version", + { + "layer_name": context().prefix(self.name), + "code": asset_archive, + "compatible_runtimes": [runtime], + "compatible_architectures": [architecture], + }, + ), ) pulumi.export(f"layer_{self.name}_name", layer_version_resource.layer_name) diff --git a/stelvio/aws/queue.py b/stelvio/aws/queue.py index 891b136f..d40647bc 100644 --- a/stelvio/aws/queue.py +++ b/stelvio/aws/queue.py @@ -4,11 +4,18 @@ import pulumi from pulumi import Output -from pulumi_aws.lambda_ import EventSourceMapping +from pulumi_aws.lambda_ import EventSourceMapping, EventSourceMappingArgs from pulumi_aws.sqs import Queue as SqsQueue +from pulumi_aws.sqs import QueueArgs from stelvio import context -from stelvio.aws.function import Function, FunctionConfig, FunctionConfigDict, parse_handler_config +from stelvio.aws.function import ( + Function, + FunctionConfig, + FunctionConfigDict, + FunctionCustomizationDict, + parse_handler_config, +) from stelvio.aws.permission import AwsPermission from stelvio.component import Component, link_config_creator, safe_name from stelvio.link import Link, LinkableMixin, LinkConfig @@ -91,8 +98,13 @@ class QueueSubscriptionResources: event_source_mapping: EventSourceMapping +class QueueSubscriptionCustomizationDict(TypedDict, total=False): + function: FunctionCustomizationDict | dict[str, Any] | None + event_source_mapping: EventSourceMappingArgs | dict[str, Any] | None + + @final -class QueueSubscription(Component[QueueSubscriptionResources]): +class QueueSubscription(Component[QueueSubscriptionResources, QueueSubscriptionCustomizationDict]): """Lambda function subscription to an SQS queue.""" def __init__( # noqa: PLR0913 @@ -103,9 +115,10 @@ def __init__( # noqa: PLR0913 batch_size: int | None, filters: list[SqsFilterDict] | None, opts: FunctionConfigDict, + customize: QueueSubscriptionCustomizationDict | None = None, ): # Add suffix because we want to use 'name' for Function, avoiding component name conflicts - super().__init__(f"{name}-subscription") + super().__init__(f"{name}-subscription", customize=customize) self.queue = queue self.function_name = name # Function gets the original name @@ -171,20 +184,29 @@ def _create_resources(self) -> QueueSubscriptionResources: config_with_merged_links = replace(self.handler, links=merged_links) # Create function with merged permissions - function = Function(self.function_name, config_with_merged_links) + function = Function( + self.function_name, + config_with_merged_links, + customize=self._customize.get("function"), + ) # Create EventSourceMapping for SQS mapping = EventSourceMapping( safe_name(context().prefix(), f"{self.name}-mapping", 128), - event_source_arn=self.queue.arn, - function_name=function.function_name, - batch_size=self.batch_size or DEFAULT_QUEUE_BATCH_SIZE, - filter_criteria=( - {"filters": [{"pattern": json.dumps(f)} for f in self.filters]} - if self.filters - else None + **self._customizer( + "event_source_mapping", + { + "event_source_arn": self.queue.arn, + "function_name": function.function_name, + "batch_size": self.batch_size or DEFAULT_QUEUE_BATCH_SIZE, + "filter_criteria": ( + {"filters": [{"pattern": json.dumps(f)} for f in self.filters]} + if self.filters + else None + ), + "enabled": True, + }, ), - enabled=True, ) return QueueSubscriptionResources(function=function, event_source_mapping=mapping) @@ -207,13 +229,18 @@ def _create_sqs_link(self) -> Link: ) +class QueueCustomizationDict(TypedDict, total=False): + queue: QueueArgs | dict[str, Any] | None + + @final -class Queue(Component[QueueResources], LinkableMixin): +class Queue(Component[QueueResources, QueueCustomizationDict], LinkableMixin): """AWS SQS Queue component. Args: name: Queue name config: Complete queue configuration as QueueConfig or dict + customize: Customization dictionary **opts: Individual queue configuration parameters You can configure the queue in two ways: @@ -238,9 +265,10 @@ def __init__( /, *, config: QueueConfig | QueueConfigDict | None = None, + customize: QueueCustomizationDict | None = None, **opts: Unpack[QueueConfigDict], ): - super().__init__(name) + super().__init__(name, customize=customize) self._config = self._parse_config(config, opts) self._subscriptions = [] @@ -324,13 +352,18 @@ def _create_resources(self) -> QueueResources: queue = SqsQueue( safe_name(context().prefix(), f"{self.name}", 128), - name=queue_name, - delay_seconds=self.config.delay, - visibility_timeout_seconds=self.config.visibility_timeout, - message_retention_seconds=self.config.retention, - fifo_queue=self.config.fifo if self.config.fifo else None, - content_based_deduplication=True if self.config.fifo else None, - redrive_policy=redrive_policy, + **self._customizer( + "queue", + { + "name": queue_name, + "delay_seconds": self.config.delay, + "visibility_timeout_seconds": self.config.visibility_timeout, + "message_retention_seconds": self.config.retention, + "fifo_queue": self.config.fifo if self.config.fifo else None, + "content_based_deduplication": True if self.config.fifo else None, + "redrive_policy": redrive_policy, + }, + ), ) pulumi.export(f"queue_{self.name}_arn", queue.arn) diff --git a/stelvio/aws/s3/s3.py b/stelvio/aws/s3/s3.py index 356b7a98..e02e5cc4 100644 --- a/stelvio/aws/s3/s3.py +++ b/stelvio/aws/s3/s3.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal, final +from typing import Any, Literal, TypedDict, final import pulumi import pulumi_aws @@ -18,12 +18,22 @@ class S3BucketResources: bucket_policy: pulumi_aws.s3.BucketPolicy | None +class S3BucketCustomizationDict(TypedDict, total=False): + bucket: pulumi_aws.s3.BucketArgs | dict[str, Any] | None + public_access_block: pulumi_aws.s3.BucketPublicAccessBlockArgs | dict[str, Any] | None + bucket_policy: pulumi_aws.s3.BucketPolicyArgs | dict[str, Any] | None + + @final -class Bucket(Component[S3BucketResources], LinkableMixin): +class Bucket(Component[S3BucketResources, S3BucketCustomizationDict], LinkableMixin): def __init__( - self, name: str, versioning: bool = False, access: Literal["public"] | None = None + self, + name: str, + versioning: bool = False, + access: Literal["public"] | None = None, + customize: S3BucketCustomizationDict | None = None, ): - super().__init__(name) + super().__init__(name, customize=customize) self.versioning = versioning self.access = access self._resources = None @@ -31,8 +41,13 @@ def __init__( def _create_resources(self) -> S3BucketResources: bucket = pulumi_aws.s3.Bucket( context().prefix(self.name), - bucket=context().prefix(self.name), - versioning={"enabled": self.versioning}, + **self._customizer( + "bucket", + { + "bucket": context().prefix(self.name), + "versioning": {"enabled": self.versioning}, + }, + ), ) # Configure public access block @@ -40,11 +55,16 @@ def _create_resources(self) -> S3BucketResources: # setup readonly configuration public_access_block = pulumi_aws.s3.BucketPublicAccessBlock( context().prefix(f"{self.name}-pab"), - bucket=bucket.id, - block_public_acls=False, - block_public_policy=False, - ignore_public_acls=False, - restrict_public_buckets=False, + **self._customizer( + "public_access_block", + { + "bucket": bucket.id, + "block_public_acls": False, + "block_public_policy": False, + "ignore_public_acls": False, + "restrict_public_buckets": False, + }, + ), ) public_read_policy = pulumi_aws.iam.get_policy_document( statements=[ @@ -70,11 +90,16 @@ def _create_resources(self) -> S3BucketResources: else: public_access_block = pulumi_aws.s3.BucketPublicAccessBlock( context().prefix(f"{self.name}-pab"), - bucket=bucket.id, - block_public_acls=True, - block_public_policy=True, - ignore_public_acls=True, - restrict_public_buckets=True, + **self._customizer( + "public_access_block", + { + "bucket": bucket.id, + "block_public_acls": True, + "block_public_policy": True, + "ignore_public_acls": True, + "restrict_public_buckets": True, + }, + ), ) bucket_policy = None diff --git a/stelvio/aws/s3/s3_static_website.py b/stelvio/aws/s3/s3_static_website.py index 6b63b3b2..141c14ad 100644 --- a/stelvio/aws/s3/s3_static_website.py +++ b/stelvio/aws/s3/s3_static_website.py @@ -2,13 +2,14 @@ import re from dataclasses import dataclass from pathlib import Path -from typing import final +from typing import Any, TypedDict, final import pulumi import pulumi_aws from stelvio import context from stelvio.aws.cloudfront import CloudFrontDistribution +from stelvio.aws.cloudfront.cloudfront import CloudFrontDistributionCustomizationDict from stelvio.aws.s3.s3 import Bucket from stelvio.component import Component, safe_name @@ -21,6 +22,12 @@ class S3StaticWebsiteResources: cloudfront_distribution: CloudFrontDistribution +class S3StaticWebsiteCustomizationDict(TypedDict, total=False): + bucket: pulumi_aws.s3.BucketArgs | dict[str, Any] | None + files: pulumi_aws.s3.BucketObjectArgs | dict[str, Any] | None + cloudfront_distribution: CloudFrontDistributionCustomizationDict | None + + REQUEST_INDEX_HTML_FUNCTION_JS = """ function handler(event) { var request = event.request; @@ -39,15 +46,16 @@ class S3StaticWebsiteResources: @final -class S3StaticWebsite(Component[S3StaticWebsiteResources]): +class S3StaticWebsite(Component[S3StaticWebsiteResources, S3StaticWebsiteCustomizationDict]): def __init__( self, name: str, custom_domain: str | None = None, directory: Path | str | None = None, default_cache_ttl: int = 120, + customize: S3StaticWebsiteCustomizationDict | None = None, ): - super().__init__(name) + super().__init__(name, customize=customize) self.directory = Path(directory) if isinstance(directory, str) else directory self.custom_domain = custom_domain self.default_cache_ttl = default_cache_ttl @@ -58,7 +66,7 @@ def _create_resources(self) -> S3StaticWebsiteResources: if self.directory is not None and not self.directory.exists(): raise FileNotFoundError(f"Directory does not exist: {self.directory}") - bucket = Bucket(f"{self.name}-bucket") + bucket = Bucket(f"{self.name}-bucket", customize=self._customize) # Create CloudFront Function to handle directory index rewriting viewer_request_function = pulumi_aws.cloudfront.Function( context().prefix(f"{self.name}-viewer-request"), @@ -69,6 +77,19 @@ def _create_resources(self) -> S3StaticWebsiteResources: ) cloudfront_distribution = CloudFrontDistribution( name=f"{self.name}-cloudfront", + # **self._customizer( + # "cloudfront_distribution", + # { + # "bucket": bucket, + # "custom_domain": self.custom_domain, + # "function_associations": [ + # { + # "event_type": "viewer-request", + # "function_arn": viewer_request_function.arn, + # } + # ], + # }, + # ), bucket=bucket, custom_domain=self.custom_domain, function_associations=[ @@ -77,6 +98,7 @@ def _create_resources(self) -> S3StaticWebsiteResources: "function_arn": viewer_request_function.arn, } ], + customize=self._customize.get("cloudfront_distribution", {}), ) # Upload files from directory to S3 bucket @@ -127,11 +149,16 @@ def _create_s3_bucket_object( return pulumi_aws.s3.BucketObject( safe_name(context().prefix(), resource_name, 128, "-p"), - bucket=bucket.resources.bucket.id, - key=str(key), - source=pulumi.FileAsset(file_path), - content_type=mimetype, - cache_control=cache_control, + **self._customizer( + "files", + { + "bucket": bucket.resources.bucket.id, + "key": str(key), + "source": pulumi.FileAsset(file_path), + "content_type": mimetype, + "cache_control": cache_control, + }, + ), ) def _process_directory_and_upload_files( diff --git a/stelvio/aws/topic.py b/stelvio/aws/topic.py index c2a0acd4..ba7e6ebb 100644 --- a/stelvio/aws/topic.py +++ b/stelvio/aws/topic.py @@ -1,13 +1,19 @@ import json from dataclasses import dataclass -from typing import Unpack, final +from typing import Any, TypedDict, Unpack, final import pulumi from pulumi import Input, Output, ResourceOptions from pulumi_aws import lambda_, sns, sqs from stelvio import context -from stelvio.aws.function import Function, FunctionConfig, FunctionConfigDict, parse_handler_config +from stelvio.aws.function import ( + Function, + FunctionConfig, + FunctionConfigDict, + FunctionCustomizationDict, + parse_handler_config, +) from stelvio.aws.permission import AwsPermission from stelvio.aws.queue import Queue from stelvio.component import Component, link_config_creator, safe_name @@ -44,41 +50,62 @@ class TopicQueueSubscriptionResources: queue_policy: sqs.QueuePolicy | None # None if ARN string was passed +class TopicSubscriptionCustomizationDict(TypedDict, total=False): + function: FunctionCustomizationDict | dict[str, Any] | None + subscription: sns.TopicSubscriptionArgs | dict[str, Any] | None + permission: lambda_.PermissionArgs | dict[str, Any] | None + + @final -class TopicSubscription(Component[TopicSubscriptionResources]): +class TopicSubscription(Component[TopicSubscriptionResources, TopicSubscriptionCustomizationDict]): """Lambda function subscription to an SNS topic.""" - def __init__( + def __init__( # noqa: PLR0913 self, name: str, topic: "Topic", handler: str | FunctionConfig | FunctionConfigDict | None, filter_: dict[str, list] | None, opts: FunctionConfigDict, + customize: TopicSubscriptionCustomizationDict | None = None, ): - super().__init__(f"{name}-subscription") + super().__init__(f"{name}-subscription", customize=customize) self.topic = topic self.function_name = name self.filter_ = filter_ self.handler = parse_handler_config(handler, opts) def _create_resources(self) -> TopicSubscriptionResources: - function = Function(self.function_name, self.handler) + function = Function( + self.function_name, + self.handler, + customize=self._customize.get("function"), + ) subscription = sns.TopicSubscription( safe_name(context().prefix(), self.name, MAX_TOPIC_NAME_LENGTH), - topic=self.topic.arn, - protocol="lambda", - endpoint=function.resources.function.arn, - filter_policy=json.dumps(self.filter_) if self.filter_ else None, + **self._customizer( + "subscription", + { + "topic": self.topic.arn, + "protocol": "lambda", + "endpoint": function.resources.function.arn, + "filter_policy": json.dumps(self.filter_) if self.filter_ else None, + }, + ), ) permission = lambda_.Permission( safe_name(context().prefix(), f"{self.name}-perm", 100), - action="lambda:InvokeFunction", - function=function.function_name, - principal="sns.amazonaws.com", - source_arn=self.topic.arn, + **self._customizer( + "permission", + { + "action": "lambda:InvokeFunction", + "function": function.function_name, + "principal": "sns.amazonaws.com", + "source_arn": self.topic.arn, + }, + ), ) return TopicSubscriptionResources( @@ -88,19 +115,27 @@ def _create_resources(self) -> TopicSubscriptionResources: ) +class TopicQueueSubscriptionCustomizationDict(TypedDict, total=False): + subscription: sns.TopicSubscriptionArgs | dict[str, Any] | None + queue_policy: sqs.QueuePolicyArgs | dict[str, Any] | None + + @final -class TopicQueueSubscription(Component[TopicQueueSubscriptionResources]): +class TopicQueueSubscription( + Component[TopicQueueSubscriptionResources, TopicQueueSubscriptionCustomizationDict] +): """SQS queue subscription to an SNS topic.""" - def __init__( + def __init__( # noqa: PLR0913 self, name: str, topic: "Topic", queue: Queue | Input[str], filter_: dict[str, list] | None, raw_message_delivery: bool, + customize: TopicQueueSubscriptionCustomizationDict | None = None, ): - super().__init__(name) + super().__init__(name, customize=customize) self.topic = topic self.queue = queue self.filter_ = filter_ @@ -116,11 +151,16 @@ def _create_resources(self) -> TopicQueueSubscriptionResources: subscription = sns.TopicSubscription( safe_name(context().prefix(), self.name, MAX_TOPIC_NAME_LENGTH), - topic=self.topic.arn, - protocol="sqs", - endpoint=queue_arn, - filter_policy=json.dumps(self.filter_) if self.filter_ else None, - raw_message_delivery=self.raw_message_delivery, + **self._customizer( + "subscription", + { + "topic": self.topic.arn, + "protocol": "sqs", + "endpoint": queue_arn, + "filter_policy": json.dumps(self.filter_) if self.filter_ else None, + "raw_message_delivery": self.raw_message_delivery, + }, + ), opts=ResourceOptions(depends_on=[queue_policy]) if queue_policy else None, ) @@ -160,18 +200,28 @@ def _create_queue_policy(self) -> sqs.QueuePolicy: f"{queue.name}-{self.topic.name}-sns-policy", MAX_TOPIC_NAME_LENGTH, ), - queue_url=queue.url, - policy=policy_document, + **self._customizer( + "queue_policy", + { + "queue_url": queue.url, + "policy": policy_document, + }, + ), ) +class TopicCustomizationDict(TypedDict, total=False): + topic: sns.TopicArgs | dict[str, Any] | None + + @final -class Topic(Component[TopicResources], LinkableMixin): +class Topic(Component[TopicResources, TopicCustomizationDict], LinkableMixin): """AWS SNS Topic component. Args: name: Topic name fifo: Whether this is a FIFO topic (default: False) + customize: Customization dictionary Examples: # Standard topic @@ -196,8 +246,9 @@ def __init__( /, *, fifo: bool = False, + customize: TopicCustomizationDict | None = None, ): - super().__init__(name) + super().__init__(name, customize=customize) self._fifo = fifo self._subscriptions = [] self._queue_subscriptions = [] @@ -224,9 +275,14 @@ def _create_resources(self) -> TopicResources: topic = sns.Topic( topic_name, - name=topic_name, - fifo_topic=self._fifo if self._fifo else None, - content_based_deduplication=self._fifo if self._fifo else None, + **self._customizer( + "topic", + { + "name": topic_name, + "fifo_topic": self._fifo if self._fifo else None, + "content_based_deduplication": self._fifo if self._fifo else None, + }, + ), ) pulumi.export(f"topic_{self.name}_arn", topic.arn) @@ -241,6 +297,7 @@ def subscribe( /, *, filter_: dict[str, list] | None = None, + customize: TopicSubscriptionCustomizationDict | None = None, **opts: Unpack[FunctionConfigDict], ) -> TopicSubscription: """Subscribe a Lambda function to this topic. @@ -249,6 +306,7 @@ def subscribe( name: Name for the subscription (used in Lambda function naming) handler: Lambda handler specification filter_: SNS filter policy for message filtering + customize: Customization dictionary **opts: Lambda function configuration (memory, timeout, etc.) Raises: @@ -268,7 +326,14 @@ def subscribe( if any(sub.name == subscription_name for sub in self._subscriptions): raise ValueError(f"Subscription '{name}' already exists for topic '{self.name}'") - subscription = TopicSubscription(function_name, self, handler, filter_, opts) + subscription = TopicSubscription( + function_name, + self, + handler, + filter_, + opts, + customize=customize, + ) self._subscriptions.append(subscription) return subscription @@ -280,6 +345,7 @@ def subscribe_queue( *, filter_: dict[str, list] | None = None, raw_message_delivery: bool = False, + customize: TopicQueueSubscriptionCustomizationDict | None = None, ) -> TopicQueueSubscription: """Subscribe an SQS queue to this topic. @@ -290,6 +356,7 @@ def subscribe_queue( queue: Queue component or queue ARN filter_: SNS filter policy for message filtering raw_message_delivery: If True, send raw message without SNS envelope + customize: Customization dictionary Raises: ValueError: If a subscription with the same name already exists @@ -305,6 +372,7 @@ def subscribe_queue( queue, filter_, raw_message_delivery, + customize=customize, ) self._queue_subscriptions.append(subscription) return subscription diff --git a/stelvio/command_run.py b/stelvio/command_run.py index ba6a5d61..f45faffa 100644 --- a/stelvio/command_run.py +++ b/stelvio/command_run.py @@ -210,6 +210,7 @@ def _load_stlv_app(env: str, dev_mode: bool) -> None: aws=config.aws, dns=config.dns, home=config.home, + customize=config.customize, dev_mode=dev_mode, ) ) diff --git a/stelvio/component.py b/stelvio/component.py index 67067939..21825cc2 100644 --- a/stelvio/component.py +++ b/stelvio/component.py @@ -6,6 +6,8 @@ from hashlib import sha256 from typing import TYPE_CHECKING, Any, ClassVar, Protocol +from stelvio import context + if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -13,13 +15,17 @@ from stelvio.link import LinkConfig -class Component[ResourcesT](ABC): +class Component[ResourcesT, CustomizationT](ABC): _name: str _resources: ResourcesT | None + _customize: CustomizationT | None = None - def __init__(self, name: str): + def __init__(self, name: str, customize: CustomizationT | None = None): self._name = name self._resources = None + self._customize = customize + if self._customize is None: + self._customize = {} ComponentRegistry.add_instance(self) @property @@ -37,6 +43,25 @@ def _create_resources(self) -> ResourcesT: """Implement actual resource creation logic""" raise NotImplementedError + def _customizer(self, resource_name: str, default_props: dict[str, dict]) -> dict: + global_customize = context().customize.get(type(self), {}) + + # Convert Pulumi Input Args to dict + def _(val: object) -> dict: + if val is None: + return {} + if isinstance(val, dict): + return val + if hasattr(val, "__dict__"): + return vars(val) + raise ValueError(f"Cannot convert customization value to dict: {val}") + + return { + **default_props, + **_(global_customize.get(resource_name)), + **_(self._customize.get(resource_name)), + } + class Bridgeable(Protocol): _dev_endpoint_id: str | None diff --git a/stelvio/config.py b/stelvio/config.py index f090d5e8..ed43b488 100644 --- a/stelvio/config.py +++ b/stelvio/config.py @@ -1,8 +1,11 @@ from dataclasses import dataclass, field -from typing import Literal +from typing import TYPE_CHECKING, Any, Literal from stelvio.dns import Dns +if TYPE_CHECKING: + from stelvio.component import Component + @dataclass(frozen=True, kw_only=True) class AwsConfig: @@ -100,12 +103,14 @@ class StelvioAppConfig: dns: DNS provider configuration for custom domains. environments: List of shared environment names (e.g., ["staging", "production"]). home: State storage backend. Currently only "aws" is supported. + customize: Customization dictionary for Pulumi resources. """ aws: AwsConfig = field(default_factory=AwsConfig) dns: Dns | None = None environments: list[str] = field(default_factory=list) home: Literal["aws"] = "aws" + customize: dict[type["Component[Any, Any]"], dict[str, dict]] = field(default_factory=dict) def is_valid_environment(self, env: str, username: str) -> bool: return env == username or env in self.environments diff --git a/stelvio/context.py b/stelvio/context.py index 51726385..66d12f96 100644 --- a/stelvio/context.py +++ b/stelvio/context.py @@ -1,9 +1,12 @@ -from dataclasses import dataclass -from typing import ClassVar, Literal +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar, Literal from stelvio.config import AwsConfig from stelvio.dns import Dns +if TYPE_CHECKING: + from stelvio.component import Component + @dataclass(frozen=True) class AppContext: @@ -15,6 +18,7 @@ class AppContext: home: Literal["aws"] dns: Dns | None = None dev_mode: bool = False + customize: dict[type["Component[Any, Any]"], dict[str, dict]] = field(default_factory=dict) def prefix(self, name: str | None = None) -> str: """Get resource name prefix or prefixed name. diff --git a/tests/aws/pulumi_mocks.py b/tests/aws/pulumi_mocks.py index bfb1d079..69398ceb 100644 --- a/tests/aws/pulumi_mocks.py +++ b/tests/aws/pulumi_mocks.py @@ -299,6 +299,10 @@ def created_event_source_mappings(self, name: str | None = None) -> list[MockRes def created_queues(self, name: str | None = None) -> list[MockResourceArgs]: return self._filter_created("aws:sqs/queue:Queue", name) + def created_sqs_queues(self, name: str | None = None) -> list[MockResourceArgs]: + """Alias for created_queues for clarity.""" + return self.created_queues(name) + def created_queue_policies(self, name: str | None = None) -> list[MockResourceArgs]: return self._filter_created("aws:sqs/queuePolicy:QueuePolicy", name) @@ -306,9 +310,25 @@ def created_queue_policies(self, name: str | None = None) -> list[MockResourceAr def created_topics(self, name: str | None = None) -> list[MockResourceArgs]: return self._filter_created("aws:sns/topic:Topic", name) + def created_sns_topics(self, name: str | None = None) -> list[MockResourceArgs]: + """Alias for created_topics for clarity.""" + return self.created_topics(name) + def created_topic_subscriptions(self, name: str | None = None) -> list[MockResourceArgs]: return self._filter_created("aws:sns/topicSubscription:TopicSubscription", name) + # DynamoDB resource helpers + def created_dynamodb_tables(self, name: str | None = None) -> list[MockResourceArgs]: + """Alias for created_dynamo_tables for clarity.""" + return self.created_dynamo_tables(name) + + # SES resource helpers + def created_email_identities(self, name: str | None = None) -> list[MockResourceArgs]: + return self._filter_created("aws:sesv2/emailIdentity:EmailIdentity", name) + + def created_configuration_sets(self, name: str | None = None) -> list[MockResourceArgs]: + return self._filter_created("aws:sesv2/configurationSet:ConfigurationSet", name) + class MockDns(Dns): """Mock DNS provider that mimics CloudflareDns interface""" diff --git a/tests/aws/test_customization.py b/tests/aws/test_customization.py new file mode 100644 index 00000000..6e1d6f41 --- /dev/null +++ b/tests/aws/test_customization.py @@ -0,0 +1,866 @@ +"""Tests for the customize parameter across all components. + +These tests verify that the customize parameter is properly passed through to the +underlying Pulumi resources for each component. +""" + +import shutil +import tempfile +from pathlib import Path +from unittest.mock import Mock + +import pulumi +import pytest +from pulumi.runtime import set_mocks + +from stelvio.aws.api_gateway import Api +from stelvio.aws.cloudfront import CloudFrontDistribution +from stelvio.aws.cloudfront.router import Router +from stelvio.aws.cron import Cron +from stelvio.aws.dynamo_db import DynamoTable +from stelvio.aws.email import Email +from stelvio.aws.function import Function +from stelvio.aws.queue import Queue +from stelvio.aws.s3 import Bucket, S3StaticWebsite +from stelvio.aws.topic import Topic +from stelvio.config import AwsConfig +from stelvio.context import AppContext, _ContextStore +from stelvio.dns import Dns + +from .pulumi_mocks import MockDns, PulumiTestMocks + +# Test prefix +TP = "test-test-" + + +class EmailTestMocks(PulumiTestMocks): + """Extended mocks for Email tests that add DKIM tokens.""" + + def new_resource(self, args): + id_, props = super().new_resource(args) + if args.typ == "aws:sesv2/emailIdentity:EmailIdentity": + props["dkim_signing_attributes"] = {"tokens": ["token1", "token2", "token3"]} + props["arn"] = ( + f"arn:aws:ses:us-east-1:123456789012:identity/{args.inputs['emailIdentity']}" + ) + if args.typ == "aws:sesv2/configurationSet:ConfigurationSet": + props["arn"] = ( + f"arn:aws:ses:us-east-1:123456789012:configuration-set/" + f"{args.inputs['configurationSetName']}" + ) + return id_, props + + +@pytest.fixture +def pulumi_mocks(): + mocks = PulumiTestMocks() + set_mocks(mocks) + return mocks + + +@pytest.fixture +def email_mocks(): + mocks = EmailTestMocks() + set_mocks(mocks) + return mocks + + +@pytest.fixture +def mock_dns(): + dns = Mock(spec=Dns) + dns.create_record.return_value = Mock() + return dns + + +@pytest.fixture +def app_context_with_dns(): + """Fixture that provides an app context with DNS configured.""" + _ContextStore.clear() + mock_dns = MockDns() + _ContextStore.set( + AppContext( + name="test", + env="test", + aws=AwsConfig(profile="default", region="us-east-1"), + home="aws", + dns=mock_dns, + ) + ) + yield mock_dns + _ContextStore.clear() + _ContextStore.set( + AppContext( + name="test", + env="test", + aws=AwsConfig(profile="default", region="us-east-1"), + home="aws", + ) + ) + + +def delete_files(directory: Path, filename: str): + """Helper to clean up generated files.""" + for file_path in directory.rglob(filename): + file_path.unlink(missing_ok=True) + + +@pytest.fixture +def project_cwd(monkeypatch, pytestconfig, tmp_path): + from stelvio.project import get_project_root + + get_project_root.cache_clear() + rootpath = pytestconfig.rootpath + source_project_dir = rootpath / "tests" / "aws" / "sample_test_project" + temp_project_dir = tmp_path / "sample_project_copy" + + shutil.copytree(source_project_dir, temp_project_dir, dirs_exist_ok=True) + monkeypatch.chdir(temp_project_dir) + yield temp_project_dir + delete_files(temp_project_dir, "stlv_resources.py") + + +# ============================================================================= +# S3 Bucket Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_bucket_customize_bucket_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to S3 bucket resource.""" + # Arrange + bucket = Bucket( + "my-bucket", + customize={ + "bucket": { + "force_destroy": True, + "tags": {"Environment": "test"}, + } + }, + ) + + # Act + _ = bucket.resources + + # Assert + def check_resources(_): + buckets = pulumi_mocks.created_s3_buckets(TP + "my-bucket") + assert len(buckets) == 1 + created_bucket = buckets[0] + + # Check customization was applied + assert created_bucket.inputs.get("forceDestroy") is True + assert created_bucket.inputs.get("tags") == {"Environment": "test"} + + bucket.resources.bucket.id.apply(check_resources) + + +@pulumi.runtime.test +def test_bucket_customize_public_access_block(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to public access block resource.""" + # Arrange - private bucket (access=None) with customize + bucket = Bucket( + "my-bucket", + customize={ + "public_access_block": { + "block_public_acls": False, # Override default True + } + }, + ) + + # Act + _ = bucket.resources + + # Assert + def check_resources(_): + pabs = pulumi_mocks.created_s3_public_access_blocks(TP + "my-bucket-pab") + assert len(pabs) == 1 + pab = pabs[0] + + # Customization should override the default + assert pab.inputs.get("blockPublicAcls") is False + # Other defaults should remain + assert pab.inputs.get("blockPublicPolicy") is True + + bucket.resources.public_access_block.id.apply(check_resources) + + +# ============================================================================= +# Function Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_function_customize_function_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to Lambda function resource.""" + # Arrange + fn = Function( + "my-function", + handler="functions/simple.handler", + customize={ + "function": { + "reserved_concurrent_executions": 10, + } + }, + ) + + # Act + _ = fn.resources + + # Assert + def check_resources(_): + functions = pulumi_mocks.created_functions(TP + "my-function") + assert len(functions) == 1 + created_fn = functions[0] + + # Check customization was applied + assert created_fn.inputs.get("reservedConcurrentExecutions") == 10 + + fn.resources.function.id.apply(check_resources) + + +# ============================================================================= +# Queue Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_queue_customize_queue_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to SQS queue resource.""" + # Arrange + queue = Queue( + "my-queue", + customize={ + "queue": { + "tags": {"Team": "backend"}, + } + }, + ) + + # Act + _ = queue.resources + + # Assert + def check_resources(_): + queues = pulumi_mocks.created_sqs_queues(TP + "my-queue") + assert len(queues) == 1 + created_queue = queues[0] + + # Check customization was applied + assert created_queue.inputs.get("tags") == {"Team": "backend"} + + queue.resources.queue.id.apply(check_resources) + + +# ============================================================================= +# Topic Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_topic_customize_topic_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to SNS topic resource.""" + # Arrange + topic = Topic( + "my-topic", + customize={ + "topic": { + "tags": {"Project": "stelvio"}, + } + }, + ) + + # Act + _ = topic.resources + + # Assert + def check_resources(_): + topics = pulumi_mocks.created_sns_topics() + assert len(topics) >= 1 + + # Find our topic + matching_topics = [t for t in topics if "my-topic" in t.name] + assert len(matching_topics) == 1 + created_topic = matching_topics[0] + + # Check customization was applied + assert created_topic.inputs.get("tags") == {"Project": "stelvio"} + + topic.resources.topic.id.apply(check_resources) + + +# ============================================================================= +# DynamoDB Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_dynamo_table_customize_table_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to DynamoDB table resource.""" + # Arrange + table = DynamoTable( + "my-table", + fields={"id": "string"}, + partition_key="id", + customize={ + "table": { + "tags": {"Service": "orders"}, + } + }, + ) + + # Act + _ = table.resources + + # Assert + def check_resources(_): + tables = pulumi_mocks.created_dynamodb_tables() + assert len(tables) >= 1 + + # Find our table + matching_tables = [t for t in tables if "my-table" in t.name] + assert len(matching_tables) == 1 + created_table = matching_tables[0] + + # Check customization was applied + assert created_table.inputs.get("tags") == {"Service": "orders"} + + table.resources.table.id.apply(check_resources) + + +# ============================================================================= +# Cron Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_cron_customize_rule_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to EventBridge rule resource.""" + # Arrange + cron = Cron( + "my-cron", + "rate(1 hour)", + "functions/simple.handler", + customize={ + "rule": { + "tags": {"Schedule": "hourly"}, + } + }, + ) + + # Act + _ = cron.resources + + # Assert + def check_resources(_): + rules = pulumi_mocks.created_event_rules() + assert len(rules) >= 1 + + # Find our rule + matching_rules = [r for r in rules if "my-cron" in r.name] + assert len(matching_rules) == 1 + created_rule = matching_rules[0] + + # Check customization was applied + assert created_rule.inputs.get("tags") == {"Schedule": "hourly"} + + cron.resources.rule.id.apply(check_resources) + + +@pulumi.runtime.test +def test_cron_customize_target_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to EventBridge target resource.""" + # Arrange + cron = Cron( + "my-cron", + "rate(1 hour)", + "functions/simple.handler", + customize={ + "target": { + "retry_policy": {"maximum_event_age_in_seconds": 60}, + } + }, + ) + + # Act + _ = cron.resources + + # Assert + def check_resources(_): + targets = pulumi_mocks.created_event_targets() + assert len(targets) >= 1 + + # Find our target + matching_targets = [t for t in targets if "my-cron" in t.name] + assert len(matching_targets) == 1 + created_target = matching_targets[0] + + # Check customization was applied + retry_policy = created_target.inputs.get("retryPolicy") + assert retry_policy is not None + assert retry_policy.get("maximumEventAgeInSeconds") == 60 + + cron.resources.target.id.apply(check_resources) + + +# ============================================================================= +# Test customization merging behavior +# ============================================================================= + + +@pulumi.runtime.test +def test_customize_merges_with_defaults(pulumi_mocks, project_cwd): + """Test that customize merges with defaults instead of replacing them.""" + # Arrange + bucket = Bucket( + "my-bucket", + versioning=True, # Default param + customize={ + "bucket": { + "force_destroy": True, # Customization + } + }, + ) + + # Act + _ = bucket.resources + + # Assert + def check_resources(_): + buckets = pulumi_mocks.created_s3_buckets(TP + "my-bucket") + assert len(buckets) == 1 + created_bucket = buckets[0] + + # Both default and customization should be present + assert created_bucket.inputs.get("versioning", {}).get("enabled") is True + assert created_bucket.inputs.get("forceDestroy") is True + + bucket.resources.bucket.id.apply(check_resources) + + +@pulumi.runtime.test +def test_customize_can_override_defaults(pulumi_mocks, project_cwd): + """Test that customize can override default values.""" + # Arrange - Override the default memory size + fn = Function( + "my-function", + handler="functions/simple.handler", + memory=256, # Default param + customize={ + "function": { + "memory_size": 512, # Override via customize + } + }, + ) + + # Act + _ = fn.resources + + # Assert + def check_resources(_): + functions = pulumi_mocks.created_functions(TP + "my-function") + assert len(functions) == 1 + created_fn = functions[0] + + # Customization should override the config value + assert created_fn.inputs.get("memorySize") == 512 + + fn.resources.function.id.apply(check_resources) + + +@pulumi.runtime.test +def test_customize_empty_dict_uses_defaults(pulumi_mocks, project_cwd): + """Test that empty customize dict still uses defaults.""" + # Arrange + bucket = Bucket( + "my-bucket", + versioning=True, + customize={}, # Empty customize + ) + + # Act + _ = bucket.resources + + # Assert + def check_resources(_): + buckets = pulumi_mocks.created_s3_buckets(TP + "my-bucket") + assert len(buckets) == 1 + created_bucket = buckets[0] + + # Defaults should still be applied + assert created_bucket.inputs.get("versioning", {}).get("enabled") is True + + bucket.resources.bucket.id.apply(check_resources) + + +@pulumi.runtime.test +def test_customize_none_uses_defaults(pulumi_mocks, project_cwd): + """Test that None customize uses defaults.""" + # Arrange + bucket = Bucket( + "my-bucket", + versioning=True, + customize=None, + ) + + # Act + _ = bucket.resources + + # Assert + def check_resources(_): + buckets = pulumi_mocks.created_s3_buckets(TP + "my-bucket") + assert len(buckets) == 1 + created_bucket = buckets[0] + + # Defaults should still be applied + assert created_bucket.inputs.get("versioning", {}).get("enabled") is True + + bucket.resources.bucket.id.apply(check_resources) + + +# ============================================================================= +# Email Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_email_customize_identity_resource(email_mocks, project_cwd, mock_dns): + """Test that customize parameter is applied to SES email identity resource.""" + # Arrange + email = Email( + "my-email", + "test@example.com", + dmarc=None, + customize={ + "identity": { + "tags": {"Service": "notifications"}, + } + }, + ) + + # Act + _ = email.resources + + # Assert + def check_resources(_): + identities = email_mocks.created_email_identities() + assert len(identities) >= 1 + + # Find our identity + matching_identities = [i for i in identities if "my-email" in i.name] + assert len(matching_identities) == 1 + created_identity = matching_identities[0] + + # Check customization was applied + assert created_identity.inputs.get("tags") == {"Service": "notifications"} + + email.resources.identity.id.apply(check_resources) + + +@pulumi.runtime.test +def test_email_customize_configuration_set(email_mocks, project_cwd, mock_dns): + """Test that customize parameter is applied to SES configuration set.""" + # Arrange - Domain email which creates configuration set + email = Email( + "my-domain-email", + "example.com", + dmarc=None, + dns=mock_dns, + customize={ + "configuration_set": { + "tags": {"Environment": "production"}, + } + }, + ) + + # Act + _ = email.resources + + # Assert + def check_resources(_): + config_sets = email_mocks.created_configuration_sets() + assert len(config_sets) >= 1 + + # Find our configuration set + matching_sets = [cs for cs in config_sets if "my-domain-email" in cs.name] + assert len(matching_sets) == 1 + created_config_set = matching_sets[0] + + # Check customization was applied + assert created_config_set.inputs.get("tags") == {"Environment": "production"} + + email.resources.configuration_set.id.apply(check_resources) + + +# ============================================================================= +# Api Gateway Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_api_customize_rest_api_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to API Gateway REST API resource.""" + # Arrange + api = Api( + "my-api", + customize={ + "rest_api": { + "description": "Custom API description", + } + }, + ) + api.route("GET", "/", "functions/simple.handler") + + # Act + _ = api.resources + + # Assert + def check_resources(_): + rest_apis = pulumi_mocks.created_rest_apis() + assert len(rest_apis) >= 1 + + # Find our REST API + matching_apis = [a for a in rest_apis if "my-api" in a.name] + assert len(matching_apis) == 1 + created_api = matching_apis[0] + + # Check customization was applied + assert created_api.inputs.get("description") == "Custom API description" + + api.resources.rest_api.id.apply(check_resources) + + +@pulumi.runtime.test +def test_api_customize_stage_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to API Gateway stage resource.""" + # Arrange + api = Api( + "my-api", + customize={ + "stage": { + "description": "Custom stage description", + } + }, + ) + api.route("GET", "/", "functions/simple.handler") + + # Act + _ = api.resources + + # Assert + def check_resources(_): + stages = pulumi_mocks.created_stages() + assert len(stages) >= 1 + + # Find our stage + matching_stages = [s for s in stages if "my-api" in s.name] + assert len(matching_stages) == 1 + created_stage = matching_stages[0] + + # Check customization was applied + assert created_stage.inputs.get("description") == "Custom stage description" + + api.resources.stage.id.apply(check_resources) + + +# ============================================================================= +# CloudFront Distribution Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_cloudfront_customize_distribution_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to CloudFront distribution resource.""" + # Arrange + bucket = Bucket("my-bucket") + _ = bucket.resources + + cf = CloudFrontDistribution( + "my-cf", + bucket=bucket, + customize={ + "distribution": { + "comment": "Custom CloudFront comment", + } + }, + ) + + # Act + _ = cf.resources + + # Assert + def check_resources(_): + distributions = pulumi_mocks.created_cloudfront_distributions() + assert len(distributions) >= 1 + + # Find our distribution + matching_dists = [d for d in distributions if "my-cf" in d.name] + assert len(matching_dists) == 1 + created_dist = matching_dists[0] + + # Check customization was applied + assert created_dist.inputs.get("comment") == "Custom CloudFront comment" + + cf.resources.distribution.id.apply(check_resources) + + +@pulumi.runtime.test +def test_cloudfront_customize_origin_access_control(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to origin access control resource.""" + # Arrange + bucket = Bucket("my-bucket") + _ = bucket.resources + + cf = CloudFrontDistribution( + "my-cf", + bucket=bucket, + customize={ + "origin_access_control": { + "description": "Custom OAC description", + } + }, + ) + + # Act + _ = cf.resources + + # Assert + def check_resources(_): + oacs = pulumi_mocks.created_origin_access_controls() + assert len(oacs) >= 1 + + # Find our OAC + matching_oacs = [o for o in oacs if "my-cf" in o.name] + assert len(matching_oacs) == 1 + created_oac = matching_oacs[0] + + # Check customization was applied + assert created_oac.inputs.get("description") == "Custom OAC description" + + cf.resources.origin_access_control.id.apply(check_resources) + + +# ============================================================================= +# Router Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_router_customize_distribution_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to Router CloudFront distribution.""" + # Arrange + bucket = Bucket("static-bucket") + _ = bucket.resources + + router = Router( + "my-router", + customize={ + "distribution": { + "comment": "Custom Router comment", + } + }, + ) + router.route("/static", bucket) + + # Act + _ = router.resources + + # Assert + def check_resources(_): + distributions = pulumi_mocks.created_cloudfront_distributions() + assert len(distributions) >= 1 + + # Find our distribution + matching_dists = [d for d in distributions if "my-router" in d.name] + assert len(matching_dists) == 1 + created_dist = matching_dists[0] + + # Check customization was applied + assert created_dist.inputs.get("comment") == "Custom Router comment" + + router.resources.distribution.id.apply(check_resources) + + +# ============================================================================= +# S3StaticWebsite Customization Tests +# ============================================================================= + + +@pulumi.runtime.test +def test_s3_static_website_customize_bucket_resource(pulumi_mocks, project_cwd): + """Test that customize parameter is applied to S3StaticWebsite bucket resource.""" + # Arrange + with tempfile.TemporaryDirectory() as tmpdir: + static_dir = Path(tmpdir) / "static" + static_dir.mkdir() + (static_dir / "index.html").write_text("Hello") + + website = S3StaticWebsite( + "my-website", + directory=str(static_dir), + customize={ + "bucket": { + "force_destroy": True, + } + }, + ) + + # Act + _ = website.resources + + # Assert + def check_resources(_): + buckets = pulumi_mocks.created_s3_buckets() + assert len(buckets) >= 1 + + # Find our bucket + matching_buckets = [b for b in buckets if "my-website" in b.name] + assert len(matching_buckets) == 1 + created_bucket = matching_buckets[0] + + # Check customization was applied + assert created_bucket.inputs.get("forceDestroy") is True + + website.resources.bucket.id.apply(check_resources) + + +@pulumi.runtime.test +def test_s3_static_website_customize_cloudfront_distribution( + pulumi_mocks, project_cwd, app_context_with_dns +): + """Test that customize parameter is applied to S3StaticWebsite CloudFront.""" + # Arrange + with tempfile.TemporaryDirectory() as tmpdir: + static_dir = Path(tmpdir) / "static" + static_dir.mkdir() + (static_dir / "index.html").write_text("Hello") + + website = S3StaticWebsite( + "my-website", + directory=str(static_dir), + custom_domain="example.com", + customize={ + "cloudfront_distribution": { + "distribution": { + "comment": "Custom Website CDN", + } + } + }, + ) + + # Act + _ = website.resources + + # Assert + def check_resources(_): + distributions = pulumi_mocks.created_cloudfront_distributions() + assert len(distributions) >= 1 + + # Find our distribution + matching_dists = [d for d in distributions if "my-website" in d.name] + assert len(matching_dists) == 1 + created_dist = matching_dists[0] + + # Check customization was applied + assert created_dist.inputs.get("comment") == "Custom Website CDN" + + website.resources.cloudfront_distribution.resources.distribution.id.apply(check_resources) diff --git a/tests/conftest.py b/tests/conftest.py index e887ef51..41f64480 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,6 +53,7 @@ def app_context(): env="test", aws=AwsConfig(profile="default", region="us-east-1"), home="aws", + customize={}, ) ) diff --git a/tests/test_component.py b/tests/test_component.py index 0864d01e..6f5375a7 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -19,9 +19,14 @@ class MockComponentResources: # Concrete implementation of Component for testing -class MockComponent(Component[MockComponentResources]): - def __init__(self, name: str, resource: MockResource = None): - super().__init__(name) +class MockComponent(Component[MockComponentResources, dict]): + def __init__( + self, + name: str, + resource: MockResource = None, + customize: dict[str, dict] | None = None, + ): + super().__init__(name, customize=customize) self._mock_resource = resource or MockResource(name) # Track if _create_resource was called self.create_resources_called = False @@ -167,3 +172,137 @@ def test_creator(r): # Test that the wrapper preserves function metadata assert creator.__name__ == test_creator.__name__ + + +# Customizer tests + + +def test_customizer_returns_default_props_when_no_customization(clear_registry): + """Test that _customizer returns default props when no customization is provided.""" + component = MockComponent("test-component") + + default_props = {"key1": "value1", "key2": "value2"} + result = component._customizer("some_resource", default_props) + + assert result == default_props + + +def test_customizer_returns_default_props_when_resource_not_in_customize(clear_registry): + """Test that _customizer returns default props when resource name is not in customize dict.""" + component = MockComponent( + "test-component", + customize={"other_resource": {"key1": "override1"}}, + ) + + default_props = {"key1": "value1", "key2": "value2"} + result = component._customizer("some_resource", default_props) + + assert result == default_props + + +def test_customizer_merges_customization_with_defaults(clear_registry): + """Test that _customizer merges customization overrides with default props.""" + component = MockComponent( + "test-component", + customize={"bucket": {"key1": "override1", "key3": "new_value"}}, + ) + + default_props = {"key1": "value1", "key2": "value2"} + result = component._customizer("bucket", default_props) + + # Customization should override key1 and add key3 + assert result == {"key1": "override1", "key2": "value2", "key3": "new_value"} + + +def test_customizer_overrides_take_precedence(clear_registry): + """Test that customization values take precedence over defaults.""" + component = MockComponent( + "test-component", + customize={"resource": {"setting": "custom"}}, + ) + + default_props = {"setting": "default"} + result = component._customizer("resource", default_props) + + assert result["setting"] == "custom" + + +def test_customizer_with_empty_defaults(clear_registry): + """Test that _customizer works with empty default props.""" + component = MockComponent( + "test-component", + customize={"resource": {"key1": "value1"}}, + ) + + result = component._customizer("resource", {}) + + assert result == {"key1": "value1"} + + +def test_customizer_with_empty_customization_for_resource(clear_registry): + """Test that _customizer handles empty customization for a specific resource.""" + component = MockComponent( + "test-component", + customize={"resource": {}}, + ) + + default_props = {"key1": "value1"} + result = component._customizer("resource", default_props) + + assert result == default_props + + +def test_customizer_with_nested_dict_values(clear_registry): + """Test that _customizer works with nested dictionary values.""" + component = MockComponent( + "test-component", + customize={"bucket": {"versioning": {"enabled": False}}}, + ) + + default_props = {"bucket": "my-bucket", "versioning": {"enabled": True}} + result = component._customizer("bucket", default_props) + + # Note: dict merge is shallow, so nested dict is completely replaced + assert result == { + "bucket": "my-bucket", + "versioning": {"enabled": False}, + } + + +def test_customizer_with_multiple_resources(clear_registry): + """Test that _customizer correctly selects the right resource configuration.""" + component = MockComponent( + "test-component", + customize={ + "bucket": {"key": "bucket_value"}, + "policy": {"key": "policy_value"}, + "role": {"key": "role_value"}, + }, + ) + + bucket_result = component._customizer("bucket", {"key": "default"}) + policy_result = component._customizer("policy", {"key": "default"}) + role_result = component._customizer("role", {"key": "default"}) + other_result = component._customizer("other", {"key": "default"}) + + assert bucket_result == {"key": "bucket_value"} + assert policy_result == {"key": "policy_value"} + assert role_result == {"key": "role_value"} + assert other_result == {"key": "default"} + + +def test_customize_defaults_to_empty_dict(clear_registry): + """Test that customize defaults to an empty dict when None is passed.""" + component = MockComponent("test-component", customize=None) + + # Should not raise, and should return defaults + result = component._customizer("resource", {"key": "value"}) + assert result == {"key": "value"} + + +def test_customize_initialization_without_parameter(clear_registry): + """Test that component can be created without customize parameter.""" + component = MockComponent("test-component") + + # Internal _customize should be an empty dict + assert component._customize == {} diff --git a/tests/test_link.py b/tests/test_link.py index ac6a276b..1c04218a 100644 --- a/tests/test_link.py +++ b/tests/test_link.py @@ -32,7 +32,7 @@ def __init__(self, name="test-resource"): self.arn = f"arn:aws:mock:::{name}" -class MockComponent(Component[MockResource], Linkable): +class MockComponent(Component[MockResource, dict], Linkable): def __init__(self, name): super().__init__(name) self._mock_resource = MockResource(name)