Skip to content

Commit

Permalink
feat: Added export_csv cloud function to generate new sources.csv (#888)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcpitre authored Jan 28, 2025
1 parent e59d956 commit 5c62297
Show file tree
Hide file tree
Showing 70 changed files with 1,279 additions and 231 deletions.
22 changes: 11 additions & 11 deletions .github/workflows/api-deployer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ jobs:
- uses: actions/download-artifact@v4
with:
name: database_gen
path: api/src/database_gen/
path: api/src/shared/database_gen/

# api schema was generated and uploaded in api-build-test job above.
- uses: actions/download-artifact@v4
Expand All @@ -219,7 +219,7 @@ jobs:
- name: Build & Publish Docker Image
run: |
# We want to generate the image even if it's the same commit that has been tagged. So use the version
# (coming from the tag) in the docker image tag (If the docket tag does not change it's won't be uploaded)
# (coming from the tag) in the docker image tag (If the docker tag does not change it won't be uploaded)
DOCKER_IMAGE_VERSION=$EXTRACTED_VERSION.$FEED_API_IMAGE_VERSION
scripts/docker-build-push.sh -project_id $PROJECT_ID -repo_name feeds-$ENVIRONMENT -service feed-api -region $REGION -version $DOCKER_IMAGE_VERSION
Expand All @@ -243,7 +243,7 @@ jobs:
- uses: actions/download-artifact@v4
with:
name: database_gen
path: api/src/database_gen/
path: api/src/shared/database_gen/

# api schema was generated and uploaded in api-build-test job above.
- uses: actions/download-artifact@v4
Expand Down Expand Up @@ -318,18 +318,18 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PLAN_OUTPUT: ${{ steps.plan.outputs.stdout }}

- name: Persist TF plan
uses: actions/upload-artifact@v4
with:
name: terraform-plan.txt
path: infra/terraform-plan.txt
overwrite: true

- name: Terraform Apply
if: ${{ inputs.TF_APPLY }}
run: |
cd infra
terraform apply -auto-approve tf.plan
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PLAN_OUTPUT: ${{ steps.plan.outputs.stdout }}

- name: Persist TF plan
uses: actions/upload-artifact@v4
with:
name: terraform-plan.txt
path: infra/terraform-plan.txt
overwrite: true
PLAN_OUTPUT: ${{ steps.plan.outputs.stdout }}
2 changes: 1 addition & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ jobs:
uses: actions/upload-artifact@v4
with:
name: database_gen
path: api/src/database_gen/
path: api/src/shared/database_gen/
overwrite: true

- name: Upload API generated code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/datasets-batch-deployer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
uses: actions/upload-artifact@v4
with:
name: database_gen
path: api/src/database_gen/
path: api/src/shared/database_gen/

- name: Build python functions
run: |
Expand Down
2 changes: 1 addition & 1 deletion api/.flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
max-line-length = 120
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,src/feeds_gen,src/database_gen
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,feeds_gen,database_gen
# Ignored because conflict with black
extend-ignore = E203
64 changes: 4 additions & 60 deletions api/src/feeds/impl/datasets_api_impl.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from typing import List
from typing import Tuple

from geoalchemy2 import WKTElement
from sqlalchemy import or_
from sqlalchemy.orm import Query, Session

from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
from shared.database.database import Database, with_db_session
from shared.database_gen.sqlacodegen_models import (
Gtfsdataset,
Feed,
)
from feeds.impl.error_handling import (
invalid_bounding_coordinates,
invalid_bounding_method,
raise_http_validation_error,
raise_http_error,
)
from shared.common.error_handling import (
dataset_not_found,
)
from feeds.impl.models.gtfs_dataset_impl import GtfsDatasetImpl
Expand All @@ -39,59 +36,6 @@ def create_dataset_query() -> Query:
]
).join(Feed, Feed.id == Gtfsdataset.feed_id)

@staticmethod
def apply_bounding_filtering(
query: Query,
bounding_latitudes: str,
bounding_longitudes: str,
bounding_filter_method: str,
) -> Query:
"""Create a new query based on the bounding parameters."""

if not bounding_latitudes or not bounding_longitudes or not bounding_filter_method:
return query

if (
len(bounding_latitudes_tokens := bounding_latitudes.split(",")) != 2
or len(bounding_longitudes_tokens := bounding_longitudes.split(",")) != 2
):
raise_http_validation_error(invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes))
min_latitude, max_latitude = bounding_latitudes_tokens
min_longitude, max_longitude = bounding_longitudes_tokens
try:
min_latitude = float(min_latitude)
max_latitude = float(max_latitude)
min_longitude = float(min_longitude)
max_longitude = float(max_longitude)
except ValueError:
raise_http_validation_error(invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes))
points = [
(min_longitude, min_latitude),
(min_longitude, max_latitude),
(max_longitude, max_latitude),
(max_longitude, min_latitude),
(min_longitude, min_latitude),
]
wkt_polygon = f"POLYGON(({', '.join(f'{lon} {lat}' for lon, lat in points)}))"
bounding_box = WKTElement(
wkt_polygon,
srid=Gtfsdataset.bounding_box.type.srid,
)

