Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion batch/batch/cloud/azure/worker/worker_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import base64
import os
import ssl
import tempfile
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -30,7 +31,13 @@ def from_env():
assert acr_url.endswith('azurecr.io'), acr_url
return AzureWorkerAPI(subscription_id, resource_group, acr_url, hail_oauth_scope)

def __init__(self, subscription_id: str, resource_group: str, acr_url: str, hail_oauth_scope: str):
def __init__(
self,
subscription_id: str,
resource_group: str,
acr_url: str,
hail_oauth_scope: str,
):
self.subscription_id = subscription_id
self.resource_group = resource_group
self.hail_oauth_scope = hail_oauth_scope
Expand Down Expand Up @@ -64,6 +71,9 @@ async def user_container_registry_credentials(self, credentials: Dict[str, str])
def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application:
raise NotImplementedError

async def worker_ssl_context(self, namespace: str) -> Optional[ssl.SSLContext]:
return None # TODO

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> AzureSlimInstanceConfig:
return AzureSlimInstanceConfig.from_dict(config_dict)

Expand Down
16 changes: 16 additions & 0 deletions batch/batch/cloud/gcp/worker/worker_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import os
import ssl
import tempfile
from contextlib import AsyncExitStack
from pathlib import Path
from typing import Dict, List

import orjson
Expand All @@ -10,6 +12,7 @@
from hailtop import httpx
from hailtop.aiocloud import aiogoogle
from hailtop.auth.auth import IdentityProvider
from hailtop.tls import internal_server_ssl_context
from hailtop.utils import check_exec_output

from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials
Expand Down Expand Up @@ -84,6 +87,19 @@ def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Applica
key = orjson.loads(base64.b64decode(credentials['key.json']).decode())
return create_app(aiogoogle.GoogleServiceAccountCredentials(key), self._metadata_server_client)

async def worker_ssl_context(self, namespace: str) -> ssl.SSLContext:
secret_manager_client = aiogoogle.GoogleSecretManagerClient(self.project, http_session=self._http_session)
async with secret_manager_client:
ssl_config_bytes = await secret_manager_client.get_latest_secret_version(
f'ssl-config-batch-worker-{namespace}'
)
ssl_config = {k: base64.b64decode(v.encode()) for k, v in orjson.loads(ssl_config_bytes).items()}
ssl_config_dir = Path('/ssl-config') / namespace
ssl_config_dir.mkdir(parents=True, exist_ok=True)
for file, contents in ssl_config.items():
(ssl_config_dir / file).write_bytes(contents)
return internal_server_ssl_context(str(ssl_config_dir))

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig:
return GCPSlimInstanceConfig.from_dict(config_dict)

Expand Down
7 changes: 6 additions & 1 deletion batch/batch/cloud/terra/azure/worker/worker_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Dict, List
import ssl
from typing import Dict, List, Optional

import orjson
from aiohttp import web
Expand Down Expand Up @@ -69,6 +70,10 @@ async def user_container_registry_credentials(self, credentials: Dict[str, str])
def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application:
raise NotImplementedError

async def worker_ssl_context(self, namespace: str) -> Optional[ssl.SSLContext]:
# There are no internal certs in Terra on Azure
return None

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> TerraAzureSlimInstanceConfig:
return TerraAzureSlimInstanceConfig.from_dict(config_dict)

Expand Down
12 changes: 9 additions & 3 deletions batch/batch/driver/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import aiohttp

from gear import CommonAiohttpAppKeys, Database, transaction
from hailtop import httpx
from hailtop.humanizex import naturaldelta_msec
from hailtop.utils import retry_transient_errors, time_msecs, time_msecs_str

from ..cloud.utils import instance_config_from_config_dict
from ..globals import INSTANCE_VERSION
from ..instance_config import InstanceConfig
from ..utils import instance_base_url

log = logging.getLogger('instance')

