diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index 0cdf365..30b73f7 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -44,7 +44,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install .[dev,server] + python -m pip install .[dev,server,validation] python -m pip install "pypgstac==${{ matrix.pypgstac }}" - name: Run test suite diff --git a/CHANGES.md b/CHANGES.md index 64a15e9..4f1f7cd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -48,6 +48,10 @@ ### Added +- add `validate_extensions` setting that enables validation of `stac_extensions` from submitted STAC objects + using the `stac_pydantic.extensions.validate_extensions` utility. Applicable only when `TransactionExtension` + is active. +- add `validation` extra requirement to install dependencies of `stac_pydantic` required for extension validation - add `write_connection_pool` option in `stac_fastapi.pgstac.db.connect_to_db` function - add `write_postgres_settings` option in `stac_fastapi.pgstac.db.connect_to_db` function to set specific settings for the `writer` DB connection pool - add specific error message when trying to create `Item` with null geometry (not supported by PgSTAC) diff --git a/setup.py b/setup.py index 8005f5e..4eb1998 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,9 @@ ], "server": ["uvicorn[standard]==0.35.0"], "awslambda": ["mangum"], + "validation": [ + "stac_pydantic[validation]", + ], } diff --git a/stac_fastapi/pgstac/config.py b/stac_fastapi/pgstac/config.py index 4f8f7da..00432af 100644 --- a/stac_fastapi/pgstac/config.py +++ b/stac_fastapi/pgstac/config.py @@ -170,6 +170,13 @@ class Settings(ApiSettings): invalid_id_chars: List[str] = DEFAULT_INVALID_ID_CHARS base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache + validate_extensions: bool = False + """ + Validate `stac_extensions` schemas against submitted data when creating or updated STAC objects. + + Implies that the `Transactions` extension is enabled. + """ + cors_origins: str = "*" cors_methods: str = "GET,POST,OPTIONS" cors_credentials: bool = False diff --git a/stac_fastapi/pgstac/transactions.py b/stac_fastapi/pgstac/transactions.py index d2fe968..bb2588e 100644 --- a/stac_fastapi/pgstac/transactions.py +++ b/stac_fastapi/pgstac/transactions.py @@ -2,7 +2,7 @@ import logging import re -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import attr import jsonpatch @@ -23,6 +23,7 @@ from stac_fastapi.types import stac as stac_types from stac_fastapi.types.errors import NotFoundError from stac_pydantic import Collection, Item, ItemCollection +from stac_pydantic.extensions import validate_extensions from starlette.responses import JSONResponse, Response from stac_fastapi.pgstac.config import Settings @@ -44,8 +45,38 @@ def _validate_id(self, id: str, settings: Settings): detail=f"ID ({id}) cannot contain the following characters: {' '.join(invalid_chars)}", ) + def _validate_extensions( + self, + stac_object: Union[ + stac_types.Item, stac_types.Collection, stac_types.Catalog, Dict[str, Any] + ], + settings: Settings, + ) -> None: + """Validate extensions of the STAC object data.""" + if not settings.validate_extensions: + return + + if isinstance(stac_object, dict): + if not stac_object.get("stac_extensions"): + return + else: + if not stac_object.stac_extensions: + return + + try: + validate_extensions( + stac_object, + reraise_exception=True, + ) + except Exception as err: + raise HTTPException( + status_code=422, + detail=f"STAC Extensions failed validation: {err!s}", + ) from err + def _validate_collection(self, request: Request, collection: stac_types.Collection): self._validate_id(collection["id"], request.app.state.settings) + self._validate_extensions(collection, request.app.state.settings) def _validate_item( self, @@ -59,6 +90,7 @@ def _validate_item( body_item_id = item.get("id") self._validate_id(body_item_id, request.app.state.settings) + self._validate_extensions(item, request.app.state.settings) if item.get("geometry", None) is None: raise HTTPException( @@ -180,6 +212,7 @@ async def update_collection( """Update collection.""" col = collection.model_dump(mode="json") + self._validate_collection(request, col) async with request.app.state.get_connection(request, "w") as conn: await dbfunc(conn, "update_collection", col) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index bd2a364..7f67e28 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -948,3 +948,39 @@ async def test_default_app(default_client, default_app, load_test_data): assert "https://api.stacspec.org/v1.0.0/collections" in conf assert "https://api.stacspec.org/v1.0.0/ogcapi-features#query" in conf assert "https://api.stacspec.org/v1.0.0/ogcapi-features#sort" in conf + + +async def test_app_transactions_validate_extension( + app_client_validate_ext, load_test_data +): + coll = load_test_data("test_collection.json") + # Add attribution extension + # https://github.com/stac-extensions/attribution + coll["stac_extensions"] = [ + "https://stac-extensions.github.io/attribution/v0.1.0/schema.json", + ] + + resp = await app_client_validate_ext.post("/collections", json=coll) + assert resp.status_code == 422 + assert "STAC Extensions failed validation:" in resp.json()["detail"] + + # add attribution + coll["attribution"] = "something" + resp = await app_client_validate_ext.post("/collections", json=coll) + assert resp.status_code == 201 + + item = load_test_data("test_item.json") + item["stac_extensions"].append( + "https://stac-extensions.github.io/attribution/v0.1.0/schema.json", + ) + resp = await app_client_validate_ext.post( + f"/collections/{coll['id']}/items", json=item + ) + assert resp.status_code == 422 + assert "STAC Extensions failed validation:" in resp.json()["detail"] + + item["properties"]["attribution"] = "something" + resp = await app_client_validate_ext.post( + f"/collections/{coll['id']}/items", json=item + ) + assert resp.status_code == 201 diff --git a/tests/conftest.py b/tests/conftest.py index d349593..dbbb597 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -454,3 +454,48 @@ async def app_client_advanced_freetext(app_advanced_freetext): transport=ASGITransport(app=app_advanced_freetext), base_url="http://test" ) as c: yield c + + +@pytest.fixture(scope="function") +async def app_transaction_validation_ext(database): + """Default stac-fastapi-pgstac application with extension validation in transaction.""" + api_settings = Settings(testing=True, validate_extensions=True) + api = StacApi( + settings=api_settings, + extensions=[ + TransactionExtension( + client=TransactionsClient(), + settings=api_settings, + ) + ], + client=CoreCrudClient(), + health_check=health_check, + ) + + postgres_settings = PostgresSettings( + pguser=database.user, + pgpassword=database.password, + pghost=database.host, + pgport=database.port, + pgdatabase=database.dbname, + ) + logger.info("Creating app Fixture") + await connect_to_db( + api.app, + postgres_settings=postgres_settings, + add_write_connection_pool=True, + ) + yield api.app + await close_db_connection(api.app) + + logger.info("Closed Pools.") + + +@pytest.fixture(scope="function") +async def app_client_validate_ext(app_transaction_validation_ext): + logger.info("creating app_client") + async with AsyncClient( + transport=ASGITransport(app=app_transaction_validation_ext), + base_url="http://test", + ) as c: + yield c