if bounding_filter_method == "partially_enclosed":
return query.filter(
or_(
Gtfsdataset.bounding_box.ST_Overlaps(bounding_box),
Gtfsdataset.bounding_box.ST_Contains(bounding_box),
)
)
elif bounding_filter_method == "completely_enclosed":
return query.filter(bounding_box.ST_Covers(Gtfsdataset.bounding_box))
elif bounding_filter_method == "disjoint":
return query.filter(Gtfsdataset.bounding_box.ST_Disjoint(bounding_box))
else:
raise_http_validation_error(invalid_bounding_method.format(bounding_filter_method))

@staticmethod
def get_datasets_gtfs(query: Query, session: Session, limit: int = None, offset: int = None) -> List[GtfsDataset]:
# Results are sorted by stable_id because Database.select(group_by=) requires it so
Expand Down
21 changes: 10 additions & 11 deletions api/src/feeds/impl/error_handling.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Final

from fastapi import HTTPException

invalid_date_message: Final[
str
] = "Invalid date format for '{}'. Expected ISO 8601 format, example: '2021-01-01T00:00:00Z'"
invalid_bounding_coordinates: Final[str] = "Invalid bounding coordinates {} {}"
invalid_bounding_method: Final[str] = "Invalid bounding_filter_method {}"
feed_not_found: Final[str] = "Feed '{}' not found"
gtfs_feed_not_found: Final[str] = "GTFS feed '{}' not found"
gtfs_rt_feed_not_found: Final[str] = "GTFS realtime Feed '{}' not found"
dataset_not_found: Final[str] = "Dataset '{}' not found"
from shared.common.error_handling import InternalHTTPException


def convert_exception(input_exception: InternalHTTPException) -> HTTPException:
"""Convert an InternalHTTPException to an HTTPException.
HTTPException is dependent on fastapi, and we don't necessarily want to deploy it with python functions.
That's why InternalHTTPException (a class that we deploy) is thrown instead of HTTPException.
Since InternalHTTPException is internal, it needs to be converted before being sent up.
"""
return HTTPException(status_code=input_exception.status_code, detail=input_exception.detail)


def raise_http_error(status_code: int, error: str):
Expand Down
106 changes: 55 additions & 51 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from sqlalchemy.orm import joinedload, Session
from sqlalchemy.orm.query import Query

