|
2 | 2 | import logging |
3 | 3 | import traceback |
4 | 4 | import time |
| 5 | +import socket |
| 6 | +from contextlib import closing |
5 | 7 |
|
6 | 8 | from rafiki.db import Database |
7 | 9 | from rafiki.constants import ServiceStatus, UserType, ServiceType, BudgetType |
8 | | -from rafiki.config import MIN_SERVICE_PORT, MAX_SERVICE_PORT, \ |
9 | | - TRAIN_WORKER_REPLICAS_PER_SUB_TRAIN_JOB, INFERENCE_WORKER_REPLICAS_PER_TRIAL, \ |
| 10 | +from rafiki.config import TRAIN_WORKER_REPLICAS_PER_SUB_TRAIN_JOB, INFERENCE_WORKER_REPLICAS_PER_TRIAL, \ |
10 | 11 | INFERENCE_MAX_BEST_TRIALS, SERVICE_STATUS_WAIT |
11 | 12 | from rafiki.container import DockerSwarmContainerManager, ServiceRequirement, InvalidServiceRequest |
12 | 13 | from rafiki.model import parse_model_install_command |
@@ -122,9 +123,11 @@ def stop_train_services(self, train_job_id): |
122 | 123 | train_job = self._db.get_train_job(train_job_id) |
123 | 124 |
|
124 | 125 | # Stop all workers for train job |
125 | | - workers = self._db.get_workers_of_train_job(train_job_id) |
126 | | - for worker in workers: |
127 | | - self._stop_train_job_worker(worker) |
| 126 | + sub_train_jobs = self._db.get_sub_train_jobs_of_train_job(train_job_id) |
| 127 | + for sub_train_job in sub_train_jobs: |
| 128 | + workers = self._db.get_workers_of_sub_train_job(sub_train_job.id) |
| 129 | + for worker in workers: |
| 130 | + self._stop_train_job_worker(worker) |
128 | 131 |
|
129 | 132 | return train_job |
130 | 133 |
|
@@ -345,19 +348,13 @@ def _create_service(self, service_type, docker_image, |
345 | 348 |
|
346 | 349 | return service |
347 | 350 |
|
348 | | - # Compute next available external port |
349 | 351 | def _get_available_ext_port(self): |
350 | | - services = self._db.get_services(status=ServiceStatus.RUNNING) |
351 | | - used_ports = [int(x.ext_port) for x in services if x.ext_port is not None] |
352 | | - port = MIN_SERVICE_PORT |
353 | | - while port <= MAX_SERVICE_PORT: |
354 | | - if port not in used_ports: |
355 | | - return port |
356 | | - |
357 | | - port += 1 |
358 | | - |
359 | | - return port |
360 | | - |
| 352 | + # Credits to https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number |
| 353 | + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: |
| 354 | + s.bind(('', 0)) |
| 355 | + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 356 | + return s.getsockname()[1] |
| 357 | + |
361 | 358 | def _get_best_trials_for_inference(self, inference_job): |
362 | 359 | best_trials = self._db.get_best_trials_of_train_job(inference_job.train_job_id) |
363 | 360 | return best_trials |
|
0 commit comments