Skip to content

Commit 0a2e785

Browse files
[batch] Add internal ssl certs to worker server
1 parent e2a420c commit 0a2e785

File tree

23 files changed

+245
-95
lines changed

23 files changed

+245
-95
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ docs.tar.gz: hail/build/www
179179

180180
website-image: docs.tar.gz
181181

182-
$(SERVICES_IMAGES): %-image: $(SERVICES_IMAGE_DEPS) $(shell git ls-files $$* ':!:**/deployment.yaml')
182+
.SECONDEXPANSION:
183+
$(SERVICES_IMAGES): %-image: $(SERVICES_IMAGE_DEPS) $$(shell git ls-files $$* ':!:**/deployment.yaml')
183184
./docker-build.sh . $*/Dockerfile $(IMAGE_NAME) --build-arg BASE_IMAGE=$(shell cat hail-ubuntu-image)
184185
echo $(IMAGE_NAME) > $@
185186

batch/batch/cloud/azure/worker/worker_api.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import base64
33
import os
4+
import ssl
45
import tempfile
56
from typing import Dict, List, Optional, Tuple
67

@@ -30,7 +31,13 @@ def from_env():
3031
assert acr_url.endswith('azurecr.io'), acr_url
3132
return AzureWorkerAPI(subscription_id, resource_group, acr_url, hail_oauth_scope)
3233

33-
def __init__(self, subscription_id: str, resource_group: str, acr_url: str, hail_oauth_scope: str):
34+
def __init__(
35+
self,
36+
subscription_id: str,
37+
resource_group: str,
38+
acr_url: str,
39+
hail_oauth_scope: str,
40+
):
3441
self.subscription_id = subscription_id
3542
self.resource_group = resource_group
3643
self.hail_oauth_scope = hail_oauth_scope
@@ -64,6 +71,9 @@ async def user_container_registry_credentials(self, credentials: Dict[str, str])
6471
def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application:
6572
raise NotImplementedError
6673

74+
async def worker_ssl_context(self, namespace: str) -> Optional[ssl.SSLContext]:
75+
return None # TODO
76+
6777
def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> AzureSlimInstanceConfig:
6878
return AzureSlimInstanceConfig.from_dict(config_dict)
6979

batch/batch/cloud/gcp/worker/worker_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import base64
22
import os
3+
import ssl
34
import tempfile
45
from contextlib import AsyncExitStack
6+
from pathlib import Path
57
from typing import Dict, List
68

79
import orjson
@@ -10,6 +12,7 @@
1012
from hailtop import httpx
1113
from hailtop.aiocloud import aiogoogle
1214
from hailtop.auth.auth import IdentityProvider
15+
from hailtop.tls import internal_server_ssl_context
1316
from hailtop.utils import check_exec_output
1417

1518
from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials
@@ -84,6 +87,19 @@ def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Applica
8487
key = orjson.loads(base64.b64decode(credentials['key.json']).decode())
8588
return create_app(aiogoogle.GoogleServiceAccountCredentials(key), self._metadata_server_client)
8689

90+
async def worker_ssl_context(self, namespace: str) -> ssl.SSLContext:
91+
secret_manager_client = aiogoogle.GoogleSecretManagerClient(self.project, http_session=self._http_session)
92+
async with secret_manager_client:
93+
ssl_config_bytes = await secret_manager_client.get_latest_secret_version(
94+
f'ssl-config-batch-worker-{namespace}'
95+
)
96+
ssl_config = {k: base64.b64decode(v.encode()) for k, v in orjson.loads(ssl_config_bytes).items()}
97+
ssl_config_dir = Path('/ssl-config') / namespace
98+
ssl_config_dir.mkdir(parents=True, exist_ok=True)
99+
for file, contents in ssl_config.items():
100+
(ssl_config_dir / file).write_bytes(contents)
101+
return internal_server_ssl_context(str(ssl_config_dir))
102+
87103
def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig:
88104
return GCPSlimInstanceConfig.from_dict(config_dict)
89105

batch/batch/cloud/terra/azure/worker/worker_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
2-
from typing import Dict, List
2+
import ssl
3+
from typing import Dict, List, Optional
34

45
import orjson
56
from aiohttp import web
@@ -69,6 +70,10 @@ async def user_container_registry_credentials(self, credentials: Dict[str, str])
6970
def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application:
7071
raise NotImplementedError
7172

73+
async def worker_ssl_context(self, namespace: str) -> Optional[ssl.SSLContext]:
74+
# There are no internal certs in Terra on Azure
75+
return None
76+
7277
def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> TerraAzureSlimInstanceConfig:
7378
return TerraAzureSlimInstanceConfig.from_dict(config_dict)
7479