from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
from shared.common.db_utils import get_gtfs_feeds_query, get_gtfs_rt_feeds_query, get_joinedload_options
from shared.database.database import Database, with_db_session
from shared.database_gen.sqlacodegen_models import (
Feed,
Gtfsdataset,
Gtfsfeed,
Expand All @@ -17,18 +18,17 @@
t_location_with_translations_en,
Entitytype,
)
from feeds.filters.feed_filter import FeedFilter
from feeds.filters.gtfs_dataset_filter import GtfsDatasetFilter
from feeds.filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter
from feeds.filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter
from shared.feed_filters.feed_filter import FeedFilter
from shared.feed_filters.gtfs_dataset_filter import GtfsDatasetFilter
from shared.feed_filters.gtfs_feed_filter import LocationFilter
from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter
from feeds.impl.datasets_api_impl import DatasetsApiImpl
from feeds.impl.error_handling import (
raise_http_validation_error,
from shared.common.error_handling import (
invalid_date_message,
raise_http_error,
feed_not_found,
gtfs_feed_not_found,
gtfs_rt_feed_not_found,
InternalHTTPException,
)
from feeds.impl.models.basic_feed_impl import BasicFeedImpl
from feeds.impl.models.entity_type_enum import EntityType
Expand All @@ -39,6 +39,7 @@
from feeds_gen.models.gtfs_dataset import GtfsDataset
from feeds_gen.models.gtfs_feed import GtfsFeed
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from feeds.impl.error_handling import raise_http_error, raise_http_validation_error, convert_exception
from middleware.request_context import is_user_email_restricted
from utils.date_utils import valid_iso_date
from utils.location_translation import (
Expand Down Expand Up @@ -117,7 +118,7 @@ def get_feeds(
)
# Results are sorted by provider
feed_query = feed_query.order_by(Feed.provider, Feed.stable_id)
feed_query = feed_query.options(*BasicFeedImpl.get_joinedload_options())
feed_query = feed_query.options(*get_joinedload_options())
if limit is not None:
feed_query = feed_query.limit(limit)
if offset is not None:
Expand Down Expand Up @@ -158,7 +159,7 @@ def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.notices),
*BasicFeedImpl.get_joinedload_options(),
*get_joinedload_options(),
)
).all()
if len(results) > 0 and results[0].Gtfsfeed:
Expand Down Expand Up @@ -237,46 +238,29 @@ def get_gtfs_feeds(
is_official: bool,
db_session: Session,
) -> List[GtfsFeed]:
"""Get some (or all) GTFS feeds from the Mobility Database."""
gtfs_feed_filter = GtfsFeedFilter(
stable_id=None,
provider__ilike=provider,
producer_url__ilike=producer_url,
location=LocationFilter(
try:
include_wip = not is_user_email_restricted()
feed_query = get_gtfs_feeds_query(
limit=limit,
offset=offset,
provider=provider,
producer_url=producer_url,
country_code=country_code,
subdivision_name__ilike=subdivision_name,
municipality__ilike=municipality,
),
)

subquery = gtfs_feed_filter.filter(select(Gtfsfeed.id).join(Location, Gtfsfeed.locations))
subquery = DatasetsApiImpl.apply_bounding_filtering(
subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method
).subquery()

is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")
feed_query = (
db_session.query(Gtfsfeed)
.filter(Gtfsfeed.id.in_(subquery))
.filter(
or_(
Gtfsfeed.operational_status == None, # noqa: E711
Gtfsfeed.operational_status != "wip",
not is_email_restricted, # Allow all feeds to be returned if the user is not restricted
)
subdivision_name=subdivision_name,
municipality=municipality,
dataset_latitudes=dataset_latitudes,
dataset_longitudes=dataset_longitudes,
bounding_filter_method=bounding_filter_method,
is_official=is_official,
include_wip=include_wip,
db_session=db_session,
)
.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.notices),
*BasicFeedImpl.get_joinedload_options(),
)
.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
)
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.limit(limit).offset(offset)
except InternalHTTPException as e:
# get_gtfs_feeds_query cannot throw HTTPException since it's part of fastapi and it's
# not necessarily deployed (e.g. for python functions). Instead it throws an InternalHTTPException
# that needs to be converted to HTTPException before being thrown.
raise convert_exception(e)

return self._get_response(feed_query, GtfsFeedImpl, db_session)

@with_db_session
Expand All @@ -303,7 +287,7 @@ def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*BasicFeedImpl.get_joinedload_options(),
*get_joinedload_options(),
)
).all()

Expand All @@ -328,6 +312,26 @@ def get_gtfs_rt_feeds(
db_session: Session,
) -> List[GtfsRTFeed]:
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
try:
include_wip = not is_user_email_restricted()
feed_query = get_gtfs_rt_feeds_query(
limit=limit,
offset=offset,
provider=provider,
producer_url=producer_url,
entity_types=entity_types,
country_code=country_code,
subdivision_name=subdivision_name,
municipality=municipality,
is_official=is_official,
include_wip=include_wip,
db_session=db_session,
)
except InternalHTTPException as e:
raise convert_exception(e)

return self._get_response(feed_query, GtfsRTFeedImpl, db_session)

entity_types_list = entity_types.split(",") if entity_types else None

# Validate entity types using the EntityType enum
Expand Down Expand Up @@ -369,7 +373,7 @@ def get_gtfs_rt_feeds(
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*BasicFeedImpl.get_joinedload_options(),
*get_joinedload_options(),
)
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
)
Expand Down
15 changes: 1 addition & 14 deletions api/src/feeds/impl/models/basic_feed_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.strategy_options import _AbstractLoad

from database_gen.sqlacodegen_models import Feed
from shared.database_gen.sqlacodegen_models import Feed
from feeds.impl.models.external_id_impl import ExternalIdImpl
from feeds.impl.models.redirect_impl import RedirectImpl
from feeds_gen.models.basic_feed import BasicFeed
Expand Down Expand Up @@ -47,16 +44,6 @@ def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
redirects=sorted([RedirectImpl.from_orm(item) for item in feed.redirectingids], key=lambda x: x.target_id),
)

@staticmethod
def get_joinedload_options() -> [_AbstractLoad]:
"""Returns common joinedload options for feeds queries."""
return [
joinedload(Feed.locations),
joinedload(Feed.externalids),
joinedload(Feed.redirectingids),
joinedload(Feed.officialstatushistories),
]


class BasicFeedImpl(BaseFeedImpl, BasicFeed):
"""Implementation of the `BasicFeed` model.
Expand Down
Loading

0 comments on commit 5c62297

Please sign in to comment.