Skip to content

Commit

Permalink
Merge pull request #11 from kevinsunny1996/task/fix_key_path_keyerror
Browse files Browse the repository at this point in the history
Update sa fetching for astro cloud
  • Loading branch information
kevinsunny1996 authored Apr 28, 2024
2 parents 3920c3e + 51daa08 commit ecca380
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions dags/utils/gcp_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Airflow Base Hook to get connection
from airflow.hooks.base import BaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook

# GCS Python Library
from google.cloud import storage
Expand Down Expand Up @@ -45,15 +46,18 @@ def get_gcp_connection_and_upload_to_gcs(bucket_name: str, dataframe_name: pd.Da
airflow_service_account_connection = 'gcp'

# Get the connection to GCP using the service account
gcp_connection_sa = BaseHook.get_connection(conn_id=airflow_service_account_connection)
info_logger.info(f'Retrieved connection: {gcp_connection_sa.conn_id}')
# gcp_connection_sa = BaseHook.get_connection(conn_id=airflow_service_account_connection)
gcp_file_upload_hook = GoogleBaseHook(gcp_conn_id=airflow_service_account_connection)
sa_creds = gcp_file_upload_hook.get_credentials()
# info_logger.info(f'Retrieved connection: {gcp_connection_sa.conn_id}')

# Get the credentials from the connection
gcp_connection_sa_credentials = gcp_connection_sa.extra_dejson['key_path']
info_logger.info(f'Retrieved credentials: {gcp_connection_sa_credentials}')
# gcp_connection_sa_credentials = gcp_connection_sa.extra_dejson['key_path']
# info_logger.info(f'Retrieved credentials: {gcp_connection_sa_credentials}')

# Initialize GCS client
client = storage.Client.from_service_account_json(gcp_connection_sa_credentials)
# client = storage.Client.from_service_account_json(gcp_connection_sa_credentials)
client = storage.Client(credentials=sa_creds,project=gcp_file_upload_hook.project_id)

# Intialize bucket object
bucket = client.bucket(bucket_name)
Expand Down

0 comments on commit ecca380

Please sign in to comment.