diff --git a/.github/workflows/api-deployer.yml b/.github/workflows/api-deployer.yml index 457600dc7..23f7fb1b6 100644 --- a/.github/workflows/api-deployer.yml +++ b/.github/workflows/api-deployer.yml @@ -199,7 +199,7 @@ jobs: - uses: actions/download-artifact@v4 with: name: database_gen - path: api/src/database_gen/ + path: api/src/shared/database_gen/ # api schema was generated and uploaded in api-build-test job above. - uses: actions/download-artifact@v4 @@ -219,7 +219,7 @@ jobs: - name: Build & Publish Docker Image run: | # We want to generate the image even if it's the same commit that has been tagged. So use the version - # (coming from the tag) in the docker image tag (If the docket tag does not change it's won't be uploaded) + # (coming from the tag) in the docker image tag (If the docker tag does not change it won't be uploaded) DOCKER_IMAGE_VERSION=$EXTRACTED_VERSION.$FEED_API_IMAGE_VERSION scripts/docker-build-push.sh -project_id $PROJECT_ID -repo_name feeds-$ENVIRONMENT -service feed-api -region $REGION -version $DOCKER_IMAGE_VERSION @@ -243,7 +243,7 @@ jobs: - uses: actions/download-artifact@v4 with: name: database_gen - path: api/src/database_gen/ + path: api/src/shared/database_gen/ # api schema was generated and uploaded in api-build-test job above. - uses: actions/download-artifact@v4 @@ -318,6 +318,13 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PLAN_OUTPUT: ${{ steps.plan.outputs.stdout }} + - name: Persist TF plan + uses: actions/upload-artifact@v4 + with: + name: terraform-plan.txt + path: infra/terraform-plan.txt + overwrite: true + - name: Terraform Apply if: ${{ inputs.TF_APPLY }} run: | @@ -325,11 +332,4 @@ jobs: terraform apply -auto-approve tf.plan env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PLAN_OUTPUT: ${{ steps.plan.outputs.stdout }} - - - name: Persist TF plan - uses: actions/upload-artifact@v4 - with: - name: terraform-plan.txt - path: infra/terraform-plan.txt - overwrite: true + PLAN_OUTPUT: ${{ steps.plan.outputs.stdout }} \ No newline at end of file diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 96b990940..1481497e6 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -105,7 +105,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: database_gen - path: api/src/database_gen/ + path: api/src/shared/database_gen/ overwrite: true - name: Upload API generated code diff --git a/.github/workflows/datasets-batch-deployer.yml b/.github/workflows/datasets-batch-deployer.yml index 6e07b58d6..c30e8d508 100644 --- a/.github/workflows/datasets-batch-deployer.yml +++ b/.github/workflows/datasets-batch-deployer.yml @@ -119,7 +119,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: database_gen - path: api/src/database_gen/ + path: api/src/shared/database_gen/ - name: Build python functions run: | diff --git a/api/.flake8 b/api/.flake8 index 7a387a63c..3e458cb40 100644 --- a/api/.flake8 +++ b/api/.flake8 @@ -1,5 +1,5 @@ [flake8] max-line-length = 120 -exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,src/feeds_gen,src/database_gen +exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,feeds_gen,database_gen # Ignored because conflict with black extend-ignore = E203 \ No newline at end of file diff --git a/api/src/feeds/impl/datasets_api_impl.py b/api/src/feeds/impl/datasets_api_impl.py index 8d32bf769..cf018c8cf 100644 --- a/api/src/feeds/impl/datasets_api_impl.py +++ b/api/src/feeds/impl/datasets_api_impl.py @@ -1,20 +1,17 @@ from typing import List from typing import Tuple -from geoalchemy2 import WKTElement -from sqlalchemy import or_ from sqlalchemy.orm import Query, Session -from database.database import Database, with_db_session -from database_gen.sqlacodegen_models import ( +from shared.database.database import Database, with_db_session +from shared.database_gen.sqlacodegen_models import ( Gtfsdataset, Feed, ) from feeds.impl.error_handling import ( - invalid_bounding_coordinates, - invalid_bounding_method, - raise_http_validation_error, raise_http_error, +) +from shared.common.error_handling import ( dataset_not_found, ) from feeds.impl.models.gtfs_dataset_impl import GtfsDatasetImpl @@ -39,59 +36,6 @@ def create_dataset_query() -> Query: ] ).join(Feed, Feed.id == Gtfsdataset.feed_id) - @staticmethod - def apply_bounding_filtering( - query: Query, - bounding_latitudes: str, - bounding_longitudes: str, - bounding_filter_method: str, - ) -> Query: - """Create a new query based on the bounding parameters.""" - - if not bounding_latitudes or not bounding_longitudes or not bounding_filter_method: - return query - - if ( - len(bounding_latitudes_tokens := bounding_latitudes.split(",")) != 2 - or len(bounding_longitudes_tokens := bounding_longitudes.split(",")) != 2 - ): - raise_http_validation_error(invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes)) - min_latitude, max_latitude = bounding_latitudes_tokens - min_longitude, max_longitude = bounding_longitudes_tokens - try: - min_latitude = float(min_latitude) - max_latitude = float(max_latitude) - min_longitude = float(min_longitude) - max_longitude = float(max_longitude) - except ValueError: - raise_http_validation_error(invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes)) - points = [ - (min_longitude, min_latitude), - (min_longitude, max_latitude), - (max_longitude, max_latitude), - (max_longitude, min_latitude), - (min_longitude, min_latitude), - ] - wkt_polygon = f"POLYGON(({', '.join(f'{lon} {lat}' for lon, lat in points)}))" - bounding_box = WKTElement( - wkt_polygon, - srid=Gtfsdataset.bounding_box.type.srid, - ) - - if bounding_filter_method == "partially_enclosed": - return query.filter( - or_( - Gtfsdataset.bounding_box.ST_Overlaps(bounding_box), - Gtfsdataset.bounding_box.ST_Contains(bounding_box), - ) - ) - elif bounding_filter_method == "completely_enclosed": - return query.filter(bounding_box.ST_Covers(Gtfsdataset.bounding_box)) - elif bounding_filter_method == "disjoint": - return query.filter(Gtfsdataset.bounding_box.ST_Disjoint(bounding_box)) - else: - raise_http_validation_error(invalid_bounding_method.format(bounding_filter_method)) - @staticmethod def get_datasets_gtfs(query: Query, session: Session, limit: int = None, offset: int = None) -> List[GtfsDataset]: # Results are sorted by stable_id because Database.select(group_by=) requires it so diff --git a/api/src/feeds/impl/error_handling.py b/api/src/feeds/impl/error_handling.py index ecb33e3ea..d664d5929 100644 --- a/api/src/feeds/impl/error_handling.py +++ b/api/src/feeds/impl/error_handling.py @@ -1,16 +1,15 @@ -from typing import Final - from fastapi import HTTPException -invalid_date_message: Final[ - str -] = "Invalid date format for '{}'. Expected ISO 8601 format, example: '2021-01-01T00:00:00Z'" -invalid_bounding_coordinates: Final[str] = "Invalid bounding coordinates {} {}" -invalid_bounding_method: Final[str] = "Invalid bounding_filter_method {}" -feed_not_found: Final[str] = "Feed '{}' not found" -gtfs_feed_not_found: Final[str] = "GTFS feed '{}' not found" -gtfs_rt_feed_not_found: Final[str] = "GTFS realtime Feed '{}' not found" -dataset_not_found: Final[str] = "Dataset '{}' not found" +from shared.common.error_handling import InternalHTTPException + + +def convert_exception(input_exception: InternalHTTPException) -> HTTPException: + """Convert an InternalHTTPException to an HTTPException. + HTTPException is dependent on fastapi, and we don't necessarily want to deploy it with python functions. + That's why InternalHTTPException (a class that we deploy) is thrown instead of HTTPException. + Since InternalHTTPException is internal, it needs to be converted before being sent up. + """ + return HTTPException(status_code=input_exception.status_code, detail=input_exception.detail) def raise_http_error(status_code: int, error: str): diff --git a/api/src/feeds/impl/feeds_api_impl.py b/api/src/feeds/impl/feeds_api_impl.py index 0ca336bdd..be3f24db5 100644 --- a/api/src/feeds/impl/feeds_api_impl.py +++ b/api/src/feeds/impl/feeds_api_impl.py @@ -6,8 +6,9 @@ from sqlalchemy.orm import joinedload, Session from sqlalchemy.orm.query import Query -from database.database import Database, with_db_session -from database_gen.sqlacodegen_models import ( +from shared.common.db_utils import get_gtfs_feeds_query, get_gtfs_rt_feeds_query, get_joinedload_options +from shared.database.database import Database, with_db_session +from shared.database_gen.sqlacodegen_models import ( Feed, Gtfsdataset, Gtfsfeed, @@ -17,18 +18,17 @@ t_location_with_translations_en, Entitytype, ) -from feeds.filters.feed_filter import FeedFilter -from feeds.filters.gtfs_dataset_filter import GtfsDatasetFilter -from feeds.filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter -from feeds.filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter +from shared.feed_filters.feed_filter import FeedFilter +from shared.feed_filters.gtfs_dataset_filter import GtfsDatasetFilter +from shared.feed_filters.gtfs_feed_filter import LocationFilter +from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter from feeds.impl.datasets_api_impl import DatasetsApiImpl -from feeds.impl.error_handling import ( - raise_http_validation_error, +from shared.common.error_handling import ( invalid_date_message, - raise_http_error, feed_not_found, gtfs_feed_not_found, gtfs_rt_feed_not_found, + InternalHTTPException, ) from feeds.impl.models.basic_feed_impl import BasicFeedImpl from feeds.impl.models.entity_type_enum import EntityType @@ -39,6 +39,7 @@ from feeds_gen.models.gtfs_dataset import GtfsDataset from feeds_gen.models.gtfs_feed import GtfsFeed from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed +from feeds.impl.error_handling import raise_http_error, raise_http_validation_error, convert_exception from middleware.request_context import is_user_email_restricted from utils.date_utils import valid_iso_date from utils.location_translation import ( @@ -117,7 +118,7 @@ def get_feeds( ) # Results are sorted by provider feed_query = feed_query.order_by(Feed.provider, Feed.stable_id) - feed_query = feed_query.options(*BasicFeedImpl.get_joinedload_options()) + feed_query = feed_query.options(*get_joinedload_options()) if limit is not None: feed_query = feed_query.limit(limit) if offset is not None: @@ -158,7 +159,7 @@ def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None joinedload(Gtfsfeed.gtfsdatasets) .joinedload(Gtfsdataset.validation_reports) .joinedload(Validationreport.notices), - *BasicFeedImpl.get_joinedload_options(), + *get_joinedload_options(), ) ).all() if len(results) > 0 and results[0].Gtfsfeed: @@ -237,46 +238,29 @@ def get_gtfs_feeds( is_official: bool, db_session: Session, ) -> List[GtfsFeed]: - """Get some (or all) GTFS feeds from the Mobility Database.""" - gtfs_feed_filter = GtfsFeedFilter( - stable_id=None, - provider__ilike=provider, - producer_url__ilike=producer_url, - location=LocationFilter( + try: + include_wip = not is_user_email_restricted() + feed_query = get_gtfs_feeds_query( + limit=limit, + offset=offset, + provider=provider, + producer_url=producer_url, country_code=country_code, - subdivision_name__ilike=subdivision_name, - municipality__ilike=municipality, - ), - ) - - subquery = gtfs_feed_filter.filter(select(Gtfsfeed.id).join(Location, Gtfsfeed.locations)) - subquery = DatasetsApiImpl.apply_bounding_filtering( - subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method - ).subquery() - - is_email_restricted = is_user_email_restricted() - self.logger.info(f"User email is restricted: {is_email_restricted}") - feed_query = ( - db_session.query(Gtfsfeed) - .filter(Gtfsfeed.id.in_(subquery)) - .filter( - or_( - Gtfsfeed.operational_status == None, # noqa: E711 - Gtfsfeed.operational_status != "wip", - not is_email_restricted, # Allow all feeds to be returned if the user is not restricted - ) + subdivision_name=subdivision_name, + municipality=municipality, + dataset_latitudes=dataset_latitudes, + dataset_longitudes=dataset_longitudes, + bounding_filter_method=bounding_filter_method, + is_official=is_official, + include_wip=include_wip, + db_session=db_session, ) - .options( - joinedload(Gtfsfeed.gtfsdatasets) - .joinedload(Gtfsdataset.validation_reports) - .joinedload(Validationreport.notices), - *BasicFeedImpl.get_joinedload_options(), - ) - .order_by(Gtfsfeed.provider, Gtfsfeed.stable_id) - ) - if is_official: - feed_query = feed_query.filter(Feed.official) - feed_query = feed_query.limit(limit).offset(offset) + except InternalHTTPException as e: + # get_gtfs_feeds_query cannot throw HTTPException since it's part of fastapi and it's + # not necessarily deployed (e.g. for python functions). Instead it throws an InternalHTTPException + # that needs to be converted to HTTPException before being thrown. + raise convert_exception(e) + return self._get_response(feed_query, GtfsFeedImpl, db_session) @with_db_session @@ -303,7 +287,7 @@ def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed: .options( joinedload(Gtfsrealtimefeed.entitytypes), joinedload(Gtfsrealtimefeed.gtfs_feeds), - *BasicFeedImpl.get_joinedload_options(), + *get_joinedload_options(), ) ).all() @@ -328,6 +312,26 @@ def get_gtfs_rt_feeds( db_session: Session, ) -> List[GtfsRTFeed]: """Get some (or all) GTFS Realtime feeds from the Mobility Database.""" + try: + include_wip = not is_user_email_restricted() + feed_query = get_gtfs_rt_feeds_query( + limit=limit, + offset=offset, + provider=provider, + producer_url=producer_url, + entity_types=entity_types, + country_code=country_code, + subdivision_name=subdivision_name, + municipality=municipality, + is_official=is_official, + include_wip=include_wip, + db_session=db_session, + ) + except InternalHTTPException as e: + raise convert_exception(e) + + return self._get_response(feed_query, GtfsRTFeedImpl, db_session) + entity_types_list = entity_types.split(",") if entity_types else None # Validate entity types using the EntityType enum @@ -369,7 +373,7 @@ def get_gtfs_rt_feeds( .options( joinedload(Gtfsrealtimefeed.entitytypes), joinedload(Gtfsrealtimefeed.gtfs_feeds), - *BasicFeedImpl.get_joinedload_options(), + *get_joinedload_options(), ) .order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id) ) diff --git a/api/src/feeds/impl/models/basic_feed_impl.py b/api/src/feeds/impl/models/basic_feed_impl.py index 19b2537c8..db9a3be0a 100644 --- a/api/src/feeds/impl/models/basic_feed_impl.py +++ b/api/src/feeds/impl/models/basic_feed_impl.py @@ -1,7 +1,4 @@ -from sqlalchemy.orm import joinedload -from sqlalchemy.orm.strategy_options import _AbstractLoad - -from database_gen.sqlacodegen_models import Feed +from shared.database_gen.sqlacodegen_models import Feed from feeds.impl.models.external_id_impl import ExternalIdImpl from feeds.impl.models.redirect_impl import RedirectImpl from feeds_gen.models.basic_feed import BasicFeed @@ -47,16 +44,6 @@ def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None: redirects=sorted([RedirectImpl.from_orm(item) for item in feed.redirectingids], key=lambda x: x.target_id), ) - @staticmethod - def get_joinedload_options() -> [_AbstractLoad]: - """Returns common joinedload options for feeds queries.""" - return [ - joinedload(Feed.locations), - joinedload(Feed.externalids), - joinedload(Feed.redirectingids), - joinedload(Feed.officialstatushistories), - ] - class BasicFeedImpl(BaseFeedImpl, BasicFeed): """Implementation of the `BasicFeed` model. diff --git a/api/src/feeds/impl/models/external_id_impl.py b/api/src/feeds/impl/models/external_id_impl.py index 6617d79db..5fa70133a 100644 --- a/api/src/feeds/impl/models/external_id_impl.py +++ b/api/src/feeds/impl/models/external_id_impl.py @@ -1,4 +1,4 @@ -from database_gen.sqlacodegen_models import Externalid +from shared.database_gen.sqlacodegen_models import Externalid from feeds_gen.models.external_id import ExternalId diff --git a/api/src/feeds/impl/models/gtfs_dataset_impl.py b/api/src/feeds/impl/models/gtfs_dataset_impl.py index 1fe87da88..3d253284a 100644 --- a/api/src/feeds/impl/models/gtfs_dataset_impl.py +++ b/api/src/feeds/impl/models/gtfs_dataset_impl.py @@ -1,7 +1,7 @@ from functools import reduce from typing import List -from database_gen.sqlacodegen_models import Gtfsdataset, Validationreport +from shared.database_gen.sqlacodegen_models import Gtfsdataset, Validationreport from feeds.impl.models.bounding_box_impl import BoundingBoxImpl from feeds.impl.models.validation_report_impl import ValidationReportImpl from feeds_gen.models.gtfs_dataset import GtfsDataset diff --git a/api/src/feeds/impl/models/gtfs_feed_impl.py b/api/src/feeds/impl/models/gtfs_feed_impl.py index ded4d00ff..f78c5080b 100644 --- a/api/src/feeds/impl/models/gtfs_feed_impl.py +++ b/api/src/feeds/impl/models/gtfs_feed_impl.py @@ -1,6 +1,6 @@ from typing import Dict -from database_gen.sqlacodegen_models import Gtfsfeed as GtfsfeedOrm +from shared.database_gen.sqlacodegen_models import Gtfsfeed as GtfsfeedOrm from feeds.impl.models.basic_feed_impl import BaseFeedImpl from feeds.impl.models.latest_dataset_impl import LatestDatasetImpl from feeds.impl.models.location_impl import LocationImpl diff --git a/api/src/feeds/impl/models/gtfs_rt_feed_impl.py b/api/src/feeds/impl/models/gtfs_rt_feed_impl.py index d19956762..5c4905b04 100644 --- a/api/src/feeds/impl/models/gtfs_rt_feed_impl.py +++ b/api/src/feeds/impl/models/gtfs_rt_feed_impl.py @@ -1,6 +1,6 @@ from typing import Dict -from database_gen.sqlacodegen_models import Gtfsrealtimefeed as GtfsRTFeedOrm +from shared.database_gen.sqlacodegen_models import Gtfsrealtimefeed as GtfsRTFeedOrm from feeds.impl.models.basic_feed_impl import BaseFeedImpl from feeds.impl.models.location_impl import LocationImpl from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed diff --git a/api/src/feeds/impl/models/latest_dataset_impl.py b/api/src/feeds/impl/models/latest_dataset_impl.py index 500046ad4..4835adbc1 100644 --- a/api/src/feeds/impl/models/latest_dataset_impl.py +++ b/api/src/feeds/impl/models/latest_dataset_impl.py @@ -1,6 +1,6 @@ from functools import reduce -from database_gen.sqlacodegen_models import Gtfsdataset +from shared.database_gen.sqlacodegen_models import Gtfsdataset from feeds.impl.models.bounding_box_impl import BoundingBoxImpl from feeds.impl.models.validation_report_impl import ValidationReportImpl from feeds_gen.models.latest_dataset import LatestDataset diff --git a/api/src/feeds/impl/models/location_impl.py b/api/src/feeds/impl/models/location_impl.py index 385aab593..dc8684334 100644 --- a/api/src/feeds/impl/models/location_impl.py +++ b/api/src/feeds/impl/models/location_impl.py @@ -1,5 +1,5 @@ from feeds_gen.models.location import Location -from database_gen.sqlacodegen_models import Location as LocationOrm +from shared.database_gen.sqlacodegen_models import Location as LocationOrm class LocationImpl(Location): diff --git a/api/src/feeds/impl/models/redirect_impl.py b/api/src/feeds/impl/models/redirect_impl.py index 2ee21b5ae..6964d9839 100644 --- a/api/src/feeds/impl/models/redirect_impl.py +++ b/api/src/feeds/impl/models/redirect_impl.py @@ -1,4 +1,4 @@ -from database_gen.sqlacodegen_models import Redirectingid +from shared.database_gen.sqlacodegen_models import Redirectingid from feeds_gen.models.redirect import Redirect diff --git a/api/src/feeds/impl/models/validation_report_impl.py b/api/src/feeds/impl/models/validation_report_impl.py index d4403ba18..9a8673761 100644 --- a/api/src/feeds/impl/models/validation_report_impl.py +++ b/api/src/feeds/impl/models/validation_report_impl.py @@ -1,4 +1,4 @@ -from database_gen.sqlacodegen_models import Validationreport +from shared.database_gen.sqlacodegen_models import Validationreport from feeds_gen.models.validation_report import ValidationReport from utils.logger import Logger diff --git a/api/src/feeds/impl/search_api_impl.py b/api/src/feeds/impl/search_api_impl.py index 59a894b19..df3a1afa2 100644 --- a/api/src/feeds/impl/search_api_impl.py +++ b/api/src/feeds/impl/search_api_impl.py @@ -3,9 +3,9 @@ from sqlalchemy import func, select from sqlalchemy.orm import Query, Session -from database.database import Database, with_db_session -from database.sql_functions.unaccent import unaccent -from database_gen.sqlacodegen_models import t_feedsearch +from shared.database.database import Database, with_db_session +from shared.database.sql_functions.unaccent import unaccent +from shared.database_gen.sqlacodegen_models import t_feedsearch from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl from feeds_gen.apis.search_api_base import BaseSearchApi from feeds_gen.models.search_feeds200_response import SearchFeeds200Response diff --git a/api/src/scripts/gbfs_utils/comparison.py b/api/src/scripts/gbfs_utils/comparison.py index 5c2795c2a..22691d5a2 100644 --- a/api/src/scripts/gbfs_utils/comparison.py +++ b/api/src/scripts/gbfs_utils/comparison.py @@ -1,6 +1,6 @@ import pandas as pd from sqlalchemy.orm import joinedload -from database_gen.sqlacodegen_models import Gbfsfeed +from shared.database_gen.sqlacodegen_models import Gbfsfeed def generate_system_csv_from_db(df, db_session): diff --git a/api/src/scripts/load_dataset_on_create.py b/api/src/scripts/load_dataset_on_create.py index 9bd78cc41..2e6e31a03 100644 --- a/api/src/scripts/load_dataset_on_create.py +++ b/api/src/scripts/load_dataset_on_create.py @@ -9,7 +9,7 @@ from google.cloud import pubsub_v1 from google.cloud.pubsub_v1.futures import Future -from database_gen.sqlacodegen_models import Feed +from shared.database_gen.sqlacodegen_models import Feed from utils.logger import Logger # Lazy create so we won't try to connect to google cloud when the file is imported. diff --git a/api/src/scripts/populate_db.py b/api/src/scripts/populate_db.py index ab2c4842d..024c89ffe 100644 --- a/api/src/scripts/populate_db.py +++ b/api/src/scripts/populate_db.py @@ -7,8 +7,8 @@ import pandas from dotenv import load_dotenv -from database.database import Database -from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed +from shared.database.database import Database +from shared.database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed from utils.logger import Logger if TYPE_CHECKING: diff --git a/api/src/scripts/populate_db_gbfs.py b/api/src/scripts/populate_db_gbfs.py index 1b87e7448..b8a4fc9d5 100644 --- a/api/src/scripts/populate_db_gbfs.py +++ b/api/src/scripts/populate_db_gbfs.py @@ -4,8 +4,8 @@ import pytz import pycountry -from database.database import generate_unique_id, configure_polymorphic_mappers -from database_gen.sqlacodegen_models import Gbfsfeed, Location, Gbfsversion, Externalid +from shared.database.database import generate_unique_id, configure_polymorphic_mappers +from shared.database_gen.sqlacodegen_models import Gbfsfeed, Location, Gbfsversion, Externalid from scripts.gbfs_utils.comparison import generate_system_csv_from_db, compare_db_to_csv from scripts.gbfs_utils.fetching import fetch_data, get_data_content, get_gbfs_versions from scripts.gbfs_utils.license import get_license_url diff --git a/api/src/scripts/populate_db_gtfs.py b/api/src/scripts/populate_db_gtfs.py index 19a6322e9..744ca0099 100644 --- a/api/src/scripts/populate_db_gtfs.py +++ b/api/src/scripts/populate_db_gtfs.py @@ -6,8 +6,8 @@ import pytz from sqlalchemy import text -from database.database import generate_unique_id, configure_polymorphic_mappers -from database_gen.sqlacodegen_models import ( +from shared.database.database import generate_unique_id, configure_polymorphic_mappers +from shared.database_gen.sqlacodegen_models import ( Entitytype, Externalid, Gtfsrealtimefeed, diff --git a/api/src/scripts/populate_db_test_data.py b/api/src/scripts/populate_db_test_data.py index 093bedc26..21cdc7cca 100644 --- a/api/src/scripts/populate_db_test_data.py +++ b/api/src/scripts/populate_db_test_data.py @@ -5,8 +5,8 @@ from google.cloud.sql.connector.instance import logger from sqlalchemy import text -from database.database import with_db_session -from database_gen.sqlacodegen_models import ( +from shared.database.database import with_db_session +from shared.database_gen.sqlacodegen_models import ( Gtfsdataset, Validationreport, Gtfsfeed, diff --git a/api/src/shared/common/db_utils.py b/api/src/shared/common/db_utils.py new file mode 100644 index 000000000..1f1858eda --- /dev/null +++ b/api/src/shared/common/db_utils.py @@ -0,0 +1,252 @@ +from geoalchemy2 import WKTElement +from sqlalchemy import select +from sqlalchemy.orm import joinedload, Session +from sqlalchemy.orm.query import Query +from sqlalchemy.orm.strategy_options import _AbstractLoad + +from shared.database_gen.sqlacodegen_models import ( + Feed, + Gtfsdataset, + Gtfsfeed, + Location, + Validationreport, + Gtfsrealtimefeed, + Entitytype, + Redirectingid, +) + +from shared.feed_filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter +from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter + +from .entity_type_enum import EntityType + +from sqlalchemy import or_ + +from .error_handling import raise_internal_http_validation_error, invalid_bounding_coordinates, invalid_bounding_method + + +def get_gtfs_feeds_query( + limit: int | None, + offset: int | None, + provider: str | None, + producer_url: str | None, + country_code: str | None, + subdivision_name: str | None, + municipality: str | None, + dataset_latitudes: str | None, + dataset_longitudes: str | None, + bounding_filter_method: str | None, + is_official: bool = False, + include_wip: bool = False, + db_session: Session = None, +) -> Query[any]: + """Get the DB query to use to retrieve the GTFS feeds..""" + gtfs_feed_filter = GtfsFeedFilter( + stable_id=None, + provider__ilike=provider, + producer_url__ilike=producer_url, + location=LocationFilter( + country_code=country_code, + subdivision_name__ilike=subdivision_name, + municipality__ilike=municipality, + ), + ) + + subquery = gtfs_feed_filter.filter(select(Gtfsfeed.id).join(Location, Gtfsfeed.locations)) + subquery = apply_bounding_filtering( + subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method + ).subquery() + + feed_query = db_session.query(Gtfsfeed).filter(Gtfsfeed.id.in_(subquery)) + if not include_wip: + feed_query = feed_query.filter( + or_(Gtfsfeed.operational_status == None, Gtfsfeed.operational_status != "wip") # noqa: E711 + ) + + feed_query = feed_query.options( + joinedload(Gtfsfeed.gtfsdatasets) + .joinedload(Gtfsdataset.validation_reports) + .joinedload(Validationreport.notices), + *get_joinedload_options(), + ).order_by(Gtfsfeed.provider, Gtfsfeed.stable_id) + if is_official: + feed_query = feed_query.filter(Feed.official) + feed_query = feed_query.limit(limit).offset(offset) + return feed_query + + +def get_all_gtfs_feeds_query( + include_wip: bool = False, + db_session: Session = None, +) -> Query[any]: + """Get the DB query to use to retrieve all the GTFS feeds, filtering out the WIP if needed""" + + feed_query = db_session.query(Gtfsfeed) + + if not include_wip: + feed_query = feed_query.filter( + or_(Gtfsfeed.operational_status == None, Gtfsfeed.operational_status != "wip") # noqa: E711 + ) + + feed_query = feed_query.options( + joinedload(Gtfsfeed.gtfsdatasets) + .joinedload(Gtfsdataset.validation_reports) + .joinedload(Validationreport.features), + *get_joinedload_options(), + ).order_by(Gtfsfeed.stable_id) + + return feed_query + + +def get_gtfs_rt_feeds_query( + limit: int | None, + offset: int | None, + provider: str | None, + producer_url: str | None, + entity_types: str | None, + country_code: str | None, + subdivision_name: str | None, + municipality: str | None, + is_official: bool | None, + include_wip: bool = False, + db_session: Session = None, +) -> Query: + """Get some (or all) GTFS Realtime feeds from the Mobility Database.""" + entity_types_list = entity_types.split(",") if entity_types else None + + # Validate entity types using the EntityType enum + if entity_types_list: + try: + entity_types_list = [EntityType(et.strip()).value for et in entity_types_list] + except ValueError: + raise_internal_http_validation_error( + "Entity types must be the value 'vp', 'sa', or 'tu'. " + "When provided a list values must be separated by commas." + ) + + gtfs_rt_feed_filter = GtfsRtFeedFilter( + stable_id=None, + provider__ilike=provider, + producer_url__ilike=producer_url, + entity_types=EntityTypeFilter(name__in=entity_types_list), + location=LocationFilter( + country_code=country_code, + subdivision_name__ilike=subdivision_name, + municipality__ilike=municipality, + ), + ) + subquery = gtfs_rt_feed_filter.filter( + select(Gtfsrealtimefeed.id) + .join(Location, Gtfsrealtimefeed.locations) + .join(Entitytype, Gtfsrealtimefeed.entitytypes) + ).subquery() + feed_query = db_session.query(Gtfsrealtimefeed).filter(Gtfsrealtimefeed.id.in_(subquery)) + + if not include_wip: + feed_query = feed_query.filter( + or_( + Gtfsrealtimefeed.operational_status == None, # noqa: E711 + Gtfsrealtimefeed.operational_status != "wip", + ) + ) + + feed_query = feed_query.options( + joinedload(Gtfsrealtimefeed.entitytypes), + joinedload(Gtfsrealtimefeed.gtfs_feeds), + *get_joinedload_options(), + ) + if is_official: + feed_query = feed_query.filter(Feed.official) + feed_query = feed_query.limit(limit).offset(offset) + return feed_query + + +def get_all_gtfs_rt_feeds_query( + include_wip: bool = False, + db_session: Session = None, +) -> Query: + """Get the DB query to use to retrieve all the GTFS rt feeds, filtering out the WIP if needed""" + feed_query = db_session.query(Gtfsrealtimefeed) + + if not include_wip: + feed_query = feed_query.filter( + or_( + Gtfsrealtimefeed.operational_status == None, # noqa: E711 + Gtfsrealtimefeed.operational_status != "wip", + ) + ) + + feed_query = feed_query.options( + joinedload(Gtfsrealtimefeed.entitytypes), + joinedload(Gtfsrealtimefeed.gtfs_feeds), + *get_joinedload_options(), + ).order_by(Gtfsfeed.stable_id) + + return feed_query + + +def apply_bounding_filtering( + query: Query, + bounding_latitudes: str, + bounding_longitudes: str, + bounding_filter_method: str, +) -> Query: + """Create a new query based on the bounding parameters.""" + + if not bounding_latitudes or not bounding_longitudes or not bounding_filter_method: + return query + + if ( + len(bounding_latitudes_tokens := bounding_latitudes.split(",")) != 2 + or len(bounding_longitudes_tokens := bounding_longitudes.split(",")) != 2 + ): + raise_internal_http_validation_error( + invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes) + ) + min_latitude, max_latitude = bounding_latitudes_tokens + min_longitude, max_longitude = bounding_longitudes_tokens + try: + min_latitude = float(min_latitude) + max_latitude = float(max_latitude) + min_longitude = float(min_longitude) + max_longitude = float(max_longitude) + except ValueError: + raise_internal_http_validation_error( + invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes) + ) + points = [ + (min_longitude, min_latitude), + (min_longitude, max_latitude), + (max_longitude, max_latitude), + (max_longitude, min_latitude), + (min_longitude, min_latitude), + ] + wkt_polygon = f"POLYGON(({', '.join(f'{lon} {lat}' for lon, lat in points)}))" + bounding_box = WKTElement( + wkt_polygon, + srid=Gtfsdataset.bounding_box.type.srid, + ) + + if bounding_filter_method == "partially_enclosed": + return query.filter( + or_( + Gtfsdataset.bounding_box.ST_Overlaps(bounding_box), + Gtfsdataset.bounding_box.ST_Contains(bounding_box), + ) + ) + elif bounding_filter_method == "completely_enclosed": + return query.filter(bounding_box.ST_Covers(Gtfsdataset.bounding_box)) + elif bounding_filter_method == "disjoint": + return query.filter(Gtfsdataset.bounding_box.ST_Disjoint(bounding_box)) + else: + raise_internal_http_validation_error(invalid_bounding_method.format(bounding_filter_method)) + + +def get_joinedload_options() -> [_AbstractLoad]: + """Returns common joinedload options for feeds queries.""" + return [ + joinedload(Feed.locations), + joinedload(Feed.externalids), + joinedload(Feed.redirectingids).joinedload(Redirectingid.target), + joinedload(Feed.officialstatushistories), + ] diff --git a/api/src/shared/common/entity_type_enum.py b/api/src/shared/common/entity_type_enum.py new file mode 100644 index 000000000..a7cc72600 --- /dev/null +++ b/api/src/shared/common/entity_type_enum.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class EntityType(Enum): + """ + Enum for the entity type + """ + + VP = "vp" + SA = "sa" + TU = "tu" diff --git a/api/src/shared/common/error_handling.py b/api/src/shared/common/error_handling.py new file mode 100644 index 000000000..9ffec4e50 --- /dev/null +++ b/api/src/shared/common/error_handling.py @@ -0,0 +1,49 @@ +from typing import Final + +invalid_date_message: Final[ + str +] = "Invalid date format for '{}'. Expected ISO 8601 format, example: '2021-01-01T00:00:00Z'" +invalid_bounding_coordinates: Final[str] = "Invalid bounding coordinates {} {}" +invalid_bounding_method: Final[str] = "Invalid bounding_filter_method {}" +feed_not_found: Final[str] = "Feed '{}' not found" +gtfs_feed_not_found: Final[str] = "GTFS feed '{}' not found" +gtfs_rt_feed_not_found: Final[str] = "GTFS realtime Feed '{}' not found" +dataset_not_found: Final[str] = "Dataset '{}' not found" + + +class InternalHTTPException(Exception): + """ + This class is used instead of the HTTPException because we don't want to depend on fastapi and have to deploy it. + At one point this exception needs to be caught and converted to a fastapi HTTPException, + """ + + def __init__(self, status_code: int, detail: str) -> None: + self.status_code = status_code + self.detail = detail + super().__init__(f"Status Code: {status_code}, Detail: {detail}") + + +def raise_internal_http_error(status_code: int, error: str): + """Raise a InternalHTTPException. + :param status_code: The status code of the error. + :param error: The error message to be raised. + example of output: + { + "detail": "Invalid date format for 'field_name'. Expected ISO 8601 format, example: '2021-01-01T00:00:00Z'" + } + """ + raise InternalHTTPException( + status_code=status_code, + detail=error, + ) + + +def raise_internal_http_validation_error(error: str): + """Raise a InternalHTTPException with status code 422 and the error message. + :param error: The error message to be raised. + example of output: + { + "detail": "Invalid date format for 'field_name'. Expected ISO 8601 format, example: '2021-01-01T00:00:00Z'" + } + """ + raise_internal_http_error(422, error) diff --git a/api/src/database/__init__.py b/api/src/shared/database/__init__.py similarity index 100% rename from api/src/database/__init__.py rename to api/src/shared/database/__init__.py diff --git a/api/src/database/database.py b/api/src/shared/database/database.py similarity index 98% rename from api/src/database/database.py rename to api/src/shared/database/database.py index 25dcfff50..204d1377c 100644 --- a/api/src/database/database.py +++ b/api/src/shared/database/database.py @@ -7,7 +7,7 @@ from dotenv import load_dotenv from sqlalchemy import create_engine from sqlalchemy.orm import load_only, Query, class_mapper, Session -from database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed +from shared.database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed from sqlalchemy.orm import sessionmaker import logging diff --git a/api/src/database/sql_functions/unaccent.py b/api/src/shared/database/sql_functions/unaccent.py similarity index 100% rename from api/src/database/sql_functions/unaccent.py rename to api/src/shared/database/sql_functions/unaccent.py diff --git a/api/src/feeds/filters/__init__.py b/api/src/shared/feed_filters/__init__.py similarity index 100% rename from api/src/feeds/filters/__init__.py rename to api/src/shared/feed_filters/__init__.py diff --git a/api/src/feeds/filters/feed_filter.py b/api/src/shared/feed_filters/feed_filter.py similarity index 86% rename from api/src/feeds/filters/feed_filter.py rename to api/src/shared/feed_filters/feed_filter.py index 2ee7f481c..16058efda 100644 --- a/api/src/feeds/filters/feed_filter.py +++ b/api/src/shared/feed_filters/feed_filter.py @@ -2,8 +2,8 @@ from fastapi_filter.contrib.sqlalchemy import Filter -from database_gen.sqlacodegen_models import Feed -from utils.param_utils import normalize_str_parameter +from shared.database_gen.sqlacodegen_models import Feed +from shared.feed_filters.param_utils import normalize_str_parameter class FeedFilter(Filter): diff --git a/api/src/feeds/filters/gtfs_dataset_filter.py b/api/src/shared/feed_filters/gtfs_dataset_filter.py similarity index 81% rename from api/src/feeds/filters/gtfs_dataset_filter.py rename to api/src/shared/feed_filters/gtfs_dataset_filter.py index 0165c0bf5..9ec5af5ce 100644 --- a/api/src/feeds/filters/gtfs_dataset_filter.py +++ b/api/src/shared/feed_filters/gtfs_dataset_filter.py @@ -2,8 +2,8 @@ from datetime import datetime from fastapi_filter.contrib.sqlalchemy import Filter -from database_gen.sqlacodegen_models import Gtfsdataset -from utils.param_utils import normalize_str_parameter +from shared.database_gen.sqlacodegen_models import Gtfsdataset +from shared.feed_filters.param_utils import normalize_str_parameter class GtfsDatasetFilter(Filter): diff --git a/api/src/feeds/filters/gtfs_feed_filter.py b/api/src/shared/feed_filters/gtfs_feed_filter.py similarity index 87% rename from api/src/feeds/filters/gtfs_feed_filter.py rename to api/src/shared/feed_filters/gtfs_feed_filter.py index a1b404108..b4e3e6ae9 100644 --- a/api/src/feeds/filters/gtfs_feed_filter.py +++ b/api/src/shared/feed_filters/gtfs_feed_filter.py @@ -2,8 +2,8 @@ from fastapi_filter.contrib.sqlalchemy import Filter -from database_gen.sqlacodegen_models import Location, Gtfsfeed -from utils.param_utils import normalize_str_parameter +from shared.database_gen.sqlacodegen_models import Location, Gtfsfeed +from shared.feed_filters.param_utils import normalize_str_parameter class LocationFilter(Filter): diff --git a/api/src/feeds/filters/gtfs_rt_feed_filter.py b/api/src/shared/feed_filters/gtfs_rt_feed_filter.py similarity index 83% rename from api/src/feeds/filters/gtfs_rt_feed_filter.py rename to api/src/shared/feed_filters/gtfs_rt_feed_filter.py index 15fc3dca1..ae7cabfa4 100644 --- a/api/src/feeds/filters/gtfs_rt_feed_filter.py +++ b/api/src/shared/feed_filters/gtfs_rt_feed_filter.py @@ -2,9 +2,9 @@ from fastapi_filter.contrib.sqlalchemy import Filter -from database_gen.sqlacodegen_models import Gtfsrealtimefeed, Entitytype -from feeds.filters.gtfs_feed_filter import LocationFilter -from utils.param_utils import normalize_str_parameter +from shared.database_gen.sqlacodegen_models import Gtfsrealtimefeed, Entitytype +from shared.feed_filters.gtfs_feed_filter import LocationFilter +from shared.feed_filters.param_utils import normalize_str_parameter class EntityTypeFilter(Filter): diff --git a/api/src/utils/param_utils.py b/api/src/shared/feed_filters/param_utils.py similarity index 100% rename from api/src/utils/param_utils.py rename to api/src/shared/feed_filters/param_utils.py diff --git a/api/src/utils/location_translation.py b/api/src/utils/location_translation.py index 7aabe6c8e..938c2287b 100644 --- a/api/src/utils/location_translation.py +++ b/api/src/utils/location_translation.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING from sqlalchemy.engine.result import Row -from database_gen.sqlacodegen_models import Location as LocationOrm, t_location_with_translations_en -from database_gen.sqlacodegen_models import Feed as FeedOrm +from shared.database_gen.sqlacodegen_models import Location as LocationOrm, t_location_with_translations_en +from shared.database_gen.sqlacodegen_models import Feed as FeedOrm if TYPE_CHECKING: from sqlalchemy.orm import Session diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 2650fb384..879bdbbf6 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -4,7 +4,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from database.database import Database +from shared.database.database import Database from main import app as application from tests.test_utils.database import populate_database diff --git a/api/tests/integration/test_database.py b/api/tests/integration/test_database.py index 7394b7810..9c416a711 100644 --- a/api/tests/integration/test_database.py +++ b/api/tests/integration/test_database.py @@ -4,8 +4,9 @@ import pytest from sqlalchemy.orm import Query -from database.database import Database, generate_unique_id -from database_gen.sqlacodegen_models import Feature, Gtfsdataset +from shared.common.db_utils import apply_bounding_filtering +from shared.database.database import Database, generate_unique_id +from shared.database_gen.sqlacodegen_models import Feature, Gtfsdataset from feeds.impl.datasets_api_impl import DatasetsApiImpl from feeds.impl.feeds_api_impl import FeedsApiImpl from faker import Faker @@ -39,7 +40,7 @@ def test_bounding_box_dateset_exists(test_database): def assert_bounding_box_found(latitudes, longitudes, method, expected_found, test_database): with test_database.start_db_session() as session: - query = DatasetsApiImpl.apply_bounding_filtering(BASE_QUERY, latitudes, longitudes, method) + query = apply_bounding_filtering(BASE_QUERY, latitudes, longitudes, method) result = test_database.select(session, query=query) assert (len(result) > 0) is expected_found diff --git a/api/tests/test_utils/database.py b/api/tests/test_utils/database.py index 22776543d..433adb5d6 100644 --- a/api/tests/test_utils/database.py +++ b/api/tests/test_utils/database.py @@ -6,7 +6,7 @@ from sqlalchemy.engine.url import make_url from tests.test_utils.db_utils import dump_database, is_test_db, dump_raw_database, empty_database -from database.database import Database +from shared.database.database import Database from scripts.populate_db_gtfs import GTFSDatabasePopulateHelper from scripts.populate_db_test_data import DatabasePopulateTestDataHelper diff --git a/api/tests/test_utils/db_utils.py b/api/tests/test_utils/db_utils.py index ef0b98255..cbabafa35 100644 --- a/api/tests/test_utils/db_utils.py +++ b/api/tests/test_utils/db_utils.py @@ -8,7 +8,7 @@ from sqlalchemy import Inspector, delete import json -from database_gen.sqlacodegen_models import Base +from shared.database_gen.sqlacodegen_models import Base class CustomEncoder(json.JSONEncoder): diff --git a/api/tests/unittest/conftest.py b/api/tests/unittest/conftest.py index c02707436..d443fe2ea 100644 --- a/api/tests/unittest/conftest.py +++ b/api/tests/unittest/conftest.py @@ -4,7 +4,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from database.database import Database +from shared.database.database import Database from main import app as application from tests.test_utils.database import populate_database diff --git a/api/tests/unittest/models/test_basic_feed_impl.py b/api/tests/unittest/models/test_basic_feed_impl.py index 8f23c0539..d68d8ff42 100644 --- a/api/tests/unittest/models/test_basic_feed_impl.py +++ b/api/tests/unittest/models/test_basic_feed_impl.py @@ -2,7 +2,7 @@ import unittest from datetime import datetime, date -from database_gen.sqlacodegen_models import ( +from shared.database_gen.sqlacodegen_models import ( Feed, Externalid, Location, diff --git a/api/tests/unittest/models/test_external_id_impl.py b/api/tests/unittest/models/test_external_id_impl.py index c04213425..a23c75ff5 100644 --- a/api/tests/unittest/models/test_external_id_impl.py +++ b/api/tests/unittest/models/test_external_id_impl.py @@ -1,6 +1,6 @@ import unittest -from database_gen.sqlacodegen_models import Externalid +from shared.database_gen.sqlacodegen_models import Externalid from feeds.impl.models.external_id_impl import ExternalIdImpl external_id_orm = Externalid( diff --git a/api/tests/unittest/models/test_gtfs_dataset_impl.py b/api/tests/unittest/models/test_gtfs_dataset_impl.py index d640e89b7..cc32ecc22 100644 --- a/api/tests/unittest/models/test_gtfs_dataset_impl.py +++ b/api/tests/unittest/models/test_gtfs_dataset_impl.py @@ -3,7 +3,7 @@ from geoalchemy2 import WKTElement -from database_gen.sqlacodegen_models import Validationreport, Gtfsdataset, Feed +from shared.database_gen.sqlacodegen_models import Validationreport, Gtfsdataset, Feed from feeds.impl.models.gtfs_dataset_impl import GtfsDatasetImpl POLYGON = "POLYGON ((3.0 1.0, 4.0 1.0, 4.0 2.0, 3.0 2.0, 3.0 1.0))" diff --git a/api/tests/unittest/models/test_gtfs_feed_impl.py b/api/tests/unittest/models/test_gtfs_feed_impl.py index 8667da58f..41ee80a84 100644 --- a/api/tests/unittest/models/test_gtfs_feed_impl.py +++ b/api/tests/unittest/models/test_gtfs_feed_impl.py @@ -4,7 +4,7 @@ from geoalchemy2 import WKTElement -from database_gen.sqlacodegen_models import ( +from shared.database_gen.sqlacodegen_models import ( Redirectingid, Feature, Validationreport, diff --git a/api/tests/unittest/models/test_gtfs_rt_feed_impl.py b/api/tests/unittest/models/test_gtfs_rt_feed_impl.py index 76a7b792a..9c54c4422 100644 --- a/api/tests/unittest/models/test_gtfs_rt_feed_impl.py +++ b/api/tests/unittest/models/test_gtfs_rt_feed_impl.py @@ -1,14 +1,20 @@ import unittest import copy -from database_gen.sqlacodegen_models import Gtfsrealtimefeed, Entitytype, Externalid, Location, Redirectingid, Feed +from shared.database_gen.sqlacodegen_models import ( + Gtfsrealtimefeed, + Entitytype, + Externalid, + Location, + Redirectingid, + Feed, +) from feeds_gen.models.source_info import SourceInfo from feeds.impl.models.gtfs_rt_feed_impl import GtfsRTFeedImpl from feeds.impl.models.external_id_impl import ExternalIdImpl from feeds.impl.models.location_impl import LocationImpl from feeds.impl.models.redirect_impl import RedirectImpl - targetFeed = Feed( id="id1", stable_id="target_id", diff --git a/api/tests/unittest/models/test_latest_dataset_impl.py b/api/tests/unittest/models/test_latest_dataset_impl.py index c0f1fdd22..f1f6b9e42 100644 --- a/api/tests/unittest/models/test_latest_dataset_impl.py +++ b/api/tests/unittest/models/test_latest_dataset_impl.py @@ -3,7 +3,7 @@ from geoalchemy2 import WKTElement -from database_gen.sqlacodegen_models import Gtfsdataset, Feed, Validationreport, Notice +from shared.database_gen.sqlacodegen_models import Gtfsdataset, Feed, Validationreport, Notice from feeds.impl.models.bounding_box_impl import BoundingBoxImpl from feeds.impl.models.latest_dataset_impl import LatestDatasetImpl diff --git a/api/tests/unittest/models/test_location_impl.py b/api/tests/unittest/models/test_location_impl.py index 24ba78529..ac0ead801 100644 --- a/api/tests/unittest/models/test_location_impl.py +++ b/api/tests/unittest/models/test_location_impl.py @@ -1,7 +1,7 @@ import unittest from feeds.impl.models.location_impl import LocationImpl -from database_gen.sqlacodegen_models import Location as LocationOrm +from shared.database_gen.sqlacodegen_models import Location as LocationOrm class TestLocationImpl(unittest.TestCase): diff --git a/api/tests/unittest/models/test_redirect_id_impl.py b/api/tests/unittest/models/test_redirect_id_impl.py index 32fb3aa64..9789ec577 100644 --- a/api/tests/unittest/models/test_redirect_id_impl.py +++ b/api/tests/unittest/models/test_redirect_id_impl.py @@ -1,7 +1,7 @@ import unittest -from database_gen.sqlacodegen_models import Redirectingid -from database_gen.sqlacodegen_models import Feed +from shared.database_gen.sqlacodegen_models import Redirectingid +from shared.database_gen.sqlacodegen_models import Feed from feeds.impl.models.redirect_impl import RedirectImpl redirect_orm = Redirectingid( diff --git a/api/tests/unittest/models/test_validation_report_impl.py b/api/tests/unittest/models/test_validation_report_impl.py index 14f97a917..9afdf2d2a 100644 --- a/api/tests/unittest/models/test_validation_report_impl.py +++ b/api/tests/unittest/models/test_validation_report_impl.py @@ -1,7 +1,7 @@ import unittest from datetime import datetime -from database_gen.sqlacodegen_models import Validationreport, Notice, Feature +from shared.database_gen.sqlacodegen_models import Validationreport, Notice, Feature from feeds.impl.models.validation_report_impl import ValidationReportImpl diff --git a/api/tests/unittest/test_feeds.py b/api/tests/unittest/test_feeds.py index 3f7a9091c..22f07a4a5 100644 --- a/api/tests/unittest/test_feeds.py +++ b/api/tests/unittest/test_feeds.py @@ -5,9 +5,16 @@ from fastapi.testclient import TestClient -from database.database import Database -from database_gen.sqlacodegen_models import Feed, Externalid, Gtfsdataset, Redirectingid, Gtfsfeed, Gtfsrealtimefeed -from feeds.filters.feed_filter import FeedFilter +from shared.database.database import Database +from shared.database_gen.sqlacodegen_models import ( + Feed, + Externalid, + Gtfsdataset, + Redirectingid, + Gtfsfeed, + Gtfsrealtimefeed, +) +from shared.feed_filters.feed_filter import FeedFilter from feeds.impl.models.basic_feed_impl import BaseFeedImpl from tests.test_utils.database import TEST_GTFS_FEED_STABLE_IDS, TEST_GTFS_RT_FEED_STABLE_ID from tests.test_utils.token import authHeaders diff --git a/api/tests/unittest/test_param_utils.py b/api/tests/unittest/test_param_utils.py index 657db7061..7ade95420 100644 --- a/api/tests/unittest/test_param_utils.py +++ b/api/tests/unittest/test_param_utils.py @@ -1,4 +1,4 @@ -from utils.param_utils import normalize_str_parameter +from shared.feed_filters.param_utils import normalize_str_parameter def test_normalize_str_parameter(): diff --git a/functions-python/batch_process_dataset/README.md b/functions-python/batch_process_dataset/README.md index aaa8f34e4..5fc3ea55c 100644 --- a/functions-python/batch_process_dataset/README.md +++ b/functions-python/batch_process_dataset/README.md @@ -32,7 +32,7 @@ The function expects a Pub/Sub message with the following format: # Function configuration The function is configured using the following environment variables: -- `DATASETS_BUCKET_NANE`: The name of the bucket where the datasets are stored. +- `DATASETS_BUCKET_NAME`: The name of the bucket where the datasets are stored. - `FEEDS_DATABASE_URL`: The URL of the feeds database. - `MAXIMUM_EXECUTIONS`: [Optional] The maximum number of executions per datasets. This controls the number of times a dataset can be processed per execution id. By default, is 1. diff --git a/functions-python/batch_process_dataset/src/main.py b/functions-python/batch_process_dataset/src/main.py index d415167bc..f20b049f0 100644 --- a/functions-python/batch_process_dataset/src/main.py +++ b/functions-python/batch_process_dataset/src/main.py @@ -308,8 +308,7 @@ def process_dataset(cloud_event: CloudEvent): Logger.init_logger() logging.info("Function Started") stable_id = "UNKNOWN" - execution_id = "UNKNOWN" - bucket_name = os.getenv("DATASETS_BUCKET_NANE") + bucket_name = os.getenv("DATASETS_BUCKET_NAME") try: # Extract data from message diff --git a/functions-python/export_csv/.coveragerc b/functions-python/export_csv/.coveragerc new file mode 100644 index 000000000..f1916e53f --- /dev/null +++ b/functions-python/export_csv/.coveragerc @@ -0,0 +1,11 @@ +[run] +omit = + */test*/* + */database_gen/* + */dataset_service/* + */helpers/* + */shared/* + +[report] +exclude_lines = + if __name__ == .__main__.: \ No newline at end of file diff --git a/functions-python/export_csv/function_config.json b/functions-python/export_csv/function_config.json new file mode 100644 index 000000000..0fa2d59af --- /dev/null +++ b/functions-python/export_csv/function_config.json @@ -0,0 +1,20 @@ +{ + "name": "export-csv", + "description": "Export the DB feed data as a csv file", + "entry_point": "export_and_upload_csv", + "timeout": 20, + "memory": "1Gi", + "trigger_http": true, + "include_folders": ["helpers", "dataset_service"], + "include_api_folders": ["utils", "database", "feed_filters", "common", "database_gen"], + "secret_environment_variables": [ + { + "key": "FEEDS_DATABASE_URL" + } + ], + "ingress_settings": "ALLOW_INTERNAL_AND_GCLB", + "max_instance_request_concurrency": 1, + "max_instance_count": 1, + "min_instance_count": 0, + "available_cpu": 1 +} diff --git a/functions-python/export_csv/requirements.txt b/functions-python/export_csv/requirements.txt new file mode 100644 index 000000000..7fd97e18c --- /dev/null +++ b/functions-python/export_csv/requirements.txt @@ -0,0 +1,23 @@ +# Common packages +psycopg2-binary==2.9.6 +aiohttp~=3.10.5 +asyncio~=3.4.3 +urllib3~=2.2.2 +requests~=2.32.3 +attrs~=23.1.0 +pluggy~=1.3.0 +certifi~=2024.7.4 +pandas~=2.2.3 +python-dotenv==1.0.0 +fastapi-filter[sqlalchemy]==1.0.0 +packaging~=24.2 + +# SQL Alchemy and Geo Alchemy +SQLAlchemy==2.0.23 +geoalchemy2==0.14.7 +shapely + +# Google +google-cloud-storage +functions-framework==3.* +google-cloud-logging \ No newline at end of file diff --git a/functions-python/export_csv/requirements_dev.txt b/functions-python/export_csv/requirements_dev.txt new file mode 100644 index 000000000..9ee50adce --- /dev/null +++ b/functions-python/export_csv/requirements_dev.txt @@ -0,0 +1,2 @@ +Faker +pytest~=7.4.3 \ No newline at end of file diff --git a/functions-python/export_csv/src/main.py b/functions-python/export_csv/src/main.py new file mode 100644 index 000000000..6f7e4a73c --- /dev/null +++ b/functions-python/export_csv/src/main.py @@ -0,0 +1,360 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import argparse +import logging +import os +import re + +import pandas as pd + +from dotenv import load_dotenv +import functions_framework + +from packaging.version import Version +from functools import reduce +from google.cloud import storage +from geoalchemy2.shape import to_shape + +from shared.helpers.logger import Logger +from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed +from collections import OrderedDict +from shared.common.db_utils import get_all_gtfs_rt_feeds_query, get_all_gtfs_feeds_query + +from shared.helpers.database import Database + +load_dotenv() +csv_default_file_path = "./output.csv" +csv_file_path = csv_default_file_path + + +class DataCollector: + """ + A class used to collect and organize data into rows and headers for CSV output. + One particularity of this class is that it uses an OrderedDict to store the data, so that the order of the columns + is preserved when writing to CSV. + """ + + def __init__(self): + self.data = OrderedDict() + self.rows = [] + self.headers = [] + + def add_data(self, key, value): + if key not in self.headers: + self.headers.append(key) + self.data[key] = value + + def finalize_row(self): + self.rows.append(self.data.copy()) + self.data = OrderedDict() + + def write_csv_to_file(self, csv_file_path): + df = pd.DataFrame(self.rows, columns=self.headers) + df.to_csv(csv_file_path, index=False) + + def get_dataframe(self) -> pd: + return pd.DataFrame(self.rows, columns=self.headers) + + +@functions_framework.http +def export_and_upload_csv(request=None): + response = export_csv() + upload_file_to_storage(csv_file_path, "sources_v2.csv") + return response + + +def export_csv(): + """ + HTTP Function entry point Reads the DB and outputs a csv file with feeds data. + This function requires the following environment variables to be set: + FEEDS_DATABASE_URL: database URL + :param request: HTTP request object + :return: HTTP response object + """ + Logger.init_logger() + logging.info("Function Started") + data_collector = collect_data() + data_collector.write_csv_to_file(csv_file_path) + return f"Exported {len(data_collector.rows)} feeds to CSV file {csv_file_path}." + + +def collect_data() -> DataCollector: + """ + Collect data from the DB and write the output to a DataCollector. + :return: A filled DataCollector + """ + db = Database(database_url=os.getenv("FEEDS_DATABASE_URL")) + logging.info(f"Using database {db.database_url}") + try: + with db.start_db_session() as session: + gtfs_feeds_query = get_all_gtfs_feeds_query( + include_wip=False, + db_session=session, + ) + + gtfs_feeds = gtfs_feeds_query.all() + + logging.info(f"Retrieved {len(gtfs_feeds)} GTFS feeds.") + + gtfs_rt_feeds_query = get_all_gtfs_rt_feeds_query( + include_wip=False, + db_session=session, + ) + + gtfs_rt_feeds = gtfs_rt_feeds_query.all() + + logging.info(f"Retrieved {len(gtfs_rt_feeds)} GTFS realtime feeds.") + + data_collector = DataCollector() + + for feed in gtfs_feeds: + data = get_feed_csv_data(feed) + + for key, value in data.items(): + data_collector.add_data(key, value) + data_collector.finalize_row() + logging.info(f"Processed {len(gtfs_feeds)} GTFS feeds.") + + for feed in gtfs_rt_feeds: + data = get_gtfs_rt_feed_csv_data(feed) + for key, value in data.items(): + data_collector.add_data(key, value) + data_collector.finalize_row() + logging.info(f"Processed {len(gtfs_rt_feeds)} GTFS realtime feeds.") + + except Exception as error: + logging.error(f"Error retrieving feeds: {error}") + raise Exception(f"Error retrieving feeds: {error}") + data_collector.write_csv_to_file(csv_file_path) + return data_collector + + +def extract_numeric_version(version): + match = re.match(r"(\d+\.\d+\.\d+)", version) + return match.group(1) if match else version + + +def get_feed_csv_data(feed: Gtfsfeed): + """ + This function takes a GtfsFeed and returns a dictionary with the data to be written to the CSV file. + """ + latest_dataset = next( + ( + dataset + for dataset in (feed.gtfsdatasets or []) + if dataset and dataset.latest + ), + None, + ) + + joined_features = "" + validated_at = None + minimum_latitude = maximum_latitude = minimum_longitude = maximum_longitude = None + + if latest_dataset and latest_dataset.validation_reports: + # Keep the report from the more recent validator version + latest_report = reduce( + lambda a, b: a + if Version(extract_numeric_version(a.validator_version)) + > Version(extract_numeric_version(b.validator_version)) + else b, + latest_dataset.validation_reports, + ) + + if latest_report: + if latest_report.features: + features = latest_report.features + joined_features = ( + "|".join( + sorted(feature.name for feature in features if feature.name) + ) + if features + else "" + ) + if latest_report.validated_at: + validated_at = latest_report.validated_at + if latest_dataset.bounding_box: + shape = to_shape(latest_dataset.bounding_box) + if shape and shape.bounds: + minimum_latitude = shape.bounds[1] + maximum_latitude = shape.bounds[3] + minimum_longitude = shape.bounds[0] + maximum_longitude = shape.bounds[2] + + latest_url = latest_dataset.hosted_url if latest_dataset else None + if latest_url: + # The url for the latest dataset contains the dataset id which includes the date. + # e.g. https://dev-files.mobilitydatabase.org/mdb-1/mdb-1-202408202229/mdb-1-202408202229.zip + # For the latest url we just want something using latest.zip, e.g: + # https://dev-files.mobilitydatabase.org/mdb-1/latest.zip + # So use the dataset url, but replace what is after the feed stable id by latest.zip + position = latest_url.find(feed.stable_id) + if position != -1: + # Construct the new URL + latest_url = latest_url[: position + len(feed.stable_id) + 1] + "latest.zip" + + data = { + "id": feed.stable_id, + "data_type": feed.data_type, + "entity_type": None, + "location.country_code": "" + if not feed.locations or not feed.locations[0] + else feed.locations[0].country_code, + "location.subdivision_name": "" + if not feed.locations or not feed.locations[0] + else feed.locations[0].subdivision_name, + "location.municipality": "" + if not feed.locations or not feed.locations[0] + else feed.locations[0].municipality, + "provider": feed.provider, + "name": feed.feed_name, + "note": feed.note, + "feed_contact_email": feed.feed_contact_email, + "static_reference": None, + "urls.direct_download": feed.producer_url, + "urls.authentication_type": feed.authentication_type, + "urls.authentication_info": feed.authentication_info_url, + "urls.api_key_parameter_name": feed.api_key_parameter_name, + "urls.latest": latest_url, + "urls.license": feed.license_url, + "location.bounding_box.minimum_latitude": minimum_latitude, + "location.bounding_box.maximum_latitude": maximum_latitude, + "location.bounding_box.minimum_longitude": minimum_longitude, + "location.bounding_box.maximum_longitude": maximum_longitude, + "location.bounding_box.extracted_on": validated_at, + # We use the report validated_at date as the extracted_on date + "status": feed.status, + "features": joined_features, + } + + redirect_ids = "" + redirect_comments = "" + # Add concatenated redirect IDs + if feed.redirectingids: + for redirect in feed.redirectingids: + if redirect and redirect.target and redirect.target.stable_id: + stripped_id = redirect.target.stable_id.strip() + if stripped_id: + redirect_ids = ( + redirect_ids + "|" + stripped_id + if redirect_ids + else stripped_id + ) + redirect_comments = ( + redirect_comments + "|" + redirect.redirect_comment + if redirect_comments + else redirect.redirect_comment + ) + if redirect_ids == "": + redirect_comments = "" + else: + # If there is no comment but we do have redirects, use an empty string instead of a + # potentially a bunch of vertical bars. + redirect_comments = ( + "" if redirect_comments.strip("|") == "" else redirect_comments + ) + + data["redirect.id"] = redirect_ids + data["redirect.comment"] = redirect_comments + + return data + + +def get_gtfs_rt_feed_csv_data(feed: Gtfsrealtimefeed): + """ + This function takes a GtfsRTFeed and returns a dictionary with the data to be written to the CSV file. + """ + entity_types = "" + if feed.entitytypes: + valid_entity_types = [ + entity_type.name.strip() + for entity_type in feed.entitytypes + if entity_type and entity_type.name + ] + valid_entity_types = sorted(valid_entity_types) + entity_types = "|".join(valid_entity_types) + + static_references = "" + if feed.gtfs_feeds: + valid_feed_references = [ + feed_reference.stable_id.strip() + for feed_reference in feed.gtfs_feeds + if feed_reference and feed_reference.stable_id + ] + static_references = "|".join(valid_feed_references) + + data = { + "id": feed.stable_id, + "data_type": feed.data_type, + "entity_type": entity_types, + "location.country_code": "" + if not feed.locations or not feed.locations[0] + else feed.locations[0].country_code, + "location.subdivision_name": "" + if not feed.locations or not feed.locations[0] + else feed.locations[0].subdivision_name, + "location.municipality": "" + if not feed.locations or not feed.locations[0] + else feed.locations[0].municipality, + "provider": feed.provider, + "name": feed.feed_name, + "note": feed.note, + "feed_contact_email": feed.feed_contact_email, + "static_reference": static_references, + "urls.direct_download": feed.producer_url, + "urls.authentication_type": feed.authentication_type, + "urls.authentication_info": feed.authentication_info_url, + "urls.api_key_parameter_name": feed.api_key_parameter_name, + "urls.latest": None, + "urls.license": feed.license_url, + "location.bounding_box.minimum_latitude": None, + "location.bounding_box.maximum_latitude": None, + "location.bounding_box.minimum_longitude": None, + "location.bounding_box.maximum_longitude": None, + "location.bounding_box.extracted_on": None, + "features": None, + "redirect.id": None, + "redirect.comment": None, + } + + return data + + +def upload_file_to_storage(source_file_path, target_path): + """ + Uploads a file to the GCP bucket + """ + bucket_name = os.getenv("DATASETS_BUCKET_NAME") + logging.info(f"Uploading file to bucket {bucket_name} at path {target_path}") + bucket = storage.Client().get_bucket(bucket_name) + blob = bucket.blob(target_path) + with open(source_file_path, "rb") as file: + blob.upload_from_file(file) + blob.make_public() + return blob + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export DB feed contents to csv.") + parser.add_argument( + "--outpath", help="Path to the output csv file. Default is ./output.csv" + ) + os.environ[ + "FEEDS_DATABASE_URL" + ] = "postgresql://postgres:postgres@localhost:54320/MobilityDatabaseTest" + args = parser.parse_args() + csv_file_path = args.outpath if args.outpath else csv_default_file_path + export_csv() diff --git a/functions-python/export_csv/tests/conftest.py b/functions-python/export_csv/tests/conftest.py new file mode 100644 index 000000000..b210219df --- /dev/null +++ b/functions-python/export_csv/tests/conftest.py @@ -0,0 +1,238 @@ +# +# MobilityData 2023 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from faker import Faker +from datetime import datetime + +from geoalchemy2 import WKTElement + +from shared.database_gen.sqlacodegen_models import ( + Validationreport, + Feature, + Redirectingid, +) +from shared.database_gen.sqlacodegen_models import ( + Gtfsfeed, + Gtfsrealtimefeed, + Gtfsdataset, + Location, + Entitytype, +) +from test_shared.test_utils.database_utils import clean_testing_db, get_testing_session + + +def populate_database(): + """ + Populates the database with fake data with the following distribution: + - 5 GTFS feeds + - 2 active + - 1 inactive + - 2 deprecated + - 5 GTFS Realtime feeds + - 4 3TFS rt datasets, with 1 of them inactive + """ + clean_testing_db() + session = get_testing_session() + fake = Faker() + + feeds = [] + # We create 3 feeds. The first one is active. The third one is inactive and redirected to the first one. + # The second one is active but not redirected. + # First fill the generic paramaters + for i in range(3): + feed = Gtfsfeed( + data_type="gtfs", + feed_name=f"gtfs-{i} Some fake name", + note=f"gtfs-{i} Some fake note", + producer_url=f"https://gtfs-{i}_some_fake_producer_url", + authentication_info_url=None, + api_key_parameter_name=None, + license_url=f"https://gtfs-{i}_some_fake_license_url", + stable_id=f"gtfs-{i}", + feed_contact_email=f"gtfs-{i}_some_fake_email@fake.com", + provider=f"gtfs-{i} Some fake company", + ) + feeds.append(feed) + + # Then fill the specific parameters for each feed + target_feed = feeds[0] + target_feed.id = "e3155a30-81d8-40bb-9e10-013a60436d86" # Just an invented uuid + target_feed.authentication_type = "0" + target_feed.status = "active" + + feed = feeds[1] + feed.id = fake.uuid4() + feed.authentication_type = "0" + feed.status = "active" + + source_feed = feeds[2] + source_feed.id = "6e7c5f17-537a-439a-bf99-9c37f1f01030" + source_feed.authentication_type = "0" + source_feed.status = "inactive" + source_feed.redirectingids = [ + Redirectingid( + source_id=source_feed.id, + target_id=target_feed.id, + redirect_comment="Some redirect comment", + target=target_feed, + ) + ] + + for feed in feeds: + session.add(feed) + + for i in range(2): + feed = Gtfsfeed( + id=fake.uuid4(), + data_type="gtfs", + feed_name=f"gtfs-deprecated-{i} Some fake name", + note=f"gtfs-deprecated-{i} Some fake note", + producer_url=f"https://gtfs-deprecated-{i}_some_fake_producer_url", + authentication_type="0" if (i == 0) else "1", + authentication_info_url=None, + api_key_parameter_name=None, + license_url=f"https://gtfs-{i}_some_fake_license_url", + stable_id=f"gtfs-deprecated-{i}", + status="deprecated", + feed_contact_email=f"gtfs-deprecated-{i}_some_fake_email@fake.com", + provider=f"gtfs-deprecated-{i} Some fake company", + ) + session.add(feed) + + location_entity = Location(id="CA-quebec-montreal") + + location_entity.country = "Canada" + location_entity.country_code = "CA" + location_entity.subdivision_name = "Quebec" + location_entity.municipality = "Montreal" + session.add(location_entity) + locations = [location_entity] + + feature1 = Feature(name="Shapes") + session.add(feature1) + feature2 = Feature(name="Route Colors") + session.add(feature2) + + # GTFS datasets leaving one active feed without a dataset + active_gtfs_feeds = ( + session.query(Gtfsfeed) + .filter(Gtfsfeed.status == "active") + .order_by(Gtfsfeed.stable_id) + .all() + ) + + # the first 2 datasets are for the first feed + for i in range(1, 4): + feed_index = 0 if i in [1, 2] else 1 + wkt_polygon = "POLYGON((-18 -9, -18 9, 18 9, 18 -9, -18 -9))" + wkt_element = WKTElement(wkt_polygon, srid=4326) + feed_stable_id = active_gtfs_feeds[feed_index].stable_id + gtfs_dataset = Gtfsdataset( + id=fake.uuid4(), + feed_id=feed_stable_id, + latest=True if i != 2 else False, + bounding_box=wkt_element, + # Use a url containing the stable id. The program should replace all the is after the feed stable id + # by latest.zip + hosted_url=f"https://url_prefix/{feed_stable_id}/dataset-{i}_some_fake_hosted_url", + note=f"dataset-{i} Some fake note", + hash=fake.sha256(), + downloaded_at=datetime.utcnow(), + stable_id=f"dataset-{i}", + ) + validation_report = Validationreport( + id=fake.uuid4(), + validator_version="6.0.1", + validated_at=datetime(2025, 1, 12), + html_report=fake.url(), + json_report=fake.url(), + ) + validation_report.features.append(feature1) + validation_report.features.append(feature2) + + session.add(validation_report) + gtfs_dataset.validation_reports.append(validation_report) + + gtfs_dataset.locations = locations + + active_gtfs_feeds[feed_index].gtfsdatasets.append(gtfs_dataset) + active_gtfs_feeds[0].locations = locations + active_gtfs_feeds[1].locations = locations + + # active_gtfs_feeds[0].gtfsdatasets.append() = gtfs_datasets + + vp_entitytype = session.query(Entitytype).filter_by(name="vp").first() + if not vp_entitytype: + vp_entitytype = Entitytype(name="vp") + session.add(vp_entitytype) + tu_entitytype = session.query(Entitytype).filter_by(name="tu").first() + if not tu_entitytype: + tu_entitytype = Entitytype(name="tu") + session.add(tu_entitytype) + + # GTFS Realtime feeds + for i in range(3): + gtfs_rt_feed = Gtfsrealtimefeed( + id=fake.uuid4(), + data_type="gtfs_rt", + feed_name=f"gtfs-rt-{i} Some fake name", + note=f"gtfs-rt-{i} Some fake note", + producer_url=f"https://gtfs-rt-{i}_some_fake_producer_url", + authentication_type=str(i), + authentication_info_url=f"https://gtfs-rt-{i}_some_fake_authentication_info_url", + api_key_parameter_name=f"gtfs-rt-{i}_fake_api_key_parameter_name", + license_url=f"https://gtfs-rt-{i}_some_fake_license_url", + stable_id=f"gtfs-rt-{i}", + status="inactive" if i == 1 else "active", + feed_contact_email=f"gtfs-rt-{i}_some_fake_email@fake.com", + provider=f"gtfs-rt-{i} Some fake company", + entitytypes=[vp_entitytype, tu_entitytype] if (i == 0) else [vp_entitytype], + ) + session.add(gtfs_rt_feed) + + session.commit() + + +def pytest_configure(config): + """ + Allows plugins and conftest files to perform initial configuration. + This hook is called for every plugin and initial conftest + file after command line options have been parsed. + """ + + +def pytest_sessionstart(session): + """ + Called after the Session object has been created and + before performing collection and entering the run test loop. + """ + clean_testing_db() + populate_database() + + +def pytest_sessionfinish(session, exitstatus): + """ + Called after whole test run finished, right before + returning the exit status to the system. + """ + # Cleaned at the beginning instead of the end so we can examine the DB after the test. + # clean_testing_db() + + +def pytest_unconfigure(config): + """ + called before test process is exited. + """ diff --git a/functions-python/export_csv/tests/test_export_csv_main.py b/functions-python/export_csv/tests/test_export_csv_main.py new file mode 100644 index 000000000..2341583db --- /dev/null +++ b/functions-python/export_csv/tests/test_export_csv_main.py @@ -0,0 +1,65 @@ +# +# MobilityData 2023 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import io +import os + +import pandas as pd +import pandas.testing as pdt +import main + +# This CSV has been created by running the tests once to extract the resulting csv, then examine the CSV to make sure +# the data is correct. +expected_csv = """ +id,data_type,entity_type,location.country_code,location.subdivision_name,location.municipality,provider,name,note,feed_contact_email,static_reference,urls.direct_download,urls.authentication_type,urls.authentication_info,urls.api_key_parameter_name,urls.latest,urls.license,location.bounding_box.minimum_latitude,location.bounding_box.maximum_latitude,location.bounding_box.minimum_longitude,location.bounding_box.maximum_longitude,location.bounding_box.extracted_on,status,features,redirect.id,redirect.comment +gtfs-0,gtfs,,CA,Quebec,Montreal,gtfs-0 Some fake company,gtfs-0 Some fake name,gtfs-0 Some fake note,gtfs-0_some_fake_email@fake.com,,https://gtfs-0_some_fake_producer_url,0,,,https://url_prefix/gtfs-0/latest.zip,https://gtfs-0_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,Route Colors|Shapes,, +gtfs-1,gtfs,,CA,Quebec,Montreal,gtfs-1 Some fake company,gtfs-1 Some fake name,gtfs-1 Some fake note,gtfs-1_some_fake_email@fake.com,,https://gtfs-1_some_fake_producer_url,0,,,https://url_prefix/gtfs-1/latest.zip,https://gtfs-1_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,Route Colors|Shapes,, +gtfs-2,gtfs,,,,,gtfs-2 Some fake company,gtfs-2 Some fake name,gtfs-2 Some fake note,gtfs-2_some_fake_email@fake.com,,https://gtfs-2_some_fake_producer_url,0,,,,https://gtfs-2_some_fake_license_url,,,,,,inactive,,gtfs-0,Some redirect comment +gtfs-deprecated-0,gtfs,,,,,gtfs-deprecated-0 Some fake company,gtfs-deprecated-0 Some fake name,gtfs-deprecated-0 Some fake note,gtfs-deprecated-0_some_fake_email@fake.com,,https://gtfs-deprecated-0_some_fake_producer_url,0,,,,https://gtfs-0_some_fake_license_url,,,,,,deprecated,,, +gtfs-deprecated-1,gtfs,,,,,gtfs-deprecated-1 Some fake company,gtfs-deprecated-1 Some fake name,gtfs-deprecated-1 Some fake note,gtfs-deprecated-1_some_fake_email@fake.com,,https://gtfs-deprecated-1_some_fake_producer_url,1,,,,https://gtfs-1_some_fake_license_url,,,,,,deprecated,,, +gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,gtfs-rt-0_some_fake_email@fake.com,,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,,,,,,,,, +gtfs-rt-1,gtfs_rt,vp,,,,gtfs-rt-1 Some fake company,gtfs-rt-1 Some fake name,gtfs-rt-1 Some fake note,gtfs-rt-1_some_fake_email@fake.com,,https://gtfs-rt-1_some_fake_producer_url,1,https://gtfs-rt-1_some_fake_authentication_info_url,gtfs-rt-1_fake_api_key_parameter_name,,https://gtfs-rt-1_some_fake_license_url,,,,,,,,, +gtfs-rt-2,gtfs_rt,vp,,,,gtfs-rt-2 Some fake company,gtfs-rt-2 Some fake name,gtfs-rt-2 Some fake note,gtfs-rt-2_some_fake_email@fake.com,,https://gtfs-rt-2_some_fake_producer_url,2,https://gtfs-rt-2_some_fake_authentication_info_url,gtfs-rt-2_fake_api_key_parameter_name,,https://gtfs-rt-2_some_fake_license_url,,,,,,,,, +""" # noqa + + +def test_export_csv(): + os.environ[ + "FEEDS_DATABASE_URL" + ] = "postgresql://postgres:postgres@localhost:54320/MobilityDatabaseTest" + data_collector = main.collect_data() + print(f"Collected data for {len(data_collector.rows)} feeds.") + + df_extracted = data_collector.get_dataframe() + + csv_buffer = io.StringIO(expected_csv) + df_from_expected_csv = pd.read_csv(csv_buffer) + df_from_expected_csv.fillna("", inplace=True) + + df_extracted.fillna("", inplace=True) + + df_extracted["urls.authentication_type"] = df_extracted[ + "urls.authentication_type" + ].astype(str) + df_from_expected_csv["urls.authentication_type"] = df_from_expected_csv[ + "urls.authentication_type" + ].astype(str) + df_from_expected_csv["location.bounding_box.extracted_on"] = pd.to_datetime( + df_from_expected_csv["location.bounding_box.extracted_on"], utc=True + ) + + # try: + pdt.assert_frame_equal(df_extracted, df_from_expected_csv) + print("DataFrames are equal.") diff --git a/functions-python/helpers/requirements.txt b/functions-python/helpers/requirements.txt index 59b67dd1a..2fbe18ed6 100644 --- a/functions-python/helpers/requirements.txt +++ b/functions-python/helpers/requirements.txt @@ -9,6 +9,7 @@ requests~=2.32.3 attrs~=23.1.0 pluggy~=1.3.0 certifi~=2024.7.4 +python-dotenv==1.0.0 # SQL Alchemy and Geo Alchemy SQLAlchemy==2.0.23 diff --git a/functions-python/helpers/test_config.json b/functions-python/helpers/test_config.json index d5e5379ac..c97d31a8b 100644 --- a/functions-python/helpers/test_config.json +++ b/functions-python/helpers/test_config.json @@ -1,3 +1,3 @@ { - "include_api_folders": ["database_gen"] + "include_api_folders": ["database_gen", "database"] } diff --git a/infra/batch/main.tf b/infra/batch/main.tf index 0808987eb..a9fc5be60 100644 --- a/infra/batch/main.tf +++ b/infra/batch/main.tf @@ -249,7 +249,7 @@ resource "google_cloudfunctions2_function" "pubsub_function" { vpc_connector_egress_settings = "PRIVATE_RANGES_ONLY" environment_variables = { - DATASETS_BUCKET_NANE = google_storage_bucket.datasets_bucket.name + DATASETS_BUCKET_NAME = google_storage_bucket.datasets_bucket.name # prevents multiline logs from being truncated on GCP console PYTHONNODEBUGRANGES = 0 DB_REUSE_SESSION = "True" diff --git a/infra/functions-python/main.tf b/infra/functions-python/main.tf index 5455bc570..b30b04a4b 100644 --- a/infra/functions-python/main.tf +++ b/infra/functions-python/main.tf @@ -55,6 +55,9 @@ locals { function_backfill_dataset_service_date_range_config = jsondecode(file("${path.module}/../../functions-python/backfill_dataset_service_date_range/function_config.json")) function_backfill_dataset_service_date_range_zip = "${path.module}/../../functions-python/backfill_dataset_service_date_range/.dist/backfill_dataset_service_date_range.zip" + + function_export_csv_config = jsondecode(file("${path.module}/../../functions-python/export_csv/function_config.json")) + function_export_csv_zip = "${path.module}/../../functions-python/export_csv/.dist/export_csv.zip" } locals { @@ -65,7 +68,8 @@ locals { [for x in local.function_extract_location_config.secret_environment_variables : x.key], [for x in local.function_process_validation_report_config.secret_environment_variables : x.key], [for x in local.function_update_validation_report_config.secret_environment_variables : x.key], - [for x in local.function_backfill_dataset_service_date_range_config.secret_environment_variables : x.key] + [for x in local.function_backfill_dataset_service_date_range_config.secret_environment_variables : x.key], + [for x in local.function_export_csv_config.secret_environment_variables : x.key] ) # Convert the list to a set to ensure uniqueness @@ -82,6 +86,10 @@ data "google_pubsub_topic" "datasets_batch_topic" { name = "datasets-batch-topic-${var.environment}" } +data "google_storage_bucket" "datasets_bucket" { + name = "${var.datasets_bucket_name}-${var.environment}" +} + # Service account to execute the cloud functions resource "google_service_account" "functions_service_account" { account_id = "functions-service-account" @@ -98,6 +106,18 @@ resource "google_storage_bucket" "gbfs_snapshots_bucket" { name = "${var.gbfs_bucket_name}-${var.environment}" } +resource "google_storage_bucket_iam_member" "datasets_bucket_functions_service_account" { + bucket = data.google_storage_bucket.datasets_bucket.name + role = "roles/storage.admin" + member = "serviceAccount:${google_service_account.functions_service_account.email}" +} + +resource "google_project_iam_member" "datasets_bucket_functions_service_account" { + project = var.project_id + member = "serviceAccount:${google_service_account.functions_service_account.email}" + role = "roles/storage.admin" +} + # Cloud function source code zip files: # 1. Tokens resource "google_storage_bucket_object" "function_token_zip" { @@ -161,6 +181,13 @@ resource "google_storage_bucket_object" "backfill_dataset_service_date_range_zip } +# 10. Export CSV +resource "google_storage_bucket_object" "export_csv_zip" { + bucket = google_storage_bucket.functions_bucket.name + name = "export-csv-${substr(filebase64sha256(local.function_export_csv_zip), 0, 10)}.zip" + source = local.function_export_csv_zip +} + # Secrets access resource "google_secret_manager_secret_iam_member" "secret_iam_member" { for_each = local.unique_secret_keys @@ -819,7 +846,55 @@ resource "google_cloudfunctions2_function" "backfill_dataset_service_date_range" } } -# IAM entry for all users to invoke the function +# 10. functions/export_csv cloud function +resource "google_cloudfunctions2_function" "export_csv" { + name = "${local.function_export_csv_config.name}" + project = var.project_id + description = local.function_export_csv_config.description + location = var.gcp_region + depends_on = [google_secret_manager_secret_iam_member.secret_iam_member] + + build_config { + runtime = var.python_runtime + entry_point = "${local.function_export_csv_config.entry_point}" + source { + storage_source { + bucket = google_storage_bucket.functions_bucket.name + object = google_storage_bucket_object.export_csv_zip.name + } + } + } + service_config { + environment_variables = { + DATASETS_BUCKET_NAME = data.google_storage_bucket.datasets_bucket.name + PROJECT_ID = var.project_id + ENVIRONMENT = var.environment + } + available_memory = local.function_export_csv_config.memory + timeout_seconds = local.function_export_csv_config.timeout + available_cpu = local.function_export_csv_config.available_cpu + max_instance_request_concurrency = local.function_export_csv_config.max_instance_request_concurrency + max_instance_count = local.function_export_csv_config.max_instance_count + min_instance_count = local.function_export_csv_config.min_instance_count + service_account_email = google_service_account.functions_service_account.email + ingress_settings = "ALLOW_ALL" + vpc_connector = data.google_vpc_access_connector.vpc_connector.id + vpc_connector_egress_settings = "PRIVATE_RANGES_ONLY" + + dynamic "secret_environment_variables" { + for_each = local.function_export_csv_config.secret_environment_variables + content { + key = secret_environment_variables.value["key"] + project_id = var.project_id + secret = "${upper(var.environment)}_${secret_environment_variables.value["key"]}" + version = "latest" + } + } + } + +} + +# IAM entry for all users to invoke the function resource "google_cloudfunctions2_function_iam_member" "tokens_invoker" { project = var.project_id location = var.gcp_region @@ -867,15 +942,6 @@ resource "google_project_iam_member" "event-receiving" { depends_on = [google_project_iam_member.invoking] } -# Grant read access to the datasets bucket for the service account -resource "google_storage_bucket_iam_binding" "bucket_object_viewer" { - bucket = "${var.datasets_bucket_name}-${var.environment}" - role = "roles/storage.objectViewer" - members = [ - "serviceAccount:${google_service_account.functions_service_account.email}" - ] -} - # Grant write access to the gbfs bucket for the service account resource "google_storage_bucket_iam_binding" "gbfs_bucket_object_creator" { bucket = google_storage_bucket.gbfs_snapshots_bucket.name diff --git a/scripts/api-tests.sh b/scripts/api-tests.sh index d9494655f..73d887d95 100755 --- a/scripts/api-tests.sh +++ b/scripts/api-tests.sh @@ -52,10 +52,10 @@ display_usage() { printf "\nScript Usage:\n" echo "Usage: $0 [options]" echo "Options:" - echo " -test_file Test file name to be executed." - echo " -folder Folder name to be executed." - echo " -html_report Generate HTML coverage report." - echo " -help Display help content." + echo " --test_file Test file name to be executed." + echo " --folder Folder name to be executed." + echo " --html_report Generate HTML coverage report." + echo " --help Display help content." exit 1 } diff --git a/scripts/db-gen.sh b/scripts/db-gen.sh index 232eedab7..1bd4f4a88 100755 --- a/scripts/db-gen.sh +++ b/scripts/db-gen.sh @@ -11,7 +11,7 @@ SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")" # Default filename for OUT_FILE -DEFAULT_FILENAME="api/src/database_gen/sqlacodegen_models.py" +DEFAULT_FILENAME="api/src/shared/database_gen/sqlacodegen_models.py" # Use the first argument as the filename for OUT_FILE; if not provided, use the default filename FILENAME=${1:-$DEFAULT_FILENAME} OUT_FILE=$SCRIPT_PATH/../$FILENAME @@ -19,8 +19,8 @@ OUT_FILE=$SCRIPT_PATH/../$FILENAME ENV_PATH=$SCRIPT_PATH/../config/.env.local source "$ENV_PATH" -rm -rf "$SCRIPT_PATH/../api/src/database_gen/" -mkdir "$SCRIPT_PATH/../api/src/database_gen/" +rm -rf "$SCRIPT_PATH/../api/src/shared/database_gen/" +mkdir "$SCRIPT_PATH/../api/src/shared/database_gen/" pip3 install -r "${SCRIPT_PATH}/../api/requirements.txt" > /dev/null # removing sqlacodegen.log file diff --git a/scripts/function-python-build.sh b/scripts/function-python-build.sh index 0ea4b7c69..515615848 100755 --- a/scripts/function-python-build.sh +++ b/scripts/function-python-build.sh @@ -30,7 +30,7 @@ SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")" ROOT_PATH="$SCRIPT_PATH/.." FUNCTIONS_PATH="$ROOT_PATH/functions-python" -API_SRC_PATH="$ROOT_PATH/api/src" +API_SRC_PATH="$ROOT_PATH/api/src/shared" # function printing usage display_usage() { diff --git a/scripts/function-python-run.sh b/scripts/function-python-run.sh index 97ceaa314..9673b1b3c 100755 --- a/scripts/function-python-run.sh +++ b/scripts/function-python-run.sh @@ -105,7 +105,7 @@ fi export PYTHONPATH="$FX_PATH" # Install a virgin python virtual environment and provision it with the required packages so it's the same as -# the one deployed that will be deployed in the cloud +# the one that will be deployed in the cloud pushd "$FUNCTIONS_PATH/$function_name" >/dev/null printf "\nINFO: installing python virtual environment" rm -rf venv diff --git a/scripts/function-python-setup.sh b/scripts/function-python-setup.sh index 1b8966d09..2d628f3f9 100755 --- a/scripts/function-python-setup.sh +++ b/scripts/function-python-setup.sh @@ -32,7 +32,7 @@ SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")" ROOT_PATH=$(realpath "$SCRIPT_PATH/..") FUNCTIONS_PATH="$ROOT_PATH/functions-python" -API_PATH="$ROOT_PATH/api/src" +API_PATH="$ROOT_PATH/api/src/shared" # function printing usage display_usage() { @@ -44,6 +44,7 @@ display_usage() { echo " -h|--help Display help content." echo " --function_name Name of the function to be executed." echo " --all Build all functions." + echo " --clean Clean shared folders." exit 1 } @@ -149,6 +150,8 @@ create_symbolic_links() { fi for folder in $folders; do + # In case the folder is made of more than one level (e.g. "feeds/filters") just link the parent (e.g. "feeds") + folder=$(echo $folder | cut -d '/' -f 1) src_folder="$root_folder/$folder" if [[ "$dst_folder" != "$src_folder"* ]]; then relative_path=$(python3 -c "import os.path; print(os.path.relpath(\"$src_folder\", \"$dst_folder\"))")