batch/batch/driver/instance.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import aiohttp
88

99
from gear import CommonAiohttpAppKeys, Database, transaction
10+
from hailtop import httpx
1011
from hailtop.humanizex import naturaldelta_msec
1112
from hailtop.utils import retry_transient_errors, time_msecs, time_msecs_str
1213

1314
from ..cloud.utils import instance_config_from_config_dict
1415
from ..globals import INSTANCE_VERSION
1516
from ..instance_config import InstanceConfig
17+
from ..utils import instance_base_url
1618

1719
log = logging.getLogger('instance')
1820

@@ -132,7 +134,7 @@ def __init__(
132134
instance_config: InstanceConfig,
133135
):
134136
self.db: Database = app['db']
135-
self.client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION]
137+
self.client_session: httpx.ClientSession = app[CommonAiohttpAppKeys.CLIENT_SESSION]
136138
self.inst_coll = inst_coll
137139
# pending, active, inactive, deleted
138140
self._state = state
@@ -153,6 +155,10 @@ def __init__(
153155
def state(self):
154156
return self._state
155157

158+
@property
159+
def base_url(self) -> str:
160+
return instance_base_url(self.version, self.ip_address)
161+
156162
async def activate(self, ip_address, timestamp):
157163
assert self._state == 'pending'
158164

@@ -197,7 +203,7 @@ async def make_request():
197203
return
198204
try:
199205
await self.client_session.post(
200-
f'http://{self.ip_address}:5000/api/v1alpha/kill', timeout=aiohttp.ClientTimeout(total=30)
206+
f'{self.base_url}/api/v1alpha/kill', timeout=aiohttp.ClientTimeout(total=30)
201207
)
202208
except aiohttp.ClientResponseError as err:
203209
if err.status == 403:
@@ -278,7 +284,7 @@ def failed_request_count(self):
278284
async def check_is_active_and_healthy(self):
279285
if self._state == 'active' and self.ip_address:
280286
try:
281-
async with self.client_session.get(f'http://{self.ip_address}:5000/healthcheck') as resp:
287+
async with self.client_session.get(f'{self.base_url}/healthcheck') as resp:
282288
actual_name = (await resp.json()).get('name')
283289
if actual_name and actual_name != self.name:
284290
return False

batch/batch/driver/job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ async def unschedule_job(app, record):
386386
scheduler_state_changed.notify()
387387
log.info(f'unschedule job {id}, attempt {attempt_id}: updated {instance} free cores')
388388

389-
url = f'http://{instance.ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/delete'
389+
url = f'{instance.base_url}/api/v1alpha/batches/{batch_id}/jobs/{job_id}/delete'
390390

391391
async def make_request():
392392
if instance.state in ('inactive', 'deleted'):
@@ -580,7 +580,7 @@ async def schedule_job(app, record, instance):
580580

581581
try:
582582
await client_session.post(
583-
f'http://{instance.ip_address}:5000/api/v1alpha/batches/jobs/create',
583+
f'{instance.base_url}/api/v1alpha/batches/jobs/create',
584584
json=body,
585585
timeout=aiohttp.ClientTimeout(total=2),
586586
)

batch/batch/front_end/front_end.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
from ..spec_writer import SpecWriter
105105
from ..utils import (
106106
add_metadata_to_request,
107+
instance_base_url,
107108
query_billing_projects_with_cost,
108109
query_billing_projects_without_cost,
109110
regions_to_bits_rep,
@@ -379,7 +380,7 @@ async def _get_job_record(app, batch_id, job_id):
379380

380381
record = await db.select_and_fetchone(
381382
"""
382-
SELECT jobs.state, jobs.spec, ip_address, format_version, jobs.attempt_id, t.attempt_id AS last_cancelled_attempt_id
383+
SELECT jobs.state, jobs.spec, ip_address, format_version, jobs.attempt_id, t.attempt_id AS last_cancelled_attempt_id, instances.version as instance_version
383384
FROM jobs
384385
INNER JOIN batches
385386
ON jobs.batch_id = batches.id
@@ -438,11 +439,12 @@ def attempt_id_from_spec(record) -> Optional[str]:
438439
return record['attempt_id'] or record['last_cancelled_attempt_id']
439440

440441

441-
async def _get_job_container_log_from_worker(client_session, batch_id, job_id, container, ip_address) -> bytes:
442+
async def _get_job_container_log_from_worker(client_session, batch_id, job_id, container, job_record) -> bytes:
443+
base_url = instance_base_url(job_record['instance_version'], job_record['ip_address'])
442444
try:
443445
return await retry_transient_errors(
444446
client_session.get_read,
445-
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}',
447+
f'{base_url}/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}',
446448
)
447449
except aiohttp.ClientResponseError:
448450
log.exception(f'while getting log for {(batch_id, job_id)}')
@@ -467,7 +469,11 @@ async def _get_job_container_log(app, batch_id, job_id, container, job_record) -
467469
state = job_record['state']
468470
if state == 'Running':
469471
return await _get_job_container_log_from_worker(
470-
app[CommonAiohttpAppKeys.CLIENT_SESSION], batch_id, job_id, container, job_record['ip_address']
472+
app[CommonAiohttpAppKeys.CLIENT_SESSION],
473+
batch_id,
474+
job_id,
475+
container,
476+
job_record,
471477
)
472478

473479
attempt_id = attempt_id_from_spec(job_record)
@@ -502,7 +508,7 @@ async def _get_job_resource_usage_from_record(
502508
batch_format_version = BatchFormatVersion(record['format_version'])
503509

504510
state = record['state']
505-
ip_address = record['ip_address']
511+
base_url = instance_base_url(record['instance_version'], record['ip_address'])
506512
tasks = job_tasks_from_spec(record)
507513
attempt_id = attempt_id_from_spec(record)
508514

@@ -513,7 +519,7 @@ async def _get_job_resource_usage_from_record(
513519
try:
514520
data = await retry_transient_errors(
515521
client_session.get_read_json,
516-
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage',
522+
f'{base_url}/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage',
517523
)
518524
return {
519525
task: ResourceUsageMonitor.decode_to_df(base64.b64decode(encoded_df))
@@ -639,11 +645,11 @@ async def _get_full_job_status(app, record):
639645
assert state == 'Running'
640646
assert record['status'] is None
641647

642-
ip_address = record['ip_address']
648+
base_url = instance_base_url(record['instance_version'], record['ip_address'])
643649
try:
644650
return await retry_transient_errors(
645651
client_session.get_read_json,
646-
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status',
652+
f'{base_url}/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status',
647653
)
648654
except aiohttp.ClientResponseError as e:
649655
if e.status == 404:

batch/batch/globals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
BATCH_FORMAT_VERSION = 7
2525
STATUS_FORMAT_VERSION = 5
26-
INSTANCE_VERSION = 29
26+
INSTANCE_VERSION = 30
2727

2828
MAX_PERSISTENT_SSD_SIZE_GIB = 64 * 1024
2929
RESERVED_STORAGE_GB_PER_CORE = 5

batch/batch/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import os
34
from collections import deque
45
from functools import wraps
56
from typing import Any, Deque, Dict, List, Optional, Set, Tuple, overload
@@ -29,6 +30,13 @@ def authorization_token(request):
2930
return session_id
3031

3132

33+
def instance_base_url(version: int, ip_address: str) -> str:
34+
# TODO Azure
35+
if version < 30 or os.getenv('HAIL_TERRA') or os.getenv('CLOUD') == 'azure':
36+
return f'http://{ip_address}:5000'
37+
return f'https://{ip_address}'
38+
39+
3240
def add_metadata_to_request(fun):
3341
@wraps(fun)
3442
async def wrapped(request, *args, **kwargs):

batch/batch/worker/worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3186,6 +3186,7 @@ async def healthcheck(self, request): # pylint: disable=unused-argument
31863186
return json_response(body)
31873187

31883188
async def run(self):
3189+
assert CLOUD_WORKER_API
31893190
app = web.Application(client_max_size=HTTP_CLIENT_MAX_SIZE)
31903191
app.add_routes([
31913192
web.post('/api/v1alpha/kill', self.kill),
@@ -3201,7 +3202,9 @@ async def run(self):
32013202

32023203
app_runner = web.AppRunner(app, access_log_class=BatchWorkerAccessLogger)
32033204
await app_runner.setup()
3204-
site = web.TCPSite(app_runner, IP_ADDRESS, 5000)
3205+
worker_ssl_context = await CLOUD_WORKER_API.worker_ssl_context(NAMESPACE)
3206+
port = 443 if worker_ssl_context is not None else 5000
3207+
site = web.TCPSite(app_runner, IP_ADDRESS, port, ssl_context=worker_ssl_context)
32053208
await site.start()
32063209

32073210
try:

0 commit comments

Comments
 (0)