From d24b7c10dc4d786ed20fb081c66e6053b77c2749 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Mon, 26 May 2025 17:11:52 +0200 Subject: [PATCH 01/23] wip --- src/aleph/vm/models.py | 21 ++++ src/aleph/vm/network/firewall.py | 104 ++++++++++++++++++ .../vm/network/port_availability_checker.py | 41 +++++++ src/aleph/vm/orchestrator/views/__init__.py | 1 + src/aleph/vm/pool.py | 44 +++++++- 5 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 src/aleph/vm/network/port_availability_checker.py diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 880b8891..5e483849 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -30,7 +30,9 @@ AlephQemuConfidentialInstance, AlephQemuConfidentialResources, ) +from aleph.vm.network.firewall import add_port_redirect_rule from aleph.vm.network.interfaces import TapInterface +from aleph.vm.network.port_availability_checker import get_available_host_port from aleph.vm.orchestrator.metrics import ( ExecutionRecord, delete_record, @@ -92,6 +94,25 @@ class VmExecution: systemd_manager: SystemDManager | None persistent: bool = False + mapped_ports: list[tuple[int, int]] | None = None # internal, external + + def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]): + if self.mapped_ports is None: + self.mapped_ports = [] + assert self.vm + for requested_port, protocol_detail in requested_ports.items(): + tcp = protocol_detail["tcp"] + udp = protocol_detail["udp"] + host_port = get_available_host_port(start_port=25000) + interface = self.vm.tap_interface + vm_id = self.vm.vm_id + if tcp: + protocol = "tcp" + add_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) + if udp: + protocol = "udp" + add_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) + self.mapped_ports.append((host_port, requested_port)) @property def is_starting(self) -> bool: diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index eedc6ae1..fd0399b4 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -376,6 +376,110 @@ def add_forward_rule_to_external(vm_id: int, interface: TapInterface) -> int: return execute_json_nft_commands(command) +def add_port_redirect_rule( + vm_id: int, interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp" +) -> int: + """Creates a rule to redirect traffic from a host port to a VM port. + + Args: + vm_id: The ID of the VM + interface: The TapInterface instance for the VM + host_port: The port number on the host to listen on + vm_port: The port number to forward to on the VM + protocol: The protocol to use (tcp or udp, defaults to tcp) + + Returns: + The exit code from executing the nftables commands + """ + table = get_table_for_hook("prerouting") + + # FIXME DO NOT HARDCODE + prerouting_command = [ + { + "add": { + "chain": { + "family": "ip", + "table": "nat", + "name": "prerouting", + "type": "nat", + "hook": "prerouting", + "prio": -100, + "policy": "accept", + } + } + } + ] + + commands = prerouting_command + [ + { + "add": { + "rule": { + "family": "ip", + "table": "nat", + # "chain": f"{settings.NFTABLES_CHAIN_PREFIX}-vm-nat-{vm_id}", + "chain": "prerouting", + "expr": [ + { + "match": { + "op": "==", + "left": {"meta": {"key": "iifname"}}, + "right": settings.NETWORK_INTERFACE, + } + }, + { + "match": { + "op": "==", + "left": {"payload": {"protocol": "tcp", "field": "dport"}}, + "right": host_port, + } + }, + { + "dnat": {"addr": str(interface.guest_ip.ip), "port": vm_port}, + }, + ], + } + } + }, + ] + # Add corresponding forward rule to allow the redirected traffic + # { + # "add": { + # "rule": { + # "family": "ip", + # "table": table, + # "chain": f"{settings.NFTABLES_CHAIN_PREFIX}-vm-filter-{vm_id}", + # "expr": [ + # { + # "match": { + # "op": "==", + # "left": {"meta": {"key": "iifname"}}, + # "right": settings.NETWORK_INTERFACE, + # } + # }, + # { + # "match": { + # "op": "==", + # "left": {"meta": {"key": "l4proto"}}, + # "right": protocol, + # } + # }, + # { + # "match": { + # "op": "==", + # "left": {"payload": {"protocol": protocol, "field": "dport"}}, + # "right": vm_port, + # } + # }, + # {"accept": None}, + # ], + # } + # } + # }, + # ] + + return execute_json_nft_commands(commands) + + def setup_nftables_for_vm(vm_id: int, interface: TapInterface) -> None: """Sets up chains for filter and nat purposes specific to this VM, and makes sure those chains are jumped to""" add_postrouting_chain(f"{settings.NFTABLES_CHAIN_PREFIX}-vm-nat-{vm_id}") diff --git a/src/aleph/vm/network/port_availability_checker.py b/src/aleph/vm/network/port_availability_checker.py new file mode 100644 index 00000000..46008e8d --- /dev/null +++ b/src/aleph/vm/network/port_availability_checker.py @@ -0,0 +1,41 @@ +import socket +from typing import Optional + +MIN_DYNAMIC_PORT = 24000 +MAX_PORT = 65535 + + +def get_available_host_port(start_port: Optional[int] = None) -> int: + """Find an available port on the host system. + + Args: + start_port: Optional starting port number. If not provided, starts from MIN_DYNAMIC_PORT + + Returns: + An available port number + + Raises: + RuntimeError: If no ports are available in the valid range + """ + port = start_port if start_port and start_port >= MIN_DYNAMIC_PORT else MIN_DYNAMIC_PORT + + while port <= MAX_PORT: + try: + # Try both TCP and UDP on all interfaces + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_sock: + tcp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + tcp_sock.bind(("0.0.0.0", port)) + tcp_sock.listen(1) + + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as udp_sock: + udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + udp_sock.bind(("0.0.0.0", port)) + + return port + + except (socket.error, OSError): + pass + + port += 1 + + raise RuntimeError(f"No available ports found in range {MIN_DYNAMIC_PORT}-{MAX_PORT}") diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index e9bfff97..440a0a62 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -187,6 +187,7 @@ async def list_executions_v2(request: web.Request) -> web.Response: "ipv4_network": execution.vm.tap_interface.ip_network, "ipv6_network": execution.vm.tap_interface.ipv6_network, "ipv6_ip": execution.vm.tap_interface.guest_ipv6.ip, + "mapped_ports": execution.mapped_ports, } if execution.vm and execution.vm.tap_interface else {}, diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 17a8d386..4eac4ef3 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -11,7 +11,6 @@ from aleph_message.models import ( Chain, ExecutableMessage, - InstanceContent, ItemHash, Payment, PaymentType, @@ -33,6 +32,36 @@ logger = logging.getLogger(__name__) +async def get_user_aggregate(addr: str, keys_arg: list[str]) -> dict: + """ + Get the settings Aggregate dict from the PyAleph API Aggregate. + + API Endpoint: + GET /api/v0/aggregates/{address}.json?keys=settings + + For more details, see the PyAleph API documentation: + https://github.com/aleph-im/pyaleph/blob/master/src/aleph/web/controllers/routes.py#L62 + """ + + import aiohttp + + async with aiohttp.ClientSession() as session: + url = f"{settings.API_SERVER}/api/v0/aggregates/{addr}.json" + logger.info(f"Fetching aggregate from {url}") + resp = await session.get(url, params={"keys": ",".join(keys_arg)}) + + # Raise an error if the request failed + resp.raise_for_status() + + resp_data = await resp.json() + return resp_data["data"] or {} + + +async def get_user_settings(addr: str, key) -> dict: + aggregate = await get_user_aggregate(addr, [key]) + return aggregate.get(key, {}) + + class VmPool: """Pool of existing VMs @@ -168,11 +197,24 @@ async def create_a_vm( if self.network.interface_exists(vm_id): await tap_interface.delete() await self.network.create_tap(vm_id, tap_interface) + else: tap_interface = None execution.create(vm_id=vm_id, tap_interface=tap_interface) await execution.start() + try: + port_forwarding_settings = await get_user_settings(message.address, "port-forwarding") + port_requests = port_forwarding_settings.get(execution.vm_hash) + except Exception as e: + logger.exception(e) + + ports_requests = { + 22: {"tcp": True, "udp": False}, + } + execution.map_requested_ports(ports_requests) + # port = get_available_host_port(start_port=25000) + # add_port_redirect_rule(vm_id, interface, port, 22) # clear the user reservations for resource in resources: From 6bc897ae3b1727a72e264cd5a4b5b44e6132d1d9 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Tue, 27 May 2025 16:59:24 +0200 Subject: [PATCH 02/23] Fix started_at not being set for program --- src/aleph/vm/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 5e483849..38af06cc 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -371,7 +371,8 @@ async def start(self): if self.vm and self.vm.support_snapshot and self.snapshot_manager: await self.snapshot_manager.start_for(vm=self.vm) - + else: + self.times.started_at = datetime.now(tz=timezone.utc) self.ready_event.set() await self.save() except Exception: From 44f343ae093c67a79023754bd80037361223a8d0 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Tue, 27 May 2025 17:00:00 +0200 Subject: [PATCH 03/23] DB: fix execution.save crashing when updating --- src/aleph/vm/models.py | 82 +++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 50 deletions(-) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 38af06cc..cfd3a29a 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -95,6 +95,7 @@ class VmExecution: persistent: bool = False mapped_ports: list[tuple[int, int]] | None = None # internal, external + record: ExecutionRecord | None = None def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]): if self.mapped_ports is None: @@ -519,57 +520,38 @@ async def save(self): """Save to DB""" assert self.vm, "The VM attribute has to be set before calling save()" - pid_info = self.vm.to_dict() if self.vm else None - # Handle cases when the process cannot be accessed - if not self.persistent and pid_info and pid_info.get("process"): - await save_record( - ExecutionRecord( - uuid=str(self.uuid), - vm_hash=self.vm_hash, - vm_id=self.vm_id, - time_defined=self.times.defined_at, - time_prepared=self.times.prepared_at, - time_started=self.times.started_at, - time_stopping=self.times.stopping_at, - cpu_time_user=pid_info["process"]["cpu_times"].user, - cpu_time_system=pid_info["process"]["cpu_times"].system, - io_read_count=pid_info["process"]["io_counters"][0], - io_write_count=pid_info["process"]["io_counters"][1], - io_read_bytes=pid_info["process"]["io_counters"][2], - io_write_bytes=pid_info["process"]["io_counters"][3], - vcpus=self.vm.hardware_resources.vcpus, - memory=self.vm.hardware_resources.memory, - network_tap=self.vm.tap_interface.device_name if self.vm.tap_interface else "", - message=self.message.model_dump_json(), - original_message=self.original.model_dump_json(), - persistent=self.persistent, - ) - ) - else: - # The process cannot be accessed, or it's a persistent VM. - await save_record( - ExecutionRecord( - uuid=str(self.uuid), - vm_hash=self.vm_hash, - vm_id=self.vm_id, - time_defined=self.times.defined_at, - time_prepared=self.times.prepared_at, - time_started=self.times.started_at, - time_stopping=self.times.stopping_at, - cpu_time_user=None, - cpu_time_system=None, - io_read_count=None, - io_write_count=None, - io_read_bytes=None, - io_write_bytes=None, - vcpus=self.vm.hardware_resources.vcpus, - memory=self.vm.hardware_resources.memory, - message=self.message.model_dump_json(), - original_message=self.original.model_dump_json(), - persistent=self.persistent, - gpus=json.dumps(self.gpus, default=pydantic_encoder), - ) + if not self.record: + self.record = ExecutionRecord( + uuid=str(self.uuid), + vm_hash=self.vm_hash, + vm_id=self.vm_id, + time_defined=self.times.defined_at, + time_prepared=self.times.prepared_at, + time_started=self.times.started_at, + time_stopping=self.times.stopping_at, + cpu_time_user=None, + cpu_time_system=None, + io_read_count=None, + io_write_count=None, + io_read_bytes=None, + io_write_bytes=None, + vcpus=self.vm.hardware_resources.vcpus, + memory=self.vm.hardware_resources.memory, + message=self.message.model_dump_json(), + original_message=self.original.model_dump_json(), + persistent=self.persistent, + gpus=json.dumps(self.gpus, default=pydantic_encoder), ) + pid_info = self.vm.to_dict() if self.vm else None + # Handle cases when the process cannot be accessed + if not self.persistent and pid_info and pid_info.get("process"): + self.record.cpu_time_user = pid_info["process"]["cpu_times"].user + self.record.cpu_time_system = pid_info["process"]["cpu_times"].system + self.record.io_read_count = pid_info["process"]["io_counters"][0] + self.record.io_write_count = pid_info["process"]["io_counters"][1] + self.record.io_read_bytes = pid_info["process"]["io_counters"][2] + self.record.io_write_bytes = pid_info["process"]["io_counters"][3] + await save_record(self.record) async def record_usage(self): await delete_record(execution_uuid=str(self.uuid)) From 049bf3da61851dd29eed5bc4dad9611c85e8308b Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Tue, 27 May 2025 17:00:27 +0200 Subject: [PATCH 04/23] Handle user aggregate not existing gracefully --- src/aleph/vm/pool.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 4eac4ef3..42cc4690 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -6,7 +6,7 @@ import shutil from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import Any, cast +from typing import Any from aleph_message.models import ( Chain, @@ -49,8 +49,13 @@ async def get_user_aggregate(addr: str, keys_arg: list[str]) -> dict: url = f"{settings.API_SERVER}/api/v0/aggregates/{addr}.json" logger.info(f"Fetching aggregate from {url}") resp = await session.get(url, params={"keys": ",".join(keys_arg)}) + # No aggregate for the user + if resp.status == 404: + return {} # Raise an error if the request failed + + resp.raise_for_status() resp_data = await resp.json() From d4d52d467bfbc0023a7e904d6bb1e67ef507e2cf Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Thu, 29 May 2025 23:21:53 +0200 Subject: [PATCH 05/23] Fix alembic crash when generating migrations --- src/aleph/vm/conf.py | 4 ++++ src/aleph/vm/orchestrator/migrations/env.py | 7 ++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/aleph/vm/conf.py b/src/aleph/vm/conf.py index b422aea6..d110046c 100644 --- a/src/aleph/vm/conf.py +++ b/src/aleph/vm/conf.py @@ -529,5 +529,9 @@ def make_db_url(): return f"sqlite+aiosqlite:///{settings.EXECUTION_DATABASE}" +def make_sync_db_url(): + return f"sqlite:///{settings.EXECUTION_DATABASE}" + + # Settings singleton settings = Settings() diff --git a/src/aleph/vm/orchestrator/migrations/env.py b/src/aleph/vm/orchestrator/migrations/env.py index 2cf116bb..bd5eadc7 100644 --- a/src/aleph/vm/orchestrator/migrations/env.py +++ b/src/aleph/vm/orchestrator/migrations/env.py @@ -1,7 +1,7 @@ from alembic import context from sqlalchemy import create_engine -from aleph.vm.conf import make_db_url +from aleph.vm.conf import make_sync_db_url # Auto-generate migrations from aleph.vm.orchestrator.metrics import Base @@ -36,7 +36,8 @@ def run_migrations_offline() -> None: script output. """ - url = make_db_url() + url = make_sync_db_url() + context.configure( url=url, target_metadata=target_metadata, @@ -55,7 +56,7 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = create_engine(make_db_url()) + connectable = create_engine(make_sync_db_url()) with connectable.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) From f10a6274cc5d13c0de44e416db6b7fda84ae6dc6 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 09:24:39 +0200 Subject: [PATCH 06/23] SAve port mapping to database --- src/aleph/vm/models.py | 6 +++-- src/aleph/vm/orchestrator/metrics.py | 1 + .../2da719d72cea_add_mapped_ports_column.py | 24 +++++++++++++++++++ src/aleph/vm/pool.py | 4 +--- 4 files changed, 30 insertions(+), 5 deletions(-) create mode 100644 src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index cfd3a29a..9ceb71e0 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -97,7 +97,7 @@ class VmExecution: mapped_ports: list[tuple[int, int]] | None = None # internal, external record: ExecutionRecord | None = None - def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]): + async def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]): if self.mapped_ports is None: self.mapped_ports = [] assert self.vm @@ -114,6 +114,8 @@ def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]): protocol = "udp" add_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) self.mapped_ports.append((host_port, requested_port)) + self.record.mapped_ports = self.mapped_ports + await save_record(self.record) @property def is_starting(self) -> bool: @@ -521,7 +523,7 @@ async def save(self): assert self.vm, "The VM attribute has to be set before calling save()" if not self.record: - self.record = ExecutionRecord( + self.record = ExecutionRecord( uuid=str(self.uuid), vm_hash=self.vm_hash, vm_id=self.vm_id, diff --git a/src/aleph/vm/orchestrator/metrics.py b/src/aleph/vm/orchestrator/metrics.py index 6c9b8eea..bc2bd305 100644 --- a/src/aleph/vm/orchestrator/metrics.py +++ b/src/aleph/vm/orchestrator/metrics.py @@ -77,6 +77,7 @@ class ExecutionRecord(Base): persistent = Column(Boolean, nullable=True) gpus = Column(JSON, nullable=True) + mapped_ports = Column(JSON, nullable=True) def __repr__(self): return f"" diff --git a/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py b/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py new file mode 100644 index 00000000..a4e8ee7e --- /dev/null +++ b/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py @@ -0,0 +1,24 @@ +"""add mapped_ports column + +Revision ID: 2da719d72cea +Revises: 5c6ae643c69b +Create Date: 2025-05-29 23:20:42.801850 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "2da719d72cea" +down_revision = "5c6ae643c69b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column("executions", sa.Column("mapped_ports", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("executions", "mapped_ports") diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 42cc4690..54c876ab 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -51,11 +51,9 @@ async def get_user_aggregate(addr: str, keys_arg: list[str]) -> dict: resp = await session.get(url, params={"keys": ",".join(keys_arg)}) # No aggregate for the user if resp.status == 404: - return {} # Raise an error if the request failed - resp.raise_for_status() resp_data = await resp.json() @@ -217,7 +215,7 @@ async def create_a_vm( ports_requests = { 22: {"tcp": True, "udp": False}, } - execution.map_requested_ports(ports_requests) + await execution.map_requested_ports(ports_requests) # port = get_available_host_port(start_port=25000) # add_port_redirect_rule(vm_id, interface, port, 22) From a82b9b0c64f365e8cf19522eb895dda448a4dcfd Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 13:41:26 +0200 Subject: [PATCH 07/23] firewall Do not create prerouting chain if it already exists --- src/aleph/vm/network/firewall.py | 152 ++++++++++++++++++------------- 1 file changed, 91 insertions(+), 61 deletions(-) diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index fd0399b4..f683eb08 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -5,8 +5,7 @@ from nftables import Nftables from aleph.vm.conf import settings - -from .interfaces import TapInterface +from aleph.vm.network.interfaces import TapInterface logger = logging.getLogger(__name__) @@ -37,19 +36,16 @@ def execute_json_nft_commands(commands: list[dict]) -> int: if return_code != 0: logger.error("Failed to add nftables rules: %s -- %s", error, json.dumps(commands, indent=4)) - return return_code + return output def get_existing_nftables_ruleset() -> dict: """Retrieves the full nftables ruleset and returns it""" nft = get_customized_nftables() - return_code, output, error = nft.cmd("list ruleset") - - if return_code != 0: - logger.error(f"Unable to get nftables ruleset: {error}") - return {"nftables": []} + # List all NAT rules + commands = [{"list": {"ruleset": {"family": "ip", "table": "nat"}}}] - nft_ruleset = json.loads(output) + nft_ruleset = execute_json_nft_commands(commands) return nft_ruleset @@ -376,25 +372,18 @@ def add_forward_rule_to_external(vm_id: int, interface: TapInterface) -> int: return execute_json_nft_commands(command) -def add_port_redirect_rule( - vm_id: int, interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp" -) -> int: - """Creates a rule to redirect traffic from a host port to a VM port. - - Args: - vm_id: The ID of the VM - interface: The TapInterface instance for the VM - host_port: The port number on the host to listen on - vm_port: The port number to forward to on the VM - protocol: The protocol to use (tcp or udp, defaults to tcp) +def add_prerouting_chain() -> int: + """Creates the prerouting chain if it doesn't exist already. Returns: - The exit code from executing the nftables commands + int: The exit code from executing the nftables commands """ - table = get_table_for_hook("prerouting") + # Check if prerouting chain exists by looking for chains with prerouting hook + existing_chains = get_base_chains_for_hook("prerouting", "ip") + if existing_chains: + return 0 # Chain already exists, nothing to do - # FIXME DO NOT HARDCODE - prerouting_command = [ + commands = [ { "add": { "chain": { @@ -409,14 +398,32 @@ def add_port_redirect_rule( } } ] + return execute_json_nft_commands(commands) + + +def add_port_redirect_rule( + vm_id: int, interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp" +) -> int: + """Creates a rule to redirect traffic from a host port to a VM port. + + Args: + vm_id: The ID of the VM + interface: The TapInterface instance for the VM + host_port: The port number on the host to listen on + vm_port: The port number to forward to on the VM + protocol: The protocol to use (tcp or udp, defaults to tcp) - commands = prerouting_command + [ + Returns: + The exit code from executing the nftables commands + """ + # table = get_table_for_hook("prerouting") + add_prerouting_chain() + commands = [ { "add": { "rule": { "family": "ip", "table": "nat", - # "chain": f"{settings.NFTABLES_CHAIN_PREFIX}-vm-nat-{vm_id}", "chain": "prerouting", "expr": [ { @@ -441,41 +448,64 @@ def add_port_redirect_rule( } }, ] - # Add corresponding forward rule to allow the redirected traffic - # { - # "add": { - # "rule": { - # "family": "ip", - # "table": table, - # "chain": f"{settings.NFTABLES_CHAIN_PREFIX}-vm-filter-{vm_id}", - # "expr": [ - # { - # "match": { - # "op": "==", - # "left": {"meta": {"key": "iifname"}}, - # "right": settings.NETWORK_INTERFACE, - # } - # }, - # { - # "match": { - # "op": "==", - # "left": {"meta": {"key": "l4proto"}}, - # "right": protocol, - # } - # }, - # { - # "match": { - # "op": "==", - # "left": {"payload": {"protocol": protocol, "field": "dport"}}, - # "right": vm_port, - # } - # }, - # {"accept": None}, - # ], - # } - # } - # }, - # ] + + return execute_json_nft_commands(commands) + + +def remove_port_redirect_rule( + vm_id: int, interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp" +) -> int: + """Removes a rule that redirects traffic from a host port to a VM port. + + Args: + vm_id: The ID of the VM + interface: The TapInterface instance for the VM + host_port: The port number on the host that was being listened on + vm_port: The port number that was being forwarded to on the VM + protocol: The protocol that was used (tcp or udp) + + Returns: + The exit code from executing the nftables commands + """ + nft_ruleset = get_existing_nftables_ruleset() + commands = [] + + for entry in nft_ruleset["nftables"]: + if ( + isinstance(entry, dict) + and "rule" in entry + and entry["rule"].get("family") == "ip" + and entry["rule"].get("table") == "nat" + and entry["rule"].get("chain") == "prerouting" + and "expr" in entry["rule"] + ): + expr = entry["rule"]["expr"] + # Check if this is our redirect rule by matching all conditions + if ( + len(expr) == 3 + and "match" in expr[0] + and expr[0]["match"]["left"].get("meta", {}).get("key") == "iifname" + and expr[0]["match"]["right"] == settings.NETWORK_INTERFACE + and "match" in expr[1] + and expr[1]["match"]["left"].get("payload", {}).get("protocol") == protocol + and expr[1]["match"]["left"]["payload"].get("field") == "dport" + and expr[1]["match"]["right"] == host_port + and "dnat" in expr[2] + and expr[2]["dnat"].get("addr") == str(interface.guest_ip.ip) + and expr[2]["dnat"].get("port") == vm_port + ): + commands.append( + { + "delete": { + "rule": { + "family": "ip", + "table": "nat", + "chain": "prerouting", + "handle": entry["rule"]["handle"], + } + } + } + ) return execute_json_nft_commands(commands) From 56b3bef8e4d49fbf98793ffbfa2a93080d0e9e16 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 13:41:55 +0200 Subject: [PATCH 08/23] fix protocol specification --- src/aleph/vm/network/firewall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index f683eb08..b9be821d 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -436,7 +436,7 @@ def add_port_redirect_rule( { "match": { "op": "==", - "left": {"payload": {"protocol": "tcp", "field": "dport"}}, + "left": {"payload": {"protocol": protocol, "field": "dport"}}, "right": host_port, } }, From 09b65999013c7e7e4cd756143f46cf778fa0428f Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 13:42:24 +0200 Subject: [PATCH 09/23] Do not assign an already forward port, even if nothing is listening on it yet --- src/aleph/vm/network/firewall.py | 44 +++++++++++++++++++ .../vm/network/port_availability_checker.py | 13 +++--- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index b9be821d..25496e3d 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -522,3 +522,47 @@ def teardown_nftables_for_vm(vm_id: int) -> None: """Remove all nftables rules related to the specified VM""" remove_chain(f"{settings.NFTABLES_CHAIN_PREFIX}-vm-nat-{vm_id}") remove_chain(f"{settings.NFTABLES_CHAIN_PREFIX}-vm-filter-{vm_id}") + + +def check_nftables_redirections(port: int) -> bool: + """Check if there are any NAT rules redirecting to the given port. + + Args: + port: The port number to check + + Returns: + True if the port is being used in any NAT redirection rules + """ + try: + rules = get_existing_nftables_ruleset() + # Navigate through the JSON structure + nftables = rules.get("nftables", []) + + for item in nftables: + if "rule" not in item: + continue + + rule = item["rule"] + if "expr" not in rule: + continue + + for expr in rule["expr"]: + # Check destination port in matches + if "match" in expr: + match = expr["match"] + if "left" in match and "payload" in match["left"]: + payload = match["left"]["payload"] + if payload.get("field") == "dport" and match.get("right") == port: + return True + + # Check port in dnat rules + if "dnat" in expr: + dnat = expr["dnat"] + if dnat.get("port") == port: + return True + + return False + + except Exception as e: + logger.warning(f"Error checking NAT redirections: {e}") + return False # Assume no redirections in case of error diff --git a/src/aleph/vm/network/port_availability_checker.py b/src/aleph/vm/network/port_availability_checker.py index 46008e8d..7d8137e2 100644 --- a/src/aleph/vm/network/port_availability_checker.py +++ b/src/aleph/vm/network/port_availability_checker.py @@ -1,6 +1,8 @@ import socket from typing import Optional +from aleph.vm.network.firewall import check_nftables_redirections + MIN_DYNAMIC_PORT = 24000 MAX_PORT = 65535 @@ -17,10 +19,13 @@ def get_available_host_port(start_port: Optional[int] = None) -> int: Raises: RuntimeError: If no ports are available in the valid range """ - port = start_port if start_port and start_port >= MIN_DYNAMIC_PORT else MIN_DYNAMIC_PORT - - while port <= MAX_PORT: + start_port = start_port if start_port and start_port >= MIN_DYNAMIC_PORT else MIN_DYNAMIC_PORT + for port in range(start_port, MAX_PORT): try: + print(port) + # check if there is already a redirect to that port + if check_nftables_redirections(port): + continue # Try both TCP and UDP on all interfaces with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_sock: tcp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -36,6 +41,4 @@ def get_available_host_port(start_port: Optional[int] = None) -> int: except (socket.error, OSError): pass - port += 1 - raise RuntimeError(f"No available ports found in range {MIN_DYNAMIC_PORT}-{MAX_PORT}") From c2efaf4723c7d0507ea2ac2e7cef45c34da1809f Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 13:43:33 +0200 Subject: [PATCH 10/23] fix db migration --- .../2da719d72cea_add_mapped_ports_column.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py b/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py index a4e8ee7e..a8f7df0d 100644 --- a/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py +++ b/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py @@ -9,6 +9,12 @@ import sqlalchemy as sa from alembic import op +# revision identifiers, used by Alembic. +from sqlalchemy import create_engine +from sqlalchemy.engine import reflection + +from aleph.vm.conf import make_db_url + # revision identifiers, used by Alembic. revision = "2da719d72cea" down_revision = "5c6ae643c69b" @@ -17,7 +23,16 @@ def upgrade() -> None: - op.add_column("executions", sa.Column("mapped_ports", sa.JSON(), nullable=True)) + engine = create_engine(make_db_url()) + inspector = reflection.Inspector.from_engine(engine) + + # The table already exists on most CRNs. + tables = inspector.get_table_names() + if "executions" in tables: + columns = inspector.get_columns("executions") + column_names = [c["name"] for c in columns] + if "mapped_ports" not in column_names: + op.add_column("executions", sa.Column("mapped_ports", sa.JSON(), nullable=True)) def downgrade() -> None: From 7992d8fb5cefc557b542ce02d7e438a33a5ea0f7 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 13:44:43 +0200 Subject: [PATCH 11/23] fix mapped_port_format --- src/aleph/vm/models.py | 8 ++++---- src/aleph/vm/pool.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 9ceb71e0..acec0c99 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -94,12 +94,10 @@ class VmExecution: systemd_manager: SystemDManager | None persistent: bool = False - mapped_ports: list[tuple[int, int]] | None = None # internal, external + mapped_ports: dict[int, dict] # Port redirect to the VM record: ExecutionRecord | None = None async def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]): - if self.mapped_ports is None: - self.mapped_ports = [] assert self.vm for requested_port, protocol_detail in requested_ports.items(): tcp = protocol_detail["tcp"] @@ -113,7 +111,8 @@ async def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]) if udp: protocol = "udp" add_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) - self.mapped_ports.append((host_port, requested_port)) + + self.mapped_ports[requested_port] = {"host": host_port, **protocol_detail} self.record.mapped_ports = self.mapped_ports await save_record(self.record) @@ -214,6 +213,7 @@ def __init__( self.snapshot_manager = snapshot_manager self.systemd_manager = systemd_manager self.persistent = persistent + self.mapped_ports = {} def to_dict(self) -> dict: return { diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 54c876ab..293ecf15 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -326,6 +326,7 @@ async def load_persistent_executions(self): if saved_execution.gpus else [] ) + execution.mapped_ports = saved_execution.mapped_ports # Load and instantiate the rest of resources and already assigned GPUs await execution.prepare() if self.network: From ce6b4e94cb7344e2fb0420e5bd92e46f0a4f146b Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 13:45:07 +0200 Subject: [PATCH 12/23] add remove port redirection method --- src/aleph/vm/models.py | 21 ++++++++++++++++++++- src/aleph/vm/pool.py | 11 +++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index acec0c99..5b5cf7ea 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -30,7 +30,7 @@ AlephQemuConfidentialInstance, AlephQemuConfidentialResources, ) -from aleph.vm.network.firewall import add_port_redirect_rule +from aleph.vm.network.firewall import add_port_redirect_rule, remove_port_redirect_rule from aleph.vm.network.interfaces import TapInterface from aleph.vm.network.port_availability_checker import get_available_host_port from aleph.vm.orchestrator.metrics import ( @@ -116,6 +116,23 @@ async def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]) self.record.mapped_ports = self.mapped_ports await save_record(self.record) + async def removed_ports_redirection(self): + for requested_port, map_detail in self.mapped_ports.items(): + tcp = map_detail["tcp"] + udp = map_detail["udp"] + host_port = map_detail["host"] + + interface = self.vm.tap_interface + vm_id = self.vm.vm_id + if tcp: + protocol = "tcp" + remove_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) + if udp: + protocol = "tcp" + remove_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) + + del self.mapped_ports[requested_port] + @property def is_starting(self) -> bool: return bool(self.times.starting_at and not self.times.started_at and not self.times.stopping_at) @@ -481,6 +498,8 @@ async def stop(self) -> None: await self.all_runs_complete() await self.record_usage() await self.vm.teardown() + await self.removed_ports_redirection() + self.times.stopped_at = datetime.now(tz=timezone.utc) self.cancel_expiration() self.cancel_update() diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 293ecf15..2b2f8ca3 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -208,16 +208,12 @@ async def create_a_vm( await execution.start() try: port_forwarding_settings = await get_user_settings(message.address, "port-forwarding") - port_requests = port_forwarding_settings.get(execution.vm_hash) + ports_requests = port_forwarding_settings.get(execution.vm_hash) except Exception as e: + ports_requests = {} logger.exception(e) - ports_requests = { - 22: {"tcp": True, "udp": False}, - } await execution.map_requested_ports(ports_requests) - # port = get_available_host_port(start_port=25000) - # add_port_redirect_rule(vm_id, interface, port, 22) # clear the user reservations for resource in resources: @@ -225,7 +221,10 @@ async def create_a_vm( del self.reservations[resource] except Exception: # ensure the VM is removed from the pool on creation error + await execution.removed_ports_redirection() self.forget_vm(vm_hash) + + raise self._schedule_forget_on_stop(execution) From 7671b92e4b08be29d964ff3257ff35463a927fed Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 14:25:57 +0200 Subject: [PATCH 13/23] fixup! Do not assign an already forward port, even if nothing is listening on it yet --- src/aleph/vm/network/port_availability_checker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/aleph/vm/network/port_availability_checker.py b/src/aleph/vm/network/port_availability_checker.py index 7d8137e2..da7cba43 100644 --- a/src/aleph/vm/network/port_availability_checker.py +++ b/src/aleph/vm/network/port_availability_checker.py @@ -22,7 +22,6 @@ def get_available_host_port(start_port: Optional[int] = None) -> int: start_port = start_port if start_port and start_port >= MIN_DYNAMIC_PORT else MIN_DYNAMIC_PORT for port in range(start_port, MAX_PORT): try: - print(port) # check if there is already a redirect to that port if check_nftables_redirections(port): continue From 6c1103a73d75fd33d5af32107c8dfa0738816220 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 14:26:34 +0200 Subject: [PATCH 14/23] Update ports --- src/aleph/vm/models.py | 74 +++++++++++++++++++------------- src/aleph/vm/network/firewall.py | 15 +++---- src/aleph/vm/pool.py | 8 ++-- 3 files changed, 55 insertions(+), 42 deletions(-) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 5b5cf7ea..110f53b4 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -45,6 +45,8 @@ from aleph.vm.systemd import SystemDManager from aleph.vm.utils import create_task_log_exceptions, dumps_for_json, is_pinging +SUPPORTED_PROTOCOL_FOR_REDIRECT = ["udp", "tcp"] + logger = logging.getLogger(__name__) @@ -97,41 +99,55 @@ class VmExecution: mapped_ports: dict[int, dict] # Port redirect to the VM record: ExecutionRecord | None = None - async def map_requested_ports(self, requested_ports: dict[int, dict[str, bool]]): + async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool]]): assert self.vm - for requested_port, protocol_detail in requested_ports.items(): - tcp = protocol_detail["tcp"] - udp = protocol_detail["udp"] + redirect_to_remove = set(self.mapped_ports.keys()) - set(requested_ports.keys()) + redirect_to_add = set(requested_ports.keys()) - set(self.mapped_ports.keys()) + redirect_to_check = set(requested_ports.keys()).intersection(set(self.mapped_ports.keys())) + interface = self.vm.tap_interface + + for vm_port in redirect_to_remove: + current = self.mapped_ports[vm_port] + + for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: + if current[protocol]: + host_port = current["host"] + remove_port_redirect_rule(interface, host_port, vm_port, protocol) + del self.mapped_ports[vm_port] + + for vm_port in redirect_to_add: + target = requested_ports[vm_port] host_port = get_available_host_port(start_port=25000) - interface = self.vm.tap_interface - vm_id = self.vm.vm_id - if tcp: - protocol = "tcp" - add_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) - if udp: - protocol = "udp" - add_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) - - self.mapped_ports[requested_port] = {"host": host_port, **protocol_detail} + for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: + if target[protocol]: + add_port_redirect_rule(interface, host_port, vm_port, protocol) + self.mapped_ports[vm_port] = {"host": host_port, **target} + + for vm_port in redirect_to_check: + current = self.mapped_ports[vm_port] + target = requested_ports[vm_port] + host_port = current["host"] + for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: + if current[protocol] != target[protocol]: + if target[protocol]: + add_port_redirect_rule(interface, host_port, vm_port, protocol) + else: + remove_port_redirect_rule(interface, host_port, vm_port, protocol) + self.mapped_ports[vm_port] = {"host": host_port, **target} + + # Save to DB self.record.mapped_ports = self.mapped_ports await save_record(self.record) - async def removed_ports_redirection(self): - for requested_port, map_detail in self.mapped_ports.items(): - tcp = map_detail["tcp"] - udp = map_detail["udp"] + async def removed_all_ports_redirection(self): + interface = self.vm.tap_interface + for vm_port, map_detail in self.mapped_ports.items(): host_port = map_detail["host"] + for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: + if map_detail[protocol]: + remove_port_redirect_rule(interface, host_port, vm_port, protocol) - interface = self.vm.tap_interface - vm_id = self.vm.vm_id - if tcp: - protocol = "tcp" - remove_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) - if udp: - protocol = "tcp" - remove_port_redirect_rule(vm_id, interface, host_port, requested_port, protocol) - - del self.mapped_ports[requested_port] + del self.mapped_ports[vm_port] @property def is_starting(self) -> bool: @@ -498,7 +514,7 @@ async def stop(self) -> None: await self.all_runs_complete() await self.record_usage() await self.vm.teardown() - await self.removed_ports_redirection() + await self.removed_all_ports_redirection() self.times.stopped_at = datetime.now(tz=timezone.utc) self.cancel_expiration() diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index 25496e3d..0954a3f1 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -401,9 +401,7 @@ def add_prerouting_chain() -> int: return execute_json_nft_commands(commands) -def add_port_redirect_rule( - vm_id: int, interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp" -) -> int: +def add_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp") -> int: """Creates a rule to redirect traffic from a host port to a VM port. Args: @@ -452,17 +450,14 @@ def add_port_redirect_rule( return execute_json_nft_commands(commands) -def remove_port_redirect_rule( - vm_id: int, interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp" -) -> int: +def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp") -> int: """Removes a rule that redirects traffic from a host port to a VM port. Args: - vm_id: The ID of the VM interface: The TapInterface instance for the VM - host_port: The port number on the host that was being listened on - vm_port: The port number that was being forwarded to on the VM - protocol: The protocol that was used (tcp or udp) + host_port: The port number on the host that is listened on + vm_port: The port number that is being forwarded to on the VM + protocol: The protocol used (tcp or udp) Returns: The exit code from executing the nftables commands diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 2b2f8ca3..657292ad 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -209,11 +209,13 @@ async def create_a_vm( try: port_forwarding_settings = await get_user_settings(message.address, "port-forwarding") ports_requests = port_forwarding_settings.get(execution.vm_hash) + ports_requests = {22: {"tcp": True, "udp": False}} + except Exception as e: ports_requests = {} logger.exception(e) - - await execution.map_requested_ports(ports_requests) + ports_requests = ports_requests or {} + await execution.update_port_redirects(ports_requests) # clear the user reservations for resource in resources: @@ -221,7 +223,7 @@ async def create_a_vm( del self.reservations[resource] except Exception: # ensure the VM is removed from the pool on creation error - await execution.removed_ports_redirection() + await execution.removed_all_ports_redirection() self.forget_vm(vm_hash) From 16d0fd8d6b70a6b8dc3aec11a0624fc00d740aca Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 14:29:26 +0200 Subject: [PATCH 15/23] Add /control/update endpoint to refresh port config add host ipv4 to v2 endpoint --- src/aleph/vm/models.py | 18 ++++++++- src/aleph/vm/network/get_interface_ipv4.py | 41 +++++++++++++++++++ src/aleph/vm/network/hostnetwork.py | 3 ++ src/aleph/vm/orchestrator/supervisor.py | 2 + src/aleph/vm/orchestrator/views/__init__.py | 18 +++++++++ src/aleph/vm/pool.py | 45 +-------------------- src/aleph/vm/utils/aggregate.py | 38 +++++++++++++++++ 7 files changed, 120 insertions(+), 45 deletions(-) create mode 100644 src/aleph/vm/network/get_interface_ipv4.py create mode 100644 src/aleph/vm/utils/aggregate.py diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 110f53b4..741af59f 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -44,6 +44,7 @@ from aleph.vm.resources import GpuDevice, HostGPU from aleph.vm.systemd import SystemDManager from aleph.vm.utils import create_task_log_exceptions, dumps_for_json, is_pinging +from aleph.vm.utils.aggregate import get_user_settings SUPPORTED_PROTOCOL_FOR_REDIRECT = ["udp", "tcp"] @@ -99,8 +100,24 @@ class VmExecution: mapped_ports: dict[int, dict] # Port redirect to the VM record: ExecutionRecord | None = None + async def fetch_port_redirect_config_and_setup(self): + message = self.message + try: + port_forwarding_settings = await get_user_settings(message.address, "port-forwarding") + ports_requests = port_forwarding_settings.get(self.vm_hash) + + except Exception: + ports_requests = {} + logger.exception("Could not fetch the port redirect settings for user") + if not ports_requests: + # FIXME DEBUG FOR NOW + ports_requests = {22: {"tcp": True, "udp": False}} + ports_requests = ports_requests or {} + await self.update_port_redirects(ports_requests) + async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool]]): assert self.vm + logger.info("Updating port redirect. Current %s, New %s", self.mapped_ports, requested_ports) redirect_to_remove = set(self.mapped_ports.keys()) - set(requested_ports.keys()) redirect_to_add = set(requested_ports.keys()) - set(self.mapped_ports.keys()) redirect_to_check = set(requested_ports.keys()).intersection(set(self.mapped_ports.keys())) @@ -108,7 +125,6 @@ async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool] for vm_port in redirect_to_remove: current = self.mapped_ports[vm_port] - for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: if current[protocol]: host_port = current["host"] diff --git a/src/aleph/vm/network/get_interface_ipv4.py b/src/aleph/vm/network/get_interface_ipv4.py new file mode 100644 index 00000000..fe0ea91e --- /dev/null +++ b/src/aleph/vm/network/get_interface_ipv4.py @@ -0,0 +1,41 @@ +import ipaddress + +import netifaces + + +def get_interface_ipv4(interface_name: str) -> str: + """ + Get the main IPv4 address from a network interface. + + Args: + interface_name: Name of the network interface (e.g., 'eth0', 'wlan0') + + Returns: + str: The IPv4 address of the interface + + Raises: + ValueError: If the interface doesn't exist or has no IPv4 address + """ + try: + # Check if the interface exists + if interface_name not in netifaces.interfaces(): + raise ValueError(f"Interface {interface_name} does not exist") + + # Get addresses for the interface + addrs = netifaces.ifaddresses(interface_name) + + # Check for IPv4 addresses (AF_INET is IPv4) + if netifaces.AF_INET not in addrs: + raise ValueError(f"No IPv4 address found for interface {interface_name}") + + # Get the first IPv4 address + ipv4_info = addrs[netifaces.AF_INET][0] + ipv4_addr = ipv4_info["addr"] + + # Validate that it's a proper IPv4 address + ipaddress.IPv4Address(ipv4_addr) + + return ipv4_addr + + except Exception as e: + raise ValueError(f"Error getting IPv4 address: {str(e)}") diff --git a/src/aleph/vm/network/hostnetwork.py b/src/aleph/vm/network/hostnetwork.py index e9fcaa79..68b05fa8 100644 --- a/src/aleph/vm/network/hostnetwork.py +++ b/src/aleph/vm/network/hostnetwork.py @@ -10,6 +10,7 @@ from aleph.vm.vm_type import VmType from .firewall import initialize_nftables, setup_nftables_for_vm, teardown_nftables +from .get_interface_ipv4 import get_interface_ipv4 from .interfaces import TapInterface from .ipaddresses import IPv4NetworkWithInterfaces from .ndp_proxy import NdpProxy @@ -115,6 +116,7 @@ class Network: ipv6_address_pool: IPv6Network network_size: int ndp_proxy: NdpProxy | None = None + host_ipv4: str IPV6_SUBNET_PREFIX: int = 124 @@ -138,6 +140,7 @@ def __init__( self.ipv6_forwarding_enabled = ipv6_forwarding_enabled self.use_ndp_proxy = use_ndp_proxy self.ndb = pyroute2.NDB() + self.host_ipv4 = get_interface_ipv4(external_interface) if not self.ipv4_address_pool.is_private: logger.warning(f"Using a network range that is not private: {self.ipv4_address_pool}") diff --git a/src/aleph/vm/orchestrator/supervisor.py b/src/aleph/vm/orchestrator/supervisor.py index 08735230..8b745a9f 100644 --- a/src/aleph/vm/orchestrator/supervisor.py +++ b/src/aleph/vm/orchestrator/supervisor.py @@ -36,6 +36,7 @@ list_executions_v2, notify_allocation, operate_reserve_resources, + operate_update, run_code_from_hostname, run_code_from_path, status_check_fastapi, @@ -134,6 +135,7 @@ def setup_webapp(pool: VmPool | None): # /control APIs are used to control the VMs and access their logs web.post("/control/allocation/notify", notify_allocation), web.post("/control/reserve_resources", operate_reserve_resources), + web.post("/control/machine/{ref}/update", operate_update), web.get("/control/machine/{ref}/stream_logs", stream_logs), web.get("/control/machine/{ref}/logs", operate_logs_json), web.post("/control/machine/{ref}/expire", operate_expire), diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index 440a0a62..b2c369e2 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -25,6 +25,7 @@ ) from aleph.vm.controllers.firecracker.program import FileTooLargeError from aleph.vm.hypervisors.firecracker.microvm import MicroVMFailedInitError +from aleph.vm.models import VmExecution from aleph.vm.orchestrator import payment, status from aleph.vm.orchestrator.chain import STREAM_CHAINS from aleph.vm.orchestrator.custom_logs import set_vm_for_logging @@ -55,6 +56,7 @@ check_host_egress_ipv4, check_host_egress_ipv6, ) +from aleph.vm.orchestrator.views.operator import get_itemhash_or_400 from aleph.vm.pool import VmPool from aleph.vm.utils import ( HostNotFoundError, @@ -185,6 +187,7 @@ async def list_executions_v2(request: web.Request) -> web.Response: item_hash: { "networking": { "ipv4_network": execution.vm.tap_interface.ip_network, + "host_ipv4": pool.network.host_ipv4, "ipv6_network": execution.vm.tap_interface.ipv6_network, "ipv6_ip": execution.vm.tap_interface.guest_ipv6.ip, "mapped_ports": execution.mapped_ports, @@ -684,3 +687,18 @@ async def operate_reserve_resources(request: web.Request, authenticated_sender: }, dumps=dumps_for_json, ) + + +@cors_allow_all +async def operate_update(request: web.Request, authenticated_sender: str) -> web.Response: + """Notify that the instance configuration has changed + + For now used to notify the CRN that port forwarding config has changed + and that the we should fetch the latest user aggregate""" + vm_hash = get_itemhash_or_400(request.match_info) + + pool: VmPool = request.app["vm_pool"] + execution: VmExecution = pool.executions.get(vm_hash) + await execution.fetch_port_redirect_config_and_setup() + + return web.json_response({}, dumps=dumps_for_json, status=200) diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 657292ad..d178129d 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -32,39 +32,6 @@ logger = logging.getLogger(__name__) -async def get_user_aggregate(addr: str, keys_arg: list[str]) -> dict: - """ - Get the settings Aggregate dict from the PyAleph API Aggregate. - - API Endpoint: - GET /api/v0/aggregates/{address}.json?keys=settings - - For more details, see the PyAleph API documentation: - https://github.com/aleph-im/pyaleph/blob/master/src/aleph/web/controllers/routes.py#L62 - """ - - import aiohttp - - async with aiohttp.ClientSession() as session: - url = f"{settings.API_SERVER}/api/v0/aggregates/{addr}.json" - logger.info(f"Fetching aggregate from {url}") - resp = await session.get(url, params={"keys": ",".join(keys_arg)}) - # No aggregate for the user - if resp.status == 404: - return {} - # Raise an error if the request failed - - resp.raise_for_status() - - resp_data = await resp.json() - return resp_data["data"] or {} - - -async def get_user_settings(addr: str, key) -> dict: - aggregate = await get_user_aggregate(addr, [key]) - return aggregate.get(key, {}) - - class VmPool: """Pool of existing VMs @@ -206,16 +173,7 @@ async def create_a_vm( execution.create(vm_id=vm_id, tap_interface=tap_interface) await execution.start() - try: - port_forwarding_settings = await get_user_settings(message.address, "port-forwarding") - ports_requests = port_forwarding_settings.get(execution.vm_hash) - ports_requests = {22: {"tcp": True, "udp": False}} - - except Exception as e: - ports_requests = {} - logger.exception(e) - ports_requests = ports_requests or {} - await execution.update_port_redirects(ports_requests) + await execution.fetch_port_redirect_config_and_setup() # clear the user reservations for resource in resources: @@ -226,7 +184,6 @@ async def create_a_vm( await execution.removed_all_ports_redirection() self.forget_vm(vm_hash) - raise self._schedule_forget_on_stop(execution) diff --git a/src/aleph/vm/utils/aggregate.py b/src/aleph/vm/utils/aggregate.py new file mode 100644 index 00000000..3087bf9b --- /dev/null +++ b/src/aleph/vm/utils/aggregate.py @@ -0,0 +1,38 @@ +from logging import getLogger + +from aleph.vm.conf import settings + +logger = getLogger(__name__) + + +async def get_user_aggregate(addr: str, keys_arg: list[str]) -> dict: + """ + Get the settings Aggregate dict from the PyAleph API Aggregate. + + API Endpoint: + GET /api/v0/aggregates/{address}.json?keys=settings + + For more details, see the PyAleph API documentation: + https://github.com/aleph-im/pyaleph/blob/master/src/aleph/web/controllers/routes.py#L62 + """ + + import aiohttp + + async with aiohttp.ClientSession() as session: + url = f"{settings.API_SERVER}/api/v0/aggregates/{addr}.json" + logger.info(f"Fetching aggregate from {url}") + resp = await session.get(url, params={"keys": ",".join(keys_arg)}) + # No aggregate for the user + if resp.status == 404: + return {} + # Raise an error if the request failed + + resp.raise_for_status() + + resp_data = await resp.json() + return resp_data["data"] or {} + + +async def get_user_settings(addr: str, key) -> dict: + aggregate = await get_user_aggregate(addr, [key]) + return aggregate.get(key, {}) From 536edcb75b62f44801e5ae90c2de5346159d5f59 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 15:01:54 +0200 Subject: [PATCH 16/23] typing and test --- src/aleph/vm/models.py | 7 +++++-- src/aleph/vm/network/firewall.py | 21 ++++++++++----------- tests/supervisor/test_views.py | 3 +++ 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 741af59f..a359a232 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -152,10 +152,13 @@ async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool] self.mapped_ports[vm_port] = {"host": host_port, **target} # Save to DB - self.record.mapped_ports = self.mapped_ports - await save_record(self.record) + if self.record: + self.record.mapped_ports = self.mapped_ports + await save_record(self.record) async def removed_all_ports_redirection(self): + if not self.vm: + return interface = self.vm.tap_interface for vm_port, map_detail in self.mapped_ports.items(): host_port = map_detail["host"] diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index 0954a3f1..554cef5f 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -21,7 +21,7 @@ def get_customized_nftables() -> Nftables: return nft -def execute_json_nft_commands(commands: list[dict]) -> int: +def execute_json_nft_commands(commands: list[dict]) -> dict: """Executes a list of nftables commands, and returns the exit status""" nft = get_customized_nftables() commands_dict = {"nftables": commands} @@ -197,7 +197,7 @@ def teardown_nftables() -> None: remove_chain(f"{settings.NFTABLES_CHAIN_PREFIX}-supervisor-filter") -def add_chain(family: str, table: str, name: str) -> int: +def add_chain(family: str, table: str, name: str) -> dict: """Helper function to quickly create a new chain in the nftables ruleset Returns the exit code from executing the nftables commands""" commands = [_make_add_chain_command(family, table, name)] @@ -216,7 +216,7 @@ def _make_add_chain_command(family: str, table: str, name: str) -> dict: } -def remove_chain(name: str) -> int: +def remove_chain(name: str) -> dict: """Removes all rules that jump to the chain, and then removes the chain itself. Returns the exit code from executing the nftables commands""" nft_ruleset = get_existing_nftables_ruleset() @@ -260,7 +260,7 @@ def remove_chain(name: str) -> int: return execute_json_nft_commands(commands) -def add_postrouting_chain(name: str) -> int: +def add_postrouting_chain(name: str) -> dict: """Adds a chain and creates a rule from the base chain with the postrouting hook. Returns the exit code from executing the nftables commands""" table = get_table_for_hook("postrouting") @@ -280,7 +280,7 @@ def add_postrouting_chain(name: str) -> int: return execute_json_nft_commands(command) -def add_forward_chain(name: str) -> int: +def add_forward_chain(name: str) -> dict: """Adds a chain and creates a rule from the base chain with the forward hook. Returns the exit code from executing the nftables commands""" table = get_table_for_hook("forward") @@ -300,7 +300,7 @@ def add_forward_chain(name: str) -> int: return execute_json_nft_commands(command) -def add_masquerading_rule(vm_id: int, interface: TapInterface) -> int: +def add_masquerading_rule(vm_id: int, interface: TapInterface) -> dict: """Creates a rule for the VM with the specified id to allow outbound traffic to be masqueraded (NAT) Returns the exit code from executing the nftables commands""" table = get_table_for_hook("postrouting") @@ -336,7 +336,7 @@ def add_masquerading_rule(vm_id: int, interface: TapInterface) -> int: return execute_json_nft_commands(command) -def add_forward_rule_to_external(vm_id: int, interface: TapInterface) -> int: +def add_forward_rule_to_external(vm_id: int, interface: TapInterface) -> dict: """Creates a rule for the VM with the specified id to allow outbound traffic Returns the exit code from executing the nftables commands""" table = get_table_for_hook("forward") @@ -372,7 +372,7 @@ def add_forward_rule_to_external(vm_id: int, interface: TapInterface) -> int: return execute_json_nft_commands(command) -def add_prerouting_chain() -> int: +def add_prerouting_chain() -> dict: """Creates the prerouting chain if it doesn't exist already. Returns: @@ -401,7 +401,7 @@ def add_prerouting_chain() -> int: return execute_json_nft_commands(commands) -def add_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp") -> int: +def add_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp") -> dict: """Creates a rule to redirect traffic from a host port to a VM port. Args: @@ -414,7 +414,6 @@ def add_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int Returns: The exit code from executing the nftables commands """ - # table = get_table_for_hook("prerouting") add_prerouting_chain() commands = [ { @@ -450,7 +449,7 @@ def add_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int return execute_json_nft_commands(commands) -def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp") -> int: +def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp") -> dict: """Removes a rule that redirects traffic from a host port to a VM port. Args: diff --git a/tests/supervisor/test_views.py b/tests/supervisor/test_views.py index 02a665bc..2d04f5e8 100644 --- a/tests/supervisor/test_views.py +++ b/tests/supervisor/test_views.py @@ -437,6 +437,7 @@ async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_wi ipv6_forwarding_enabled=False, ) network.setup() + network.host_ipv4 = "10.0.5.201" from aleph.vm.vm_type import VmType @@ -455,9 +456,11 @@ async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_wi assert await response.json() == { "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca": { "networking": { + "host_ipv4": "10.0.5.201", "ipv4_network": "172.16.3.0/24", "ipv6_network": "fc00:1:2:3:3:deca:deca:dec0/124", "ipv6_ip": "fc00:1:2:3:3:deca:deca:dec1", + "mapped_ports" : {} }, "status": { "defined_at": str(execution.times.defined_at), From a61624a7172098d83281d4e1271601a62921d1f9 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 15:06:18 +0200 Subject: [PATCH 17/23] mod rework network --- src/aleph/vm/network/firewall.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index 554cef5f..ca07aded 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -41,12 +41,11 @@ def execute_json_nft_commands(commands: list[dict]) -> dict: def get_existing_nftables_ruleset() -> dict: """Retrieves the full nftables ruleset and returns it""" - nft = get_customized_nftables() # List all NAT rules commands = [{"list": {"ruleset": {"family": "ip", "table": "nat"}}}] nft_ruleset = execute_json_nft_commands(commands) - return nft_ruleset + return nft_ruleset["nftables"] def get_base_chains_for_hook(hook: str, family: str = "ip") -> list: @@ -55,7 +54,7 @@ def get_base_chains_for_hook(hook: str, family: str = "ip") -> list: nft_ruleset = get_existing_nftables_ruleset() chains = [] - for entry in nft_ruleset["nftables"]: + for entry in nft_ruleset: if ( not isinstance(entry, dict) or "chain" not in entry @@ -81,7 +80,7 @@ def get_table_for_hook(hook: str, family: str = "ip") -> str: def check_if_table_exists(family: str, table: str) -> bool: """Checks whether the specified table exists in the nftables ruleset""" nft_ruleset = get_existing_nftables_ruleset() - for entry in nft_ruleset["nftables"]: + for entry in nft_ruleset: if ( isinstance(entry, dict) and "table" in entry @@ -223,7 +222,7 @@ def remove_chain(name: str) -> dict: commands = [] remove_chain_commands = [] - for entry in nft_ruleset["nftables"]: + for entry in nft_ruleset: if ( isinstance(entry, dict) and "rule" in entry @@ -530,9 +529,7 @@ def check_nftables_redirections(port: int) -> bool: try: rules = get_existing_nftables_ruleset() # Navigate through the JSON structure - nftables = rules.get("nftables", []) - - for item in nftables: + for item in rules: if "rule" not in item: continue From c7c7cd448fdc227e6ea235ac553649e0e83db16a Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 30 May 2025 15:06:48 +0200 Subject: [PATCH 18/23] mod rework network --- src/aleph/vm/network/firewall.py | 2 +- tests/supervisor/test_views.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index ca07aded..e423c223 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -380,7 +380,7 @@ def add_prerouting_chain() -> dict: # Check if prerouting chain exists by looking for chains with prerouting hook existing_chains = get_base_chains_for_hook("prerouting", "ip") if existing_chains: - return 0 # Chain already exists, nothing to do + return {} # Chain already exists, nothing to do commands = [ { diff --git a/tests/supervisor/test_views.py b/tests/supervisor/test_views.py index 2d04f5e8..8cd438bb 100644 --- a/tests/supervisor/test_views.py +++ b/tests/supervisor/test_views.py @@ -460,7 +460,7 @@ async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_wi "ipv4_network": "172.16.3.0/24", "ipv6_network": "fc00:1:2:3:3:deca:deca:dec0/124", "ipv6_ip": "fc00:1:2:3:3:deca:deca:dec1", - "mapped_ports" : {} + "mapped_ports": {}, }, "status": { "defined_at": str(execution.times.defined_at), From 890b091ebbe84b64487f739fb3e6a0cdc7bde818 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Mon, 2 Jun 2025 10:25:27 +0200 Subject: [PATCH 19/23] fix tests --- tests/supervisor/test_views.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/supervisor/test_views.py b/tests/supervisor/test_views.py index 8cd438bb..9e37ae4d 100644 --- a/tests/supervisor/test_views.py +++ b/tests/supervisor/test_views.py @@ -1,4 +1,3 @@ -import asyncio import os import tempfile from copy import deepcopy @@ -408,7 +407,7 @@ async def test_v2_executions_list_one_vm(aiohttp_client, mock_app_with_pool, moc async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_with_pool, mock_instance_content): "Test locally but do not create" web_app = await mock_app_with_pool - pool = web_app["vm_pool"] + pool: VmPool = web_app["vm_pool"] message = InstanceContent.model_validate(mock_instance_content) vm_hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" @@ -424,6 +423,8 @@ async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_wi vm_id = 3 from aleph.vm.network.hostnetwork import Network, make_ipv6_allocator + mocker.patch("aleph.vm.network.hostnetwork.get_interface_ipv4", return_value="10.0.5.201") + network = Network( vm_ipv4_address_pool_range=settings.IPV4_ADDRESS_POOL, vm_network_size=settings.IPV4_NETWORK_PREFIX_LENGTH, @@ -437,7 +438,7 @@ async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_wi ipv6_forwarding_enabled=False, ) network.setup() - network.host_ipv4 = "10.0.5.201" + pool.network = network from aleph.vm.vm_type import VmType @@ -617,6 +618,7 @@ def mock_is_kernel_enabled_gpu(pci_host: str) -> bool: ) mocker.patch.object(settings, "ENABLE_GPU_SUPPORT", True) + mocker.patch.object(settings, "ALLOW_VM_NETWORKING", False) pool = VmPool() await pool.setup() app = setup_webapp(pool=pool) From 0b7fc796476761d16bdb229e7f0f3df3caa004b4 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Mon, 2 Jun 2025 13:27:11 +0200 Subject: [PATCH 20/23] Add 2 tests for operate_update --- src/aleph/vm/models.py | 3 +- src/aleph/vm/orchestrator/views/__init__.py | 12 +++-- tests/supervisor/test_views.py | 55 +++++++++++++++++++++ 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index a359a232..3654dd8c 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -116,7 +116,8 @@ async def fetch_port_redirect_config_and_setup(self): await self.update_port_redirects(ports_requests) async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool]]): - assert self.vm + assert self.vm, "The VM attribute has to be set before calling update_port_redirects()" + logger.info("Updating port redirect. Current %s, New %s", self.mapped_ports, requested_ports) redirect_to_remove = set(self.mapped_ports.keys()) - set(requested_ports.keys()) redirect_to_add = set(requested_ports.keys()) - set(self.mapped_ports.keys()) diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index b2c369e2..cd8b23ea 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -690,15 +690,19 @@ async def operate_reserve_resources(request: web.Request, authenticated_sender: @cors_allow_all -async def operate_update(request: web.Request, authenticated_sender: str) -> web.Response: +async def operate_update(request: web.Request) -> web.Response: """Notify that the instance configuration has changed - For now used to notify the CRN that port forwarding config has changed - and that the we should fetch the latest user aggregate""" + For now used to notify the CRN that port-forwarding config has changed + and that it should be fetched and the setup upgraded""" vm_hash = get_itemhash_or_400(request.match_info) pool: VmPool = request.app["vm_pool"] execution: VmExecution = pool.executions.get(vm_hash) + if not execution: + raise HTTPNotFound(reason="VM not found") + if not execution.vm: + # Configuration will be fetched when the VM start so no need to return an error + return web.json_response({"status": "ok", "msg": "VM not starting yet"}, dumps=dumps_for_json, status=200) await execution.fetch_port_redirect_config_and_setup() - return web.json_response({}, dumps=dumps_for_json, status=200) diff --git a/tests/supervisor/test_views.py b/tests/supervisor/test_views.py index 9e37ae4d..2fadc1c3 100644 --- a/tests/supervisor/test_views.py +++ b/tests/supervisor/test_views.py @@ -838,3 +838,58 @@ async def test_reserve_resources_double_fail(aiohttp_client, mocker, mock_app_wi resp = await response.json() assert resp["status"] == "error", await response.text() assert len(app["vm_pool"].reservations) == 0 + + +@pytest.mark.asyncio +async def test_operate_not_started(aiohttp_client, mock_app_with_pool, mock_instance_content, mocker): + web_app = await mock_app_with_pool + pool = web_app["vm_pool"] + client = await aiohttp_client(web_app) + vm_hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" + message = InstanceContent.model_validate(mock_instance_content) + + mocker.patch.object(VmExecution, "update_port_redirects", new=mocker.AsyncMock(return_value=False)) + execution = VmExecution( + vm_hash=vm_hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + + pool.executions = {vm_hash: execution} + response: web.Response = await client.post("/control/machine/{ref}/update".format(ref=vm_hash)) + assert response.status == 200, await response.text() + response.raise_for_status() + resp = await response.json() + assert resp == {"msg": "VM not starting yet", "status": "ok"} + assert execution.update_port_redirects.call_count == 0 + + +@pytest.mark.asyncio +async def test_operate(aiohttp_client, mock_app_with_pool, mock_instance_content, mocker): + web_app = await mock_app_with_pool + pool = web_app["vm_pool"] + client = await aiohttp_client(web_app) + vm_hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" + message = InstanceContent.model_validate(mock_instance_content) + + mocker.patch.object(VmExecution, "update_port_redirects", new=mocker.AsyncMock(return_value=False)) + execution = VmExecution( + vm_hash=vm_hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + execution.vm = mocker.Mock() + + pool.executions = {vm_hash: execution} + response: web.Response = await client.post("/control/machine/{ref}/update".format(ref=vm_hash)) + assert response.status == 200, await response.text() + response.raise_for_status() + resp = await response.json() + assert resp == {} + assert execution.update_port_redirects.call_count == 1 From eb353ccf7ce5a6bad4448d503473050e1ab0c5df Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Tue, 3 Jun 2025 13:37:42 +0200 Subject: [PATCH 21/23] Fix cleanup issue --- src/aleph/vm/models.py | 8 ++++++-- src/aleph/vm/network/firewall.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 3654dd8c..00bfd425 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -161,7 +161,8 @@ async def removed_all_ports_redirection(self): if not self.vm: return interface = self.vm.tap_interface - for vm_port, map_detail in self.mapped_ports.items(): + # copy in a list since we modify dict during iteration + for vm_port, map_detail in list(self.mapped_ports.items()): host_port = map_detail["host"] for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: if map_detail[protocol]: @@ -476,7 +477,10 @@ async def non_blocking_wait_for_boot(self): except Exception as e: logger.warning("%s failed to responded to ping or is not running, stopping it.: %s ", self, e) assert self.vm - await self.stop() + try: + await self.stop() + except Exception as f: + logger.exception("%s failed to stop: %s", self, f) return False async def wait_for_init(self): diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index e423c223..6e828a3d 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -463,7 +463,7 @@ def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: nft_ruleset = get_existing_nftables_ruleset() commands = [] - for entry in nft_ruleset["nftables"]: + for entry in nft_ruleset: if ( isinstance(entry, dict) and "rule" in entry From 675d3e0d74ae045cf806db90e6af85c9efe4fe72 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Tue, 3 Jun 2025 15:22:47 +0200 Subject: [PATCH 22/23] POC Rescue console , auth disabled. log not working --- src/aleph/vm/hypervisors/qemu/qemuvm.py | 30 +++++++-- src/aleph/vm/orchestrator/supervisor.py | 2 + src/aleph/vm/orchestrator/views/operator.py | 71 +++++++++++++++++++++ 3 files changed, 97 insertions(+), 6 deletions(-) diff --git a/src/aleph/vm/hypervisors/qemu/qemuvm.py b/src/aleph/vm/hypervisors/qemu/qemuvm.py index 79fd60e5..579befd7 100644 --- a/src/aleph/vm/hypervisors/qemu/qemuvm.py +++ b/src/aleph/vm/hypervisors/qemu/qemuvm.py @@ -44,6 +44,9 @@ def __init__(self, vm_hash, config: QemuVMConfiguration): self.image_path = config.image_path self.monitor_socket_path = config.monitor_socket_path self.qmp_socket_path = config.qmp_socket_path + self.dir = Path("/var/lib/aleph/vm/executions/{}".format(vm_hash)) + self.dir.mkdir(exist_ok=True, parents=True) + self.serial_socket_path = self.dir / "serial.sock" self.vcpu_count = config.vcpu_count self.mem_size_mb = config.mem_size_mb self.interface_name = config.interface_name @@ -80,6 +83,9 @@ async def start( self.journal_stderr: BinaryIO = journal.stream(self._journal_stderr_name) # hardware_resources.published ports -> not implemented at the moment # hardware_resources.seconds -> only for microvm + # open('/proc/self/stdout', 'w').write('x\r\n') + # open('/dev/stdout', 'wb').write('x\r\n') + args = [ self.qemu_bin_path, "-enable-kvm", @@ -90,10 +96,10 @@ async def start( str(self.vcpu_count), "-drive", f"file={self.image_path},media=disk,if=virtio", - # To debug you can pass gtk or curses instead + # To debug pass gtk or curses instead of none "-display", "none", - # "--no-reboot", # Rebooting from inside the VM shuts down the machine + # "--no-reboot", # Rebooting from inside the VM shuts down the machine # Disable --no-reboot so user can reboot from inside the VM. see ALEPH-472 # Listen for commands on this socket "-monitor", @@ -103,10 +109,21 @@ async def start( "-qmp", f"unix:{self.qmp_socket_path},server,nowait", # Tell to put the output to std fd, so we can include them in the log - "-serial", - "stdio", - # nographics. Seems redundant with -serial stdio but without it the boot process is not displayed on stdout + # "-serial", + # "stdio", + # nographic. Disable graphic ui, redirect serial to stdio (which will be modified afterward), redirect parallel to stdio also (necessary to see bios boot) "-nographic", + # Redirect the serial, which expose an unix console, to a unix socket, used for remote debug + # Ideally, parallel should be redirected too with mux=on, but in practice it crashes qemu + # logfile=/dev/stdout allow the output to be displayed in journalctl for the log endpoint + "-chardev", + f"socket,id=iounix,path={str(self.serial_socket_path)},wait=off,server=on" + # f",logfile=/dev/stdout,mux=off,logappend=off", # DOES NOT WORK WITH SYSTEMD + f",logfile={self.dir/ "execution.log"},mux=off,logappend=off", # DOES NOT WORK WITH SYSTEMD + # Ideally insted of logfile we would use chardev.hub but it is qemu >= 10 only + "-serial", + "chardev:iounix", + ### # Boot # order=c only first hard drive # reboot-timeout in combination with -no-reboot, makes it so qemu stop if there is no bootable device @@ -114,7 +131,7 @@ async def start( "order=c,reboot-timeout=1", # Uncomment for debug # "-serial", "telnet:localhost:4321,server,nowait", - # "-snapshot", # Do not save anything to disk + # "-snapshot", # Do not save anything to disk ] if self.interface_name: # script=no, downscript=no tell qemu not to try to set up the network itself @@ -130,6 +147,7 @@ async def start( self.qemu_process = proc = await asyncio.create_subprocess_exec( *args, stdin=asyncio.subprocess.DEVNULL, + # stdout= asyncio.subprocess.STDOUT, stdout=self.journal_stdout, stderr=self.journal_stderr, ) diff --git a/src/aleph/vm/orchestrator/supervisor.py b/src/aleph/vm/orchestrator/supervisor.py index 8b745a9f..a760b4e9 100644 --- a/src/aleph/vm/orchestrator/supervisor.py +++ b/src/aleph/vm/orchestrator/supervisor.py @@ -57,6 +57,7 @@ operate_reboot, operate_stop, stream_logs, + operate_serial, ) logger = logging.getLogger(__name__) @@ -138,6 +139,7 @@ def setup_webapp(pool: VmPool | None): web.post("/control/machine/{ref}/update", operate_update), web.get("/control/machine/{ref}/stream_logs", stream_logs), web.get("/control/machine/{ref}/logs", operate_logs_json), + web.get("/control/machine/{ref}/serial", operate_serial), # websocket web.post("/control/machine/{ref}/expire", operate_expire), web.post("/control/machine/{ref}/stop", operate_stop), web.post("/control/machine/{ref}/erase", operate_erase), diff --git a/src/aleph/vm/orchestrator/views/operator.py b/src/aleph/vm/orchestrator/views/operator.py index 3178dfdc..b31e911d 100644 --- a/src/aleph/vm/orchestrator/views/operator.py +++ b/src/aleph/vm/orchestrator/views/operator.py @@ -1,8 +1,10 @@ +import asyncio import json import logging from datetime import timedelta from http import HTTPStatus +import aiohttp import aiohttp.web_exceptions import pydantic from aiohttp import web @@ -421,3 +423,72 @@ async def operate_erase(request: web.Request, authenticated_sender: str) -> web. volume.path_on_host.unlink() return web.Response(status=200, body=f"Erased VM with ref {vm_hash}") + + +BUFFER_SIZE = 1024 + + +@cors_allow_all +async def operate_serial(request: web.Request) -> web.StreamResponse: + """Connect to the serial console of a virtual machine. + + The authentication method is slightly different because browsers do not + allow Javascript to set headers in WebSocket requests. + """ + # "/var/lib/aleph/vm/executions/decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca/serial.sock" + vm_hash = get_itemhash_or_400(request.match_info) + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) + + if execution.vm is None: + raise web.HTTPBadRequest(body=f"VM {vm_hash} is not running") + unix_socket_path = f"/var/lib/aleph/vm/executions/{execution.vm_hash}/serial.sock" + ws = web.WebSocketResponse() + logger.info(f"starting websocket: {request.path}") + await ws.prepare(request) + try: + # await authenticate_websocket_for_vm_or_403(execution, vm_hash, ws) + # await ws.send_json({"status": "connected"}) + reader, writer = await asyncio.open_unix_connection(unix_socket_path) + + async def ws_to_unix(): + msg: aiohttp.WSMessage + async for msg in ws: + if msg.type == aiohttp.WSMsgType.BINARY: + logger.info('bytes, %s', msg.data) + writer.write(msg.data) + await writer.drain() + elif msg.type == aiohttp.WSMsgType.TEXT: + print('r') + writer.write(msg.data.encode()) # allow basic text clients + await writer.drain() + elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR): + break + writer.close() + await writer.wait_closed() + + async def unix_to_ws(): + try: + while not ws.closed: + data = await reader.read(BUFFER_SIZE) + if not data: + break + logger.info('sending data, %s', data) + await ws.send_bytes(data) + except asyncio.CancelledError: + pass + + to_unix = asyncio.create_task(ws_to_unix()) + from_unix = asyncio.create_task(unix_to_ws()) + + await asyncio.wait([to_unix, from_unix], return_when=asyncio.FIRST_COMPLETED) + + for task in (to_unix, from_unix): + task.cancel() + + return ws + + finally: + await ws.close() + logger.info(f"connection {ws} closed") From 1869b998012bd842cc5180e0119544e33ee867a2 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Tue, 3 Jun 2025 15:23:06 +0200 Subject: [PATCH 23/23] add sample client --- scripts/serial_client.py | 55 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 scripts/serial_client.py diff --git a/scripts/serial_client.py b/scripts/serial_client.py new file mode 100644 index 00000000..34add635 --- /dev/null +++ b/scripts/serial_client.py @@ -0,0 +1,55 @@ +import asyncio +import os +import sys +import termios +import tty + +import websockets + +WS_URL = "ws://localhost:4020/control/machine/{ref}/serial" +EXIT_SEQUENCE = b"\x1d" # Ctrl-] + + +async def client(): + old_settings = termios.tcgetattr(sys.stdin) + ws_url = WS_URL.format(ref="decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca") + + + try: + async with websockets.connect(ws_url) as ws: + print("[Connected] Use Ctrl-] to disconnect ", file=sys.stderr) + tty.setraw(sys.stdin.fileno()) # Set terminal to raw mode + + async def send_input(): + while True: + data = await asyncio.get_event_loop().run_in_executor(None, os.read, sys.stdin.fileno(), 1) + if not data: + break + if data == EXIT_SEQUENCE: + print("\n[Exit sequence received — quitting]", file=sys.stderr) + break + if data == b"\r": + data = b"\r\n" + await ws.send(data) + + async def receive_output(): + async for message in ws: + if isinstance(message, bytes): + os.write(sys.stdout.fileno(), message) + else: + os.write(sys.stdout.fileno(), message.encode()) + + task_send = asyncio.create_task(send_input()) + task_recv = asyncio.create_task(receive_output()) + + done, pending = await asyncio.wait([task_send, task_recv], return_when=asyncio.FIRST_COMPLETED) + + for task in pending: + task.cancel() + finally: + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) # Restore terminal + print("\n[Disconnected]", file=sys.stderr) + + +if __name__ == "__main__": + asyncio.run(client())