Expand Down Expand Up @@ -132,7 +134,7 @@ def __init__(
instance_config: InstanceConfig,
):
self.db: Database = app['db']
self.client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION]
self.client_session: httpx.ClientSession = app[CommonAiohttpAppKeys.CLIENT_SESSION]
self.inst_coll = inst_coll
# pending, active, inactive, deleted
self._state = state
Expand All @@ -153,6 +155,10 @@ def __init__(
def state(self):
return self._state

@property
def base_url(self) -> str:
return instance_base_url(self.version, self.ip_address)

async def activate(self, ip_address, timestamp):
assert self._state == 'pending'

Expand Down Expand Up @@ -197,7 +203,7 @@ async def make_request():
return
try:
await self.client_session.post(
f'http://{self.ip_address}:5000/api/v1alpha/kill', timeout=aiohttp.ClientTimeout(total=30)
f'{self.base_url}/api/v1alpha/kill', timeout=aiohttp.ClientTimeout(total=30)
)
except aiohttp.ClientResponseError as err:
if err.status == 403:
Expand Down Expand Up @@ -278,7 +284,7 @@ def failed_request_count(self):
async def check_is_active_and_healthy(self):
if self._state == 'active' and self.ip_address:
try:
async with self.client_session.get(f'http://{self.ip_address}:5000/healthcheck') as resp:
async with self.client_session.get(f'{self.base_url}/healthcheck') as resp:
actual_name = (await resp.json()).get('name')
if actual_name and actual_name != self.name:
return False
Expand Down
4 changes: 2 additions & 2 deletions batch/batch/driver/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ async def unschedule_job(app, record):
scheduler_state_changed.notify()
log.info(f'unschedule job {id}, attempt {attempt_id}: updated {instance} free cores')

url = f'http://{instance.ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/delete'
url = f'{instance.base_url}/api/v1alpha/batches/{batch_id}/jobs/{job_id}/delete'

async def make_request():
if instance.state in ('inactive', 'deleted'):
Expand Down Expand Up @@ -580,7 +580,7 @@ async def schedule_job(app, record, instance):

try:
await client_session.post(
f'http://{instance.ip_address}:5000/api/v1alpha/batches/jobs/create',
f'{instance.base_url}/api/v1alpha/batches/jobs/create',
json=body,
timeout=aiohttp.ClientTimeout(total=2),
)
Expand Down
22 changes: 14 additions & 8 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from ..spec_writer import SpecWriter
from ..utils import (
add_metadata_to_request,
instance_base_url,
query_billing_projects_with_cost,
query_billing_projects_without_cost,
regions_to_bits_rep,
Expand Down Expand Up @@ -379,7 +380,7 @@ async def _get_job_record(app, batch_id, job_id):

record = await db.select_and_fetchone(
"""
SELECT jobs.state, jobs.spec, ip_address, format_version, jobs.attempt_id, t.attempt_id AS last_cancelled_attempt_id
SELECT jobs.state, jobs.spec, ip_address, format_version, jobs.attempt_id, t.attempt_id AS last_cancelled_attempt_id, instances.version as instance_version
FROM jobs
INNER JOIN batches
ON jobs.batch_id = batches.id
Expand Down Expand Up @@ -438,11 +439,12 @@ def attempt_id_from_spec(record) -> Optional[str]:
return record['attempt_id'] or record['last_cancelled_attempt_id']


async def _get_job_container_log_from_worker(client_session, batch_id, job_id, container, ip_address) -> bytes:
async def _get_job_container_log_from_worker(client_session, batch_id, job_id, container, job_record) -> bytes:
base_url = instance_base_url(job_record['instance_version'], job_record['ip_address'])
try:
return await retry_transient_errors(
client_session.get_read,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}',
f'{base_url}/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}',
)
except aiohttp.ClientResponseError:
log.exception(f'while getting log for {(batch_id, job_id)}')
Expand All @@ -467,7 +469,11 @@ async def _get_job_container_log(app, batch_id, job_id, container, job_record) -
state = job_record['state']
if state == 'Running':
return await _get_job_container_log_from_worker(
app[CommonAiohttpAppKeys.CLIENT_SESSION], batch_id, job_id, container, job_record['ip_address']
app[CommonAiohttpAppKeys.CLIENT_SESSION],
batch_id,
job_id,
container,
job_record,
)

attempt_id = attempt_id_from_spec(job_record)
Expand Down Expand Up @@ -502,7 +508,7 @@ async def _get_job_resource_usage_from_record(
batch_format_version = BatchFormatVersion(record['format_version'])

state = record['state']
ip_address = record['ip_address']
base_url = instance_base_url(record['instance_version'], record['ip_address'])
tasks = job_tasks_from_spec(record)
attempt_id = attempt_id_from_spec(record)

Expand All @@ -513,7 +519,7 @@ async def _get_job_resource_usage_from_record(
try:
data = await retry_transient_errors(
client_session.get_read_json,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage',
f'{base_url}/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage',
)
return {
task: ResourceUsageMonitor.decode_to_df(base64.b64decode(encoded_df))
Expand Down Expand Up @@ -639,11 +645,11 @@ async def _get_full_job_status(app, record):
assert state == 'Running'
assert record['status'] is None

ip_address = record['ip_address']
base_url = instance_base_url(record['instance_version'], record['ip_address'])
try:
return await retry_transient_errors(
client_session.get_read_json,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status',
f'{base_url}/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status',
)
except aiohttp.ClientResponseError as e:
if e.status == 404:
Expand Down
2 changes: 1 addition & 1 deletion batch/batch/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

BATCH_FORMAT_VERSION = 7
STATUS_FORMAT_VERSION = 5
INSTANCE_VERSION = 29
INSTANCE_VERSION = 30

MAX_PERSISTENT_SSD_SIZE_GIB = 64 * 1024
RESERVED_STORAGE_GB_PER_CORE = 5
8 changes: 8 additions & 0 deletions batch/batch/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import os
from collections import deque
from functools import wraps
from typing import Any, Deque, Dict, List, Optional, Set, Tuple, overload
Expand Down Expand Up @@ -29,6 +30,13 @@ def authorization_token(request):
return session_id


def instance_base_url(version: int, ip_address: str) -> str:
# TODO Azure
if version < 30 or os.getenv('HAIL_TERRA') or os.getenv('CLOUD') == 'azure':
return f'http://{ip_address}:5000'
return f'https://{ip_address}'


def add_metadata_to_request(fun):
@wraps(fun)
async def wrapped(request, *args, **kwargs):
Expand Down
5 changes: 4 additions & 1 deletion batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3186,6 +3186,7 @@ async def healthcheck(self, request): # pylint: disable=unused-argument
return json_response(body)

async def run(self):
assert CLOUD_WORKER_API
app = web.Application(client_max_size=HTTP_CLIENT_MAX_SIZE)
app.add_routes([
web.post('/api/v1alpha/kill', self.kill),
Expand All @@ -3201,7 +3202,9 @@ async def run(self):

app_runner = web.AppRunner(app, access_log_class=BatchWorkerAccessLogger)
await app_runner.setup()
site = web.TCPSite(app_runner, IP_ADDRESS, 5000)
worker_ssl_context = await CLOUD_WORKER_API.worker_ssl_context(NAMESPACE)
port = 443 if worker_ssl_context is not None else 5000
site = web.TCPSite(app_runner, IP_ADDRESS, port, ssl_context=worker_ssl_context)
await site.start()

try:
Expand Down
7 changes: 6 additions & 1 deletion batch/batch/worker/worker_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
from typing import Dict, List, TypedDict, Union
import ssl
from typing import Dict, List, Optional, TypedDict, Union

from aiohttp import web

Expand Down Expand Up @@ -49,6 +50,10 @@ async def user_container_registry_credentials(self, credentials: Dict[str, str])
def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application:
raise NotImplementedError

@abc.abstractmethod
async def worker_ssl_context(self, namespace: str) -> Optional[ssl.SSLContext]:
raise NotImplementedError

@abc.abstractmethod
def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> InstanceConfig:
raise NotImplementedError
Expand Down
12 changes: 8 additions & 4 deletions build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ steps:
- merge_code
- kind: buildImage2
name: create_certs_image
dockerFile: /io/tls/Dockerfile
contextPath: /io/tls
dockerFile: /io/repo/tls/Dockerfile
contextPath: /io/repo
publishAs: create_certs_image
inputs:
- from: /repo/tls
to: /io/tls
- from: /repo
to: /io/repo
dependsOn:
- hail_ubuntu_image
- merge_code
Expand Down Expand Up @@ -243,6 +243,10 @@ steps:
valueFrom: create_certs_image.image
script: |
set -ex
export CLOUD={{ global.cloud }}
{% if global.cloud == 'gcp' %}
export HAIL_PROJECT={{ global.gcp_project }}
{% endif %}
python3 create_certs.py \
{{ default_ns.name }} \
config.yaml \
Expand Down
1 change: 1 addition & 0 deletions dev-docs/services/tls-cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ kubectl create secret generic \
make -C $HAIL/hail python/hailtop/hail_version

PYTHONPATH=$HAIL/hail/python \
HAIL_PROJECT=hail-vdc \
python3 $HAIL/tls/create_certs.py \
default \
$HAIL/tls/config.yaml \
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hailtop/aiocloud/aiogoogle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
GoogleIAmClient,
GoogleLoggingClient,
GoogleMetadataServerClient,
GoogleSecretManagerClient,
GoogleStorageAsyncFS,
GoogleStorageAsyncFSFactory,
GoogleStorageClient,
Expand All @@ -32,6 +33,7 @@
'GoogleIAmClient',
'GoogleLoggingClient',
'GoogleMetadataServerClient',
'GoogleSecretManagerClient',
'GoogleStorageClient',
'GoogleStorageAsyncFS',
'GoogleStorageAsyncFSFactory',
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hailtop/aiocloud/aiogoogle/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .iam_client import GoogleIAmClient
from .logging_client import GoogleLoggingClient
from .metadata_server_client import GoogleMetadataServerClient
from .secret_manager_client import GoogleSecretManagerClient
from .storage_client import (
GCSRequesterPaysConfiguration,
GoogleStorageAsyncFS,
Expand All @@ -20,6 +21,7 @@
'GoogleIAmClient',
'GoogleLoggingClient',
'GoogleMetadataServerClient',
'GoogleSecretManagerClient',
'GCSRequesterPaysConfiguration',
'GoogleStorageClient',
'GoogleStorageAsyncFS',
Expand Down
Loading