Skip to content

Commit

Permalink
feat: tl scraping function (#847)
Browse files Browse the repository at this point in the history
  • Loading branch information
cka-y authored Dec 6, 2024
1 parent dc5700c commit 73ea4bc
Show file tree
Hide file tree
Showing 18 changed files with 810 additions and 1,241 deletions.
9 changes: 8 additions & 1 deletion functions-python/batch_process_dataset/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,21 @@ def process_dataset(cloud_event: CloudEvent):
dataset_file: DatasetFile = None
error_message = None
try:
# Extract data from message
# Extract data from message
logging.info(f"Cloud Event: {cloud_event}")
data = base64.b64decode(cloud_event.data["message"]["data"]).decode()
json_payload = json.loads(data)
logging.info(
f"[{json_payload['feed_stable_id']}] JSON Payload: {json.dumps(json_payload)}"
)
stable_id = json_payload["feed_stable_id"]
execution_id = json_payload["execution_id"]
except Exception as e:
error_message = f"[{stable_id}] Error parsing message: [{e}]"
logging.error(error_message)
logging.error(f"Function completed with error:{error_message}")
return
try:
trace_service = DatasetTraceService()

trace = trace_service.get_by_execution_and_stable_ids(execution_id, stable_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def test_process_dataset_normal_execution(
@patch("batch_process_dataset.src.main.Logger")
@patch("batch_process_dataset.src.main.DatasetTraceService")
@patch("batch_process_dataset.src.main.DatasetProcessor")
def test_process_dataset_exception(
def test_process_dataset_exception_caught(
self, mock_dataset_processor, mock_dataset_trace, _
):
db_url = os.getenv("TEST_FEEDS_DATABASE_URL", default=default_db_url)
Expand All @@ -413,11 +413,7 @@ def test_process_dataset_exception(
mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0

# Call the function
try:
process_dataset(cloud_event)
assert False
except AttributeError:
assert True
process_dataset(cloud_event)

@patch("batch_process_dataset.src.main.Logger")
@patch("batch_process_dataset.src.main.DatasetTraceService")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def update_location(
.filter(Gtfsfeed.stable_id == dataset.feed.stable_id)
.one_or_none()
)
if gtfs_feed is None:
logging.error(f"Feed {dataset.feed.stable_id} not found a GTFS feed.")
raise Exception(f"Feed {dataset.feed.stable_id} not found a GTFS feed.")

for gtfs_rt_feed in gtfs_feed.gtfs_rt_feeds:
logging.info(f"Updating GTFS-RT feed with stable ID {gtfs_rt_feed.stable_id}")
Expand Down
189 changes: 100 additions & 89 deletions functions-python/feed_sync_dispatcher_transitland/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
# limitations under the License.
#

import json
import logging
import os
import random
import time
from dataclasses import dataclass, asdict
from typing import Optional, List
from typing import Optional

import functions_framework
import pandas as pd
Expand All @@ -29,14 +27,15 @@
from requests.exceptions import RequestException, HTTPError
from sqlalchemy.orm import Session

from database_gen.sqlacodegen_models import Gtfsfeed
from database_gen.sqlacodegen_models import Feed
from helpers.feed_sync.feed_sync_common import FeedSyncProcessor, FeedSyncPayload
from helpers.feed_sync.feed_sync_dispatcher import feed_sync_dispatcher
from helpers.feed_sync.models import TransitFeedSyncPayload
from helpers.logger import Logger
from helpers.pub_sub import get_pubsub_client, get_execution_id
from typing import Tuple, List
from collections import defaultdict

# Logging configuration
logging.basicConfig(level=logging.INFO)

# Environment variables
PUBSUB_TOPIC_NAME = os.getenv("PUBSUB_TOPIC_NAME")
Expand All @@ -45,68 +44,66 @@
TRANSITLAND_API_KEY = os.getenv("TRANSITLAND_API_KEY")
TRANSITLAND_OPERATOR_URL = os.getenv("TRANSITLAND_OPERATOR_URL")
TRANSITLAND_FEED_URL = os.getenv("TRANSITLAND_FEED_URL")
spec = ["gtfs", "gtfs-rt"]

# session instance to reuse connections
session = requests.Session()


@dataclass
class TransitFeedSyncPayload:
def process_feed_urls(feed: dict, urls_in_db: List[str]) -> Tuple[List[str], List[str]]:
"""
Data class for transit feed sync payloads.
Extracts the valid feed URLs and their corresponding entity types from the feed dictionary. If the same URL
corresponds to multiple entity types, the types are concatenated with a comma.
"""
url_keys_to_types = {
"static_current": "",
"realtime_alerts": "sa",
"realtime_trip_updates": "tu",
"realtime_vehicle_positions": "vp",
}

external_id: str
feed_id: str
feed_url: Optional[str] = None
execution_id: Optional[str] = None
spec: Optional[str] = None
auth_info_url: Optional[str] = None
auth_param_name: Optional[str] = None
type: Optional[str] = None
operator_name: Optional[str] = None
country: Optional[str] = None
state_province: Optional[str] = None
city_name: Optional[str] = None
source: Optional[str] = None
payload_type: Optional[str] = None
urls = feed.get("urls", {})
url_to_entity_types = defaultdict(list)

def to_dict(self):
return asdict(self)
for key, entity_type in url_keys_to_types.items():
if (url := urls.get(key)) and (url not in urls_in_db):
if entity_type:
logging.info(f"Found URL for entity type: {entity_type}")
url_to_entity_types[url].append(entity_type)

def to_json(self):
return json.dumps(self.to_dict())
valid_urls = []
entity_types = []

for url, types in url_to_entity_types.items():
valid_urls.append(url)
logging.info(f"URL = {url}, Entity types = {types}")
entity_types.append(",".join(types))

class TransitFeedSyncProcessor(FeedSyncProcessor):
def check_url_status(self, url: str) -> bool:
"""
Checks if a URL returns a valid response status code.
"""
try:
logging.info(f"Checking URL: {url}")
if url is None or len(url) == 0:
logging.warning("URL is empty. Skipping check.")
return False
response = requests.head(url, timeout=25)
logging.info(f"URL status code: {response.status_code}")
return response.status_code < 400
except requests.RequestException as e:
logging.warning(f"Failed to reach {url}: {e}")
return False
return valid_urls, entity_types


class TransitFeedSyncProcessor(FeedSyncProcessor):
def process_sync(
self, db_session: Optional[Session] = None, execution_id: Optional[str] = None
self, db_session: Session, execution_id: Optional[str] = None
) -> List[FeedSyncPayload]:
"""
Process data synchronously to fetch, extract, combine, filter and prepare payloads for publishing
to a queue based on conditions related to the data retrieved from TransitLand API.
"""
feeds_data = self.get_data(
TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, spec, session
feeds_data_gtfs_rt = self.get_data(
TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, "gtfs_rt", session
)
logging.info(
"Fetched %s GTFS-RT feeds from TransitLand API",
len(feeds_data_gtfs_rt["feeds"]),
)

feeds_data_gtfs = self.get_data(
TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, "gtfs", session
)
logging.info(
"Fetched %s GTFS feeds from TransitLand API", len(feeds_data_gtfs["feeds"])
)
logging.info("Fetched %s feeds from TransitLand API", len(feeds_data["feeds"]))
feeds_data = feeds_data_gtfs["feeds"] + feeds_data_gtfs_rt["feeds"]

operators_data = self.get_data(
TRANSITLAND_OPERATOR_URL, TRANSITLAND_API_KEY, session=session
Expand All @@ -115,8 +112,10 @@ def process_sync(
"Fetched %s operators from TransitLand API",
len(operators_data["operators"]),
)

feeds = self.extract_feeds_data(feeds_data)
all_urls = set(
[element[0] for element in db_session.query(Feed.producer_url).all()]
)
feeds = self.extract_feeds_data(feeds_data, all_urls)
operators = self.extract_operators_data(operators_data)

# Converts operators and feeds to pandas DataFrames
Expand All @@ -135,16 +134,18 @@ def process_sync(
# Filtered out rows where 'feed_url' is missing
combined_df = combined_df[combined_df["feed_url"].notna()]

# Group by 'feed_id' and concatenate 'operator_name' while keeping first values of other columns
# Group by 'stable_id' and concatenate 'operator_name' while keeping first values of other columns
df_grouped = (
combined_df.groupby("feed_id")
combined_df.groupby("stable_id")
.agg(
{
"operator_name": lambda x: ", ".join(x),
"feeds_onestop_id": "first",
"feed_id": "first",
"feed_url": "first",
"operator_feed_id": "first",
"spec": "first",
"entity_types": "first",
"country": "first",
"state_province": "first",
"city_name": "first",
Expand Down Expand Up @@ -173,11 +174,6 @@ def process_sync(
filtered_df = filtered_df.drop_duplicates(
subset=["feed_url"]
) # Drop duplicates
filtered_df = filtered_df[filtered_df["feed_url"].apply(self.check_url_status)]
logging.info(
"Filtered out %s feeds with invalid URLs",
len(df_grouped) - len(filtered_df),
)

# Convert filtered DataFrame to dictionary format
combined_data = filtered_df.to_dict(orient="records")
Expand All @@ -187,7 +183,7 @@ def process_sync(
for data in combined_data:
external_id = data["feeds_onestop_id"]
feed_url = data["feed_url"]
source = "TLD"
source = "tld"

if not self.check_external_id(db_session, external_id, source):
payload_type = "new"
Expand All @@ -201,6 +197,8 @@ def process_sync(
# prepare payload
payload = TransitFeedSyncPayload(
external_id=external_id,
stable_id=data["stable_id"],
entity_types=data["entity_types"],
feed_id=data["feed_id"],
execution_id=execution_id,
feed_url=feed_url,
Expand All @@ -212,7 +210,7 @@ def process_sync(
country=data["country"],
state_province=data["state_province"],
city_name=data["city_name"],
source="TLD",
source="tld",
payload_type=payload_type,
)
payloads.append(FeedSyncPayload(external_id=external_id, payload=payload))
Expand Down Expand Up @@ -277,25 +275,39 @@ def get_data(
logging.info("Finished fetching data.")
return all_data

def extract_feeds_data(self, feeds_data: dict) -> List[dict]:
def extract_feeds_data(self, feeds_data: dict, urls_in_db: List[str]) -> List[dict]:
"""
This function extracts relevant data from the Transitland feeds endpoint containing feeds information.
Returns a list of dictionaries representing each feed.
"""
feeds = []
for feed in feeds_data["feeds"]:
feed_url = feed["urls"].get("static_current")
feeds.append(
{
"feed_id": feed["id"],
"feed_url": feed_url,
"spec": feed["spec"].lower(),
"feeds_onestop_id": feed["onestop_id"],
"auth_info_url": feed["authorization"].get("info_url"),
"auth_param_name": feed["authorization"].get("param_name"),
"type": feed["authorization"].get("type"),
}
)
for feed in feeds_data:
feed_urls, entity_types = process_feed_urls(feed, urls_in_db)
logging.info("Feed %s has %s valid URL(s)", feed["id"], len(feed_urls))
logging.info("Feed %s entity types: %s", feed["id"], entity_types)
if len(feed_urls) == 0:
logging.warning("Feed URL not found for feed %s", feed["id"])
continue

for feed_url, entity_types in zip(feed_urls, entity_types):
if entity_types is not None and len(entity_types) > 0:
stable_id = f"{feed['id']}-{entity_types.replace(',', '_')}"
else:
stable_id = feed["id"]
logging.info("Stable ID: %s", stable_id)
feeds.append(
{
"feed_id": feed["id"],
"stable_id": stable_id,
"feed_url": feed_url,
"entity_types": entity_types if len(entity_types) > 0 else None,
"spec": feed["spec"].lower(),
"feeds_onestop_id": feed["onestop_id"],
"auth_info_url": feed["authorization"].get("info_url"),
"auth_param_name": feed["authorization"].get("param_name"),
"type": feed["authorization"].get("type"),
}
)
return feeds

def extract_operators_data(self, operators_data: dict) -> List[dict]:
Expand All @@ -309,16 +321,15 @@ def extract_operators_data(self, operators_data: dict) -> List[dict]:
places = operator["agencies"][0]["places"]
place = places[1] if len(places) > 1 else places[0]

operator_data = {
"operator_name": operator.get("name"),
"operator_feed_id": operator["feeds"][0]["id"]
if operator.get("feeds")
else None,
"country": place.get("adm0_name") if place else None,
"state_province": place.get("adm1_name") if place else None,
"city_name": place.get("city_name") if place else None,
}
operators.append(operator_data)
for related_feed in operator.get("feeds", []):
operator_data = {
"operator_name": operator.get("name"),
"operator_feed_id": related_feed["id"],
"country": place.get("adm0_name") if place else None,
"state_province": place.get("adm1_name") if place else None,
"city_name": place.get("city_name") if place else None,
}
operators.append(operator_data)
return operators

def check_external_id(
Expand All @@ -328,12 +339,12 @@ def check_external_id(
Checks if the external_id exists in the public.externalid table for the given source.
:param db_session: SQLAlchemy session
:param external_id: The external_id (feeds_onestop_id) to check
:param source: The source to filter by (e.g., 'TLD' for TransitLand)
:param source: The source to filter by (e.g., 'tld' for TransitLand)
:return: True if the feed exists, False otherwise
"""
results = (
db_session.query(Gtfsfeed)
.filter(Gtfsfeed.externalids.any(associated_id=external_id))
db_session.query(Feed)
.filter(Feed.externalids.any(associated_id=external_id))
.all()
)
return results is not None and len(results) > 0
Expand All @@ -345,12 +356,12 @@ def get_mbd_feed_url(
Retrieves the feed_url from the public.feed table in the mbd for the given external_id.
:param db_session: SQLAlchemy session
:param external_id: The external_id (feeds_onestop_id) from TransitLand
:param source: The source to filter by (e.g., 'TLD' for TransitLand)
:param source: The source to filter by (e.g., 'tld' for TransitLand)
:return: feed_url in mbd if exists, otherwise None
"""
results = (
db_session.query(Gtfsfeed)
.filter(Gtfsfeed.externalids.any(associated_id=external_id))
db_session.query(Feed)
.filter(Feed.externalids.any(associated_id=external_id))
.all()
)
return results[0].producer_url if results else None
Expand Down
Loading

0 comments on commit 73ea4bc

Please sign in to comment.