diff --git a/.github/workflows/deploy_server.yaml b/.github/workflows/deploy_server.yaml index db498218..2b275aad 100644 --- a/.github/workflows/deploy_server.yaml +++ b/.github/workflows/deploy_server.yaml @@ -15,8 +15,10 @@ jobs: DOCKER_BUILDKIT: 1 BUILDKIT_PROGRESS: plain CLOUDSDK_CORE_DISABLE_PROMPTS: 1 - DRIVER_IMAGE: australia-southeast1-docker.pkg.dev/analysis-runner/images/driver - SERVER_IMAGE: australia-southeast1-docker.pkg.dev/analysis-runner/images/server + AZURE_CONTAINER_REGISTRY: cpgcommonimages.azurecr.io + GCP_CONTAINER_REGSITRY: australia-southeast1-docker.pkg.dev + DRIVER_IMAGE: analysis-runner/images/driver + SERVER_IMAGE: $GCP_CONTAINER_REGSITRY/analysis-runner/images/server steps: - name: "checkout analysis-runner repo" @@ -48,21 +50,37 @@ jobs: run: | gcloud auth configure-docker marketplace.gcr.io,australia-southeast1-docker.pkg.dev + - name: "azure setup" + uses: azure/login@v1 + with: + creds: $${{ secrets.AZURE_CREDENTIALS }} + + - name: "azure docker auth" + uses: azure/docker-login@v1 + with: + login-server: $AZURE_CONTAINER_REGISTRY + username: ${{ secrets.AZURE_REGISTRY_USERNAME }} + password: ${{ secrets.AZURE_REGISTRY_PASSWORD }} + - name: "build driver image" run: | docker build -f driver/Dockerfile.hail --build-arg HAIL_SHA=$HAIL_SHA --tag $DRIVER_IMAGE:$IMAGE_TAG driver - - name: "push driver image" + - name: "tag and push gcp image [driver]" + run: | + docker image tag $DRIVER_IMAGE:$IMAGE_TAG $GCP_CONTAINER_REGISTRY/$DRIVER_IMAGE:latest && + docker push $GCP_CONTAINER_REGISTRY/$DRIVER_IMAGE:latest + + - name: "tag and push azure image [driver]" run: | - docker push $DRIVER_IMAGE:$IMAGE_TAG - docker tag $DRIVER_IMAGE:$IMAGE_TAG $DRIVER_IMAGE:latest - docker push $DRIVER_IMAGE:latest + docker image tag $DRIVER_IMAGE:$IMAGE_TAG $AZURE_CONTAINER_REGISTRY/$DRIVER_IMAGE:latest && + $AZURE_CONTAINER_REGISTRY/$DRIVER_IMAGE:latest - name: "build server image" run: | docker build --build-arg DRIVER_IMAGE=$DRIVER_IMAGE:$IMAGE_TAG --tag $SERVER_IMAGE:$IMAGE_TAG server - - name: "push server image" + - name: "push server image to gcp" run: | docker push $SERVER_IMAGE:$IMAGE_TAG docker tag $SERVER_IMAGE:$IMAGE_TAG $SERVER_IMAGE:latest diff --git a/.gitignore b/.gitignore index b2867c56..6a52dd99 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,4 @@ dmypy.json .idea/ .DS_Store .vscode/ +.*.json diff --git a/analysis_runner/cli.py b/analysis_runner/cli.py index 404fba53..18bcc19a 100755 --- a/analysis_runner/cli.py +++ b/analysis_runner/cli.py @@ -78,4 +78,4 @@ def main_from_args(args=None): if __name__ == '__main__': - main_from_args() + main_from_args(args=sys.argv[1:]) diff --git a/analysis_runner/cli_analysisrunner.py b/analysis_runner/cli_analysisrunner.py index 93894bd0..314dad24 100644 --- a/analysis_runner/cli_analysisrunner.py +++ b/analysis_runner/cli_analysisrunner.py @@ -9,7 +9,7 @@ import requests from cpg_utils.config import read_configs -from cpg_utils.cloud import get_google_identity_token +from cpg_utils.cloud import get_google_identity_token, get_azure_identity_token from analysis_runner.constants import get_server_endpoint from analysis_runner.git import ( get_git_default_remote, @@ -25,6 +25,8 @@ logger, ) +SUPPORTED_CLOUD_ENVIRONMENTS = {'gcp', 'azure'} +DEFAULT_CLOUD_ENVIRONMENT = 'gcp' def add_analysis_runner_args(parser=None) -> argparse.ArgumentParser: """ @@ -35,6 +37,15 @@ def add_analysis_runner_args(parser=None) -> argparse.ArgumentParser: add_general_args(parser) + parser.add_argument( + '-c', + '--cloud', + required=False, + default=DEFAULT_CLOUD_ENVIRONMENT, + choices=SUPPORTED_CLOUD_ENVIRONMENTS, + help=f'Backend cloud environment to use. Supported options are ({", ".join(SUPPORTED_CLOUD_ENVIRONMENTS)})', + ) + parser.add_argument( '--image', help=( @@ -104,6 +115,7 @@ def run_analysis_runner( # pylint: disable=too-many-arguments commit=None, repository=None, cwd=None, + cloud=DEFAULT_CLOUD_ENVIRONMENT, image=None, cpu=None, memory=None, @@ -216,6 +228,7 @@ def run_analysis_runner( # pylint: disable=too-many-arguments 'script': _script, 'description': description, 'cwd': _cwd, + 'cloud': cloud, 'image': image, 'cpu': cpu, 'memory': memory, @@ -224,7 +237,7 @@ def run_analysis_runner( # pylint: disable=too-many-arguments 'config': _config, }, headers={'Authorization': f'Bearer {_token}'}, - timeout=60, + # timeout=60, ) try: response.raise_for_status() diff --git a/analysis_runner/constants.py b/analysis_runner/constants.py index eb29acdd..d083fe6d 100644 --- a/analysis_runner/constants.py +++ b/analysis_runner/constants.py @@ -1,5 +1,8 @@ """Constants for analysis-runner""" +import os +import distutils + SERVER_ENDPOINT = 'https://server-a2pko7ameq-ts.a.run.app' SERVER_TEST_ENDPOINT = 'https://server-test-a2pko7ameq-ts.a.run.app' ANALYSIS_RUNNER_PROJECT_ID = 'analysis-runner' @@ -11,13 +14,17 @@ 'gcloud -q auth activate-service-account --key-file=/gsa-key/key.json' ) +USE_LOCAL_SERVER = distutils.util.strtobool(os.getenv('ANALYSIS_RUNNER_LOCAL', 'False')) + def get_server_endpoint(is_test: bool = False): """ Get the server endpoint {production / test} Do it in a function so it's easy to fix if the logic changes """ - if is_test: + if USE_LOCAL_SERVER: + return 'http://localhost:8080' + elif is_test: return SERVER_TEST_ENDPOINT return SERVER_ENDPOINT diff --git a/examples/batch/hail_batch_job.py b/examples/batch/hail_batch_job.py new file mode 100755 index 00000000..a62b8842 --- /dev/null +++ b/examples/batch/hail_batch_job.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +""" + Simple script to test whether the CPG infrastructure and permissions are + configured appropriately to permit running AIP. +""" + +import click + +from cpg_utils.config import get_config +from cpg_utils.hail_batch import remote_tmpdir +import hailtop.batch as hb + + +@click.command() +def main(): + """ + main + """ + + service_backend = hb.ServiceBackend( + billing_project=get_config()['hail']['billing_project'], + remote_tmpdir=remote_tmpdir(), + ) + batch = hb.Batch( + name='Test CPG Infra', + backend=service_backend, + cancel_after_n_failures=1, + default_timeout=6000, + default_memory='highmem', + ) + + j = batch.new_job(name='Write the file') + j.command(f'echo "Hello World." > {j.ofile}') + + k = batch.new_job(name='Read the file') + k.command(f'cat {j.ofile}') + + batch.run(wait=False) + + +if __name__ == '__main__': + main() # pylint: disable=E1120 diff --git a/examples/batch/hail_batch_job.toml b/examples/batch/hail_batch_job.toml new file mode 100644 index 00000000..2a14098d --- /dev/null +++ b/examples/batch/hail_batch_job.toml @@ -0,0 +1,19 @@ +[buckets] +web_suffix = "web" +tmp_suffix = "tmp" +analysis_suffix = "analysis" + +[workflow] +dataset = "thousand-genomes" +access_level = "test" +dataset_path = "cpgthousandgenomes/test" +output_prefix = "output" +path_scheme = "az" +image_registry_prefix = "cpgcommonimages.azurecr.io" + +[hail] +billing_project = "fewgenomes" +bucket = "az://cpgthousandgenomes/test" + +[images] +hail = "hailgenetics/hail:0.2.93" diff --git a/examples/batch/run_analysis.sh b/examples/batch/run_analysis.sh new file mode 100644 index 00000000..1030c9ff --- /dev/null +++ b/examples/batch/run_analysis.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +analysis-runner \ + --dataset thousand-genomes \ + --description 'Test script for batch on Azure' \ + --output-dir test \ + --cloud azure \ + --access-level test \ + --config examples/batch/hail_batch_job.toml \ + --image cpg_workflows:latest \ + examples/batch/test_cpg_infra.py \ + test diff --git a/server/Dockerfile b/server/Dockerfile index 10589286..131de7a1 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -19,6 +19,8 @@ EXPOSE $PORT COPY main.py cromwell.py util.py ./ # Prepare the Hail deploy config to point to the CPG domain. -COPY deploy-config.json /deploy-config/deploy-config.json +ENV HAIL_DEPLOY_CONFIG_FILE /deploy-config/deploy-config-gcp.json +COPY deploy-config-gcp.json /deploy-config/deploy-config-gcp.json +COPY deploy-config-azure.json /deploy-config/deploy-config-azure.json CMD gunicorn --bind :$PORT --worker-class aiohttp.GunicornWebWorker main:init_func diff --git a/server/cromwell.py b/server/cromwell.py index 3c3449ce..c2f238c4 100644 --- a/server/cromwell.py +++ b/server/cromwell.py @@ -3,6 +3,7 @@ Exports 'add_cromwell_routes', to add the following route to a flask API: POST /cromwell: Posts a workflow to a cromwell_url """ +import os import json from datetime import datetime @@ -100,11 +101,6 @@ async def cromwell(request): # pylint: disable=too-many-locals input_jsons = params.get('input_json_paths') or [] input_dict = params.get('inputs_dict') - if access_level == 'test': - workflow_output_dir = f'gs://cpg-{dataset}-test/{output_dir}' - else: - workflow_output_dir = f'gs://cpg-{dataset}-main/{output_dir}' - timestamp = datetime.now().astimezone().isoformat() # Prepare the job's configuration and write it to a blob. @@ -121,6 +117,10 @@ async def cromwell(request): # pylint: disable=too-many-locals config_path = write_config(config, cloud_environment) # This metadata dictionary gets stored at the output_dir location. + workflow_output_dir = os.path.join( + config.get('storage', {}).get('default', {}).get('default'), + output_dir + ) metadata = get_analysis_runner_metadata( timestamp=timestamp, dataset=dataset, diff --git a/server/deploy-config-azure.json b/server/deploy-config-azure.json new file mode 100644 index 00000000..7e42e445 --- /dev/null +++ b/server/deploy-config-azure.json @@ -0,0 +1,5 @@ +{ + "location": "external", + "default_namespace": "default", + "domain": "azhail.populationgenomics.org.au" +} diff --git a/server/deploy-config-gcp.json b/server/deploy-config-gcp.json new file mode 100644 index 00000000..3be2a787 --- /dev/null +++ b/server/deploy-config-gcp.json @@ -0,0 +1,5 @@ +{ + "location": "external", + "default_namespace": "default", + "domain": "hail.populationgenomics.org.au" +} diff --git a/server/deploy-config.json b/server/deploy-config.json deleted file mode 100644 index 7fb35202..00000000 --- a/server/deploy-config.json +++ /dev/null @@ -1 +0,0 @@ -{"location": "external", "default_namespace": "default", "domain": "hail.populationgenomics.org.au"} diff --git a/server/main.py b/server/main.py index 3e59a699..ff4de9d3 100644 --- a/server/main.py +++ b/server/main.py @@ -1,7 +1,8 @@ """The analysis-runner server, running Hail Batch pipelines on users' behalf.""" # pylint: disable=wrong-import-order -import datetime +import os import json +import datetime import logging import traceback from shlex import quote @@ -15,6 +16,9 @@ from util import ( DRIVER_IMAGE, PUBSUB_TOPIC, + SUPPORTED_CLOUD_ENVIRONMENTS, + DEFAULT_CLOUD_ENVIRONMENT, + DEPLOY_CONFIG_PATHS, _get_hail_version, check_allowed_repos, check_dataset_and_group, @@ -41,9 +45,6 @@ routes = web.RouteTableDef() -SUPPORTED_CLOUD_ENVIRONMENTS = {'gcp'} - - # pylint: disable=too-many-statements @routes.post('/') async def index(request): @@ -56,11 +57,14 @@ async def index(request): output_prefix = validate_output_dir(params['output']) dataset = params['dataset'] - cloud_environment = params.get('cloud_environment', 'gcp') + cloud_environment = params.get('cloud', DEFAULT_CLOUD_ENVIRONMENT) if cloud_environment not in SUPPORTED_CLOUD_ENVIRONMENTS: raise web.HTTPBadRequest( reason=f'analysis-runner does not yet support the {cloud_environment} environment' ) + + # Set hail backend to the correct one based on the cloud environment + os.environ['HAIL_DEPLOY_CONFIG_FILE'] = DEPLOY_CONFIG_PATHS.get(cloud_environment) dataset_config = check_dataset_and_group( server_config=get_server_config(), @@ -114,8 +118,8 @@ async def index(request): # Prepare the job's configuration and write it to a blob. run_config = get_baseline_run_config( - environment=cloud_environment, - gcp_project_id=environment_config.get('projectId'), + environment='gcp', + gcp_project_id=dataset_config.get('gcp', {}).get('projectId'), dataset=dataset, access_level=access_level, output_prefix=output_prefix, @@ -125,6 +129,10 @@ async def index(request): update_dict(run_config, user_config) config_path = write_config(run_config, environment=cloud_environment) + output_dir = os.path.join( + run_config.get('storage', {}).get('default', {}).get('default'), + output_prefix + ) metadata = get_analysis_runner_metadata( timestamp=timestamp, dataset=dataset, @@ -134,12 +142,12 @@ async def index(request): commit=commit, script=' '.join(script), description=params['description'], - output_prefix=output_prefix, hailVersion=hail_version, driver_image=image, config_path=config_path, cwd=cwd, environment=cloud_environment, + output_dir=output_dir ) user_name = email.split('@')[0] @@ -153,6 +161,7 @@ async def index(request): batch = hb.Batch(backend=backend, name=batch_name, **extra_batch_params) job = batch.new_job(name='driver') + job.command('cat /gsa-key/key.json') job = prepare_git_job(job=job, repo_name=repo, commit=commit, is_test=is_test) job.image(image) if cpu: @@ -217,7 +226,7 @@ async def config(request): output_prefix = validate_output_dir(params['output']) dataset = params['dataset'] - cloud_environment = params.get('cloud_environment', 'gcp') + cloud_environment = params.get('cloud', DEFAULT_CLOUD_ENVIRONMENT) if cloud_environment not in SUPPORTED_CLOUD_ENVIRONMENTS: raise web.HTTPBadRequest( reason=f'analysis-runner config does not yet support the {cloud_environment} environment' @@ -229,7 +238,6 @@ async def config(request): dataset=dataset, email=email, ) - environment_config = dataset_config.get(cloud_environment) image = params.get('image') or DRIVER_IMAGE access_level = params['accessLevel'] @@ -241,8 +249,8 @@ async def config(request): # Prepare the job's configuration to return run_config = get_baseline_run_config( - environment=cloud_environment, - gcp_project_id=environment_config.get('projectId'), + environment='gcp', + gcp_project_id=dataset_config.get('gcp', {}).get('projectId'), dataset=dataset, access_level=access_level, output_prefix=output_prefix, diff --git a/server/util.py b/server/util.py index b1c2e34a..08b4afc4 100644 --- a/server/util.py +++ b/server/util.py @@ -6,10 +6,12 @@ import json import uuid import toml +import distutils from aiohttp import web, ClientSession -from cloudpathlib import AnyPath +from cloudpathlib import AzureBlobClient, AnyPath from hailtop.config import get_deploy_config +from azure.identity import DefaultAzureCredential, EnvironmentCredential from google.cloud import secretmanager, pubsub_v1 from cpg_utils.config import update_dict from cpg_utils.cloud import ( @@ -33,7 +35,28 @@ MEMBERS_CACHE_LOCATION = os.getenv('MEMBERS_CACHE_LOCATION') assert MEMBERS_CACHE_LOCATION -CONFIG_PATH_PREFIXES = {'gcp': 'gs://cpg-config'} +SUPPORTED_CLOUD_ENVIRONMENTS = {'gcp', 'azure'} +DEFAULT_CLOUD_ENVIRONMENT = 'gcp' + +USE_LOCAL_SERVER = distutils.util.strtobool(os.getenv('ANALYSIS_RUNNER_LOCAL', 'False')) +PREFIX = os.path.dirname(os.path.abspath(__file__)) if USE_LOCAL_SERVER else '/deploy-config/' +DEPLOY_CONFIG_PATHS = { + 'gcp': os.path.join(PREFIX, 'deploy-config-gcp.json'), + 'azure': os.path.join(PREFIX, 'deploy-config-azure.json') +} + +AZURE_STORAGE_ACCOUNT = 'cpgcommon' +CONFIG_PATH_PREFIXES = { + 'gcp': 'gs://cpg-config', + 'azure': 'az://cpg-config' +} + +# Set Azure AnyPath client +client = AzureBlobClient( + account_url=f'https://{AZURE_STORAGE_ACCOUNT}.blob.core.windows.net/', + credential=EnvironmentCredential() +) +client.set_as_default_client() secret_manager = secretmanager.SecretManagerServiceClient() publisher = pubsub_v1.PublisherClient() @@ -46,7 +69,7 @@ def get_server_config() -> dict: async def _get_hail_version(environment: str) -> str: """ASYNC get hail version for the hail server in the local deploy_config""" - if not environment == 'gcp': + if environment not in SUPPORTED_CLOUD_ENVIRONMENTS: raise web.HTTPBadRequest( reason=f'Unsupported Hail Batch deploy config environment: {environment}' ) @@ -110,13 +133,20 @@ def check_dataset_and_group(server_config, environment: str, dataset, email) -> reason=f'Dataset {dataset} does not support the {environment} environment' ) - # do this to check access-members cache - gcp_project = dataset_config.get('gcp', {}).get('projectId') + if environment == 'gcp': + # do this to check access-members cache + gcp_project = dataset_config.get('gcp', {}).get('projectId') + + if not gcp_project: + raise web.HTTPBadRequest( + reason=f'The analysis-runner does not support checking group members for the {environment} environment' + ) + elif environment == 'azure': + if not environment in dataset_config: + raise web.HTTPBadRequest( + reason=f'The analysis-runner does not support checking group members for the {environment} environment' + ) - if not gcp_project: - raise web.HTTPBadRequest( - reason=f'The analysis-runner does not support checking group members for the {environment} environment' - ) if not is_member_in_cached_group( f'{dataset}-analysis', email, members_cache_location=MEMBERS_CACHE_LOCATION ): @@ -137,19 +167,17 @@ def get_analysis_runner_metadata( commit, script, description, - output_prefix, driver_image, config_path, cwd, environment, + output_dir, **kwargs, ): """ Get well-formed analysis-runner metadata, requiring the core listed keys with some flexibility to provide your own keys (as **kwargs) """ - output_dir = f'gs://cpg-{dataset}-{cpg_namespace(access_level)}/{output_prefix}' - return { 'timestamp': timestamp, 'dataset': dataset, @@ -170,7 +198,7 @@ def get_analysis_runner_metadata( def run_batch_job_and_print_url(batch, wait, environment): """Call batch.run(), return the URL, and wait for job to finish if wait=True""" - if not environment == 'gcp': + if environment not in SUPPORTED_CLOUD_ENVIRONMENTS: raise web.HTTPBadRequest( reason=f'Unsupported Hail Batch deploy config environment: {environment}' ) @@ -198,13 +226,20 @@ def validate_image(container: str, is_test: bool): def write_config(config: dict, environment: str) -> str: """Writes the given config dictionary to a blob and returns its unique path.""" + prefix = CONFIG_PATH_PREFIXES.get(environment) if not prefix: raise web.HTTPBadRequest(reason=f'Bad environment for config: {environment}') + # Uses the default AzureBlobClient defined at the top of utils + # to connect to AZURE_STORAGE_ACCOUNT where it will always write out to config_path = AnyPath(prefix) / (str(uuid.uuid4()) + '.toml') with config_path.open('w') as f: toml.dump(config, f) + + if environment == 'azure': + return os.path.join(f'hail-az://{AZURE_STORAGE_ACCOUNT}', str(config_path).removeprefix('az://')) + return str(config_path) @@ -228,8 +263,6 @@ def get_baseline_run_config( baseline_config = { 'hail': { 'billing_project': dataset, - # TODO: how would this work for Azure - 'bucket': f'cpg-{dataset}-hail', }, 'workflow': { 'access_level': access_level,