diff --git a/scripts/serial_client.py b/scripts/serial_client.py new file mode 100644 index 000000000..34add6358 --- /dev/null +++ b/scripts/serial_client.py @@ -0,0 +1,55 @@ +import asyncio +import os +import sys +import termios +import tty + +import websockets + +WS_URL = "ws://localhost:4020/control/machine/{ref}/serial" +EXIT_SEQUENCE = b"\x1d" # Ctrl-] + + +async def client(): + old_settings = termios.tcgetattr(sys.stdin) + ws_url = WS_URL.format(ref="decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca") + + + try: + async with websockets.connect(ws_url) as ws: + print("[Connected] Use Ctrl-] to disconnect ", file=sys.stderr) + tty.setraw(sys.stdin.fileno()) # Set terminal to raw mode + + async def send_input(): + while True: + data = await asyncio.get_event_loop().run_in_executor(None, os.read, sys.stdin.fileno(), 1) + if not data: + break + if data == EXIT_SEQUENCE: + print("\n[Exit sequence received — quitting]", file=sys.stderr) + break + if data == b"\r": + data = b"\r\n" + await ws.send(data) + + async def receive_output(): + async for message in ws: + if isinstance(message, bytes): + os.write(sys.stdout.fileno(), message) + else: + os.write(sys.stdout.fileno(), message.encode()) + + task_send = asyncio.create_task(send_input()) + task_recv = asyncio.create_task(receive_output()) + + done, pending = await asyncio.wait([task_send, task_recv], return_when=asyncio.FIRST_COMPLETED) + + for task in pending: + task.cancel() + finally: + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) # Restore terminal + print("\n[Disconnected]", file=sys.stderr) + + +if __name__ == "__main__": + asyncio.run(client()) diff --git a/src/aleph/vm/conf.py b/src/aleph/vm/conf.py index b422aea6b..d110046cb 100644 --- a/src/aleph/vm/conf.py +++ b/src/aleph/vm/conf.py @@ -529,5 +529,9 @@ def make_db_url(): return f"sqlite+aiosqlite:///{settings.EXECUTION_DATABASE}" +def make_sync_db_url(): + return f"sqlite:///{settings.EXECUTION_DATABASE}" + + # Settings singleton settings = Settings() diff --git a/src/aleph/vm/hypervisors/qemu/qemuvm.py b/src/aleph/vm/hypervisors/qemu/qemuvm.py index 79fd60e51..579befd75 100644 --- a/src/aleph/vm/hypervisors/qemu/qemuvm.py +++ b/src/aleph/vm/hypervisors/qemu/qemuvm.py @@ -44,6 +44,9 @@ def __init__(self, vm_hash, config: QemuVMConfiguration): self.image_path = config.image_path self.monitor_socket_path = config.monitor_socket_path self.qmp_socket_path = config.qmp_socket_path + self.dir = Path("/var/lib/aleph/vm/executions/{}".format(vm_hash)) + self.dir.mkdir(exist_ok=True, parents=True) + self.serial_socket_path = self.dir / "serial.sock" self.vcpu_count = config.vcpu_count self.mem_size_mb = config.mem_size_mb self.interface_name = config.interface_name @@ -80,6 +83,9 @@ async def start( self.journal_stderr: BinaryIO = journal.stream(self._journal_stderr_name) # hardware_resources.published ports -> not implemented at the moment # hardware_resources.seconds -> only for microvm + # open('/proc/self/stdout', 'w').write('x\r\n') + # open('/dev/stdout', 'wb').write('x\r\n') + args = [ self.qemu_bin_path, "-enable-kvm", @@ -90,10 +96,10 @@ async def start( str(self.vcpu_count), "-drive", f"file={self.image_path},media=disk,if=virtio", - # To debug you can pass gtk or curses instead + # To debug pass gtk or curses instead of none "-display", "none", - # "--no-reboot", # Rebooting from inside the VM shuts down the machine + # "--no-reboot", # Rebooting from inside the VM shuts down the machine # Disable --no-reboot so user can reboot from inside the VM. see ALEPH-472 # Listen for commands on this socket "-monitor", @@ -103,10 +109,21 @@ async def start( "-qmp", f"unix:{self.qmp_socket_path},server,nowait", # Tell to put the output to std fd, so we can include them in the log - "-serial", - "stdio", - # nographics. Seems redundant with -serial stdio but without it the boot process is not displayed on stdout + # "-serial", + # "stdio", + # nographic. Disable graphic ui, redirect serial to stdio (which will be modified afterward), redirect parallel to stdio also (necessary to see bios boot) "-nographic", + # Redirect the serial, which expose an unix console, to a unix socket, used for remote debug + # Ideally, parallel should be redirected too with mux=on, but in practice it crashes qemu + # logfile=/dev/stdout allow the output to be displayed in journalctl for the log endpoint + "-chardev", + f"socket,id=iounix,path={str(self.serial_socket_path)},wait=off,server=on" + # f",logfile=/dev/stdout,mux=off,logappend=off", # DOES NOT WORK WITH SYSTEMD + f",logfile={self.dir/ "execution.log"},mux=off,logappend=off", # DOES NOT WORK WITH SYSTEMD + # Ideally insted of logfile we would use chardev.hub but it is qemu >= 10 only + "-serial", + "chardev:iounix", + ### # Boot # order=c only first hard drive # reboot-timeout in combination with -no-reboot, makes it so qemu stop if there is no bootable device @@ -114,7 +131,7 @@ async def start( "order=c,reboot-timeout=1", # Uncomment for debug # "-serial", "telnet:localhost:4321,server,nowait", - # "-snapshot", # Do not save anything to disk + # "-snapshot", # Do not save anything to disk ] if self.interface_name: # script=no, downscript=no tell qemu not to try to set up the network itself @@ -130,6 +147,7 @@ async def start( self.qemu_process = proc = await asyncio.create_subprocess_exec( *args, stdin=asyncio.subprocess.DEVNULL, + # stdout= asyncio.subprocess.STDOUT, stdout=self.journal_stdout, stderr=self.journal_stderr, ) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 880b88914..00bfd4259 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -30,7 +30,9 @@ AlephQemuConfidentialInstance, AlephQemuConfidentialResources, ) +from aleph.vm.network.firewall import add_port_redirect_rule, remove_port_redirect_rule from aleph.vm.network.interfaces import TapInterface +from aleph.vm.network.port_availability_checker import get_available_host_port from aleph.vm.orchestrator.metrics import ( ExecutionRecord, delete_record, @@ -42,6 +44,9 @@ from aleph.vm.resources import GpuDevice, HostGPU from aleph.vm.systemd import SystemDManager from aleph.vm.utils import create_task_log_exceptions, dumps_for_json, is_pinging +from aleph.vm.utils.aggregate import get_user_settings + +SUPPORTED_PROTOCOL_FOR_REDIRECT = ["udp", "tcp"] logger = logging.getLogger(__name__) @@ -92,6 +97,78 @@ class VmExecution: systemd_manager: SystemDManager | None persistent: bool = False + mapped_ports: dict[int, dict] # Port redirect to the VM + record: ExecutionRecord | None = None + + async def fetch_port_redirect_config_and_setup(self): + message = self.message + try: + port_forwarding_settings = await get_user_settings(message.address, "port-forwarding") + ports_requests = port_forwarding_settings.get(self.vm_hash) + + except Exception: + ports_requests = {} + logger.exception("Could not fetch the port redirect settings for user") + if not ports_requests: + # FIXME DEBUG FOR NOW + ports_requests = {22: {"tcp": True, "udp": False}} + ports_requests = ports_requests or {} + await self.update_port_redirects(ports_requests) + + async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool]]): + assert self.vm, "The VM attribute has to be set before calling update_port_redirects()" + + logger.info("Updating port redirect. Current %s, New %s", self.mapped_ports, requested_ports) + redirect_to_remove = set(self.mapped_ports.keys()) - set(requested_ports.keys()) + redirect_to_add = set(requested_ports.keys()) - set(self.mapped_ports.keys()) + redirect_to_check = set(requested_ports.keys()).intersection(set(self.mapped_ports.keys())) + interface = self.vm.tap_interface + + for vm_port in redirect_to_remove: + current = self.mapped_ports[vm_port] + for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: + if current[protocol]: + host_port = current["host"] + remove_port_redirect_rule(interface, host_port, vm_port, protocol) + del self.mapped_ports[vm_port] + + for vm_port in redirect_to_add: + target = requested_ports[vm_port] + host_port = get_available_host_port(start_port=25000) + for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: + if target[protocol]: + add_port_redirect_rule(interface, host_port, vm_port, protocol) + self.mapped_ports[vm_port] = {"host": host_port, **target} + + for vm_port in redirect_to_check: + current = self.mapped_ports[vm_port] + target = requested_ports[vm_port] + host_port = current["host"] + for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: + if current[protocol] != target[protocol]: + if target[protocol]: + add_port_redirect_rule(interface, host_port, vm_port, protocol) + else: + remove_port_redirect_rule(interface, host_port, vm_port, protocol) + self.mapped_ports[vm_port] = {"host": host_port, **target} + + # Save to DB + if self.record: + self.record.mapped_ports = self.mapped_ports + await save_record(self.record) + + async def removed_all_ports_redirection(self): + if not self.vm: + return + interface = self.vm.tap_interface + # copy in a list since we modify dict during iteration + for vm_port, map_detail in list(self.mapped_ports.items()): + host_port = map_detail["host"] + for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: + if map_detail[protocol]: + remove_port_redirect_rule(interface, host_port, vm_port, protocol) + + del self.mapped_ports[vm_port] @property def is_starting(self) -> bool: @@ -190,6 +267,7 @@ def __init__( self.snapshot_manager = snapshot_manager self.systemd_manager = systemd_manager self.persistent = persistent + self.mapped_ports = {} def to_dict(self) -> dict: return { @@ -350,7 +428,8 @@ async def start(self): if self.vm and self.vm.support_snapshot and self.snapshot_manager: await self.snapshot_manager.start_for(vm=self.vm) - + else: + self.times.started_at = datetime.now(tz=timezone.utc) self.ready_event.set() await self.save() except Exception: @@ -398,7 +477,10 @@ async def non_blocking_wait_for_boot(self): except Exception as e: logger.warning("%s failed to responded to ping or is not running, stopping it.: %s ", self, e) assert self.vm - await self.stop() + try: + await self.stop() + except Exception as f: + logger.exception("%s failed to stop: %s", self, f) return False async def wait_for_init(self): @@ -456,6 +538,8 @@ async def stop(self) -> None: await self.all_runs_complete() await self.record_usage() await self.vm.teardown() + await self.removed_all_ports_redirection() + self.times.stopped_at = datetime.now(tz=timezone.utc) self.cancel_expiration() self.cancel_update() @@ -497,57 +581,38 @@ async def save(self): """Save to DB""" assert self.vm, "The VM attribute has to be set before calling save()" - pid_info = self.vm.to_dict() if self.vm else None - # Handle cases when the process cannot be accessed - if not self.persistent and pid_info and pid_info.get("process"): - await save_record( - ExecutionRecord( - uuid=str(self.uuid), - vm_hash=self.vm_hash, - vm_id=self.vm_id, - time_defined=self.times.defined_at, - time_prepared=self.times.prepared_at, - time_started=self.times.started_at, - time_stopping=self.times.stopping_at, - cpu_time_user=pid_info["process"]["cpu_times"].user, - cpu_time_system=pid_info["process"]["cpu_times"].system, - io_read_count=pid_info["process"]["io_counters"][0], - io_write_count=pid_info["process"]["io_counters"][1], - io_read_bytes=pid_info["process"]["io_counters"][2], - io_write_bytes=pid_info["process"]["io_counters"][3], - vcpus=self.vm.hardware_resources.vcpus, - memory=self.vm.hardware_resources.memory, - network_tap=self.vm.tap_interface.device_name if self.vm.tap_interface else "", - message=self.message.model_dump_json(), - original_message=self.original.model_dump_json(), - persistent=self.persistent, - ) - ) - else: - # The process cannot be accessed, or it's a persistent VM. - await save_record( - ExecutionRecord( - uuid=str(self.uuid), - vm_hash=self.vm_hash, - vm_id=self.vm_id, - time_defined=self.times.defined_at, - time_prepared=self.times.prepared_at, - time_started=self.times.started_at, - time_stopping=self.times.stopping_at, - cpu_time_user=None, - cpu_time_system=None, - io_read_count=None, - io_write_count=None, - io_read_bytes=None, - io_write_bytes=None, - vcpus=self.vm.hardware_resources.vcpus, - memory=self.vm.hardware_resources.memory, - message=self.message.model_dump_json(), - original_message=self.original.model_dump_json(), - persistent=self.persistent, - gpus=json.dumps(self.gpus, default=pydantic_encoder), - ) + if not self.record: + self.record = ExecutionRecord( + uuid=str(self.uuid), + vm_hash=self.vm_hash, + vm_id=self.vm_id, + time_defined=self.times.defined_at, + time_prepared=self.times.prepared_at, + time_started=self.times.started_at, + time_stopping=self.times.stopping_at, + cpu_time_user=None, + cpu_time_system=None, + io_read_count=None, + io_write_count=None, + io_read_bytes=None, + io_write_bytes=None, + vcpus=self.vm.hardware_resources.vcpus, + memory=self.vm.hardware_resources.memory, + message=self.message.model_dump_json(), + original_message=self.original.model_dump_json(), + persistent=self.persistent, + gpus=json.dumps(self.gpus, default=pydantic_encoder), ) + pid_info = self.vm.to_dict() if self.vm else None + # Handle cases when the process cannot be accessed + if not self.persistent and pid_info and pid_info.get("process"): + self.record.cpu_time_user = pid_info["process"]["cpu_times"].user + self.record.cpu_time_system = pid_info["process"]["cpu_times"].system + self.record.io_read_count = pid_info["process"]["io_counters"][0] + self.record.io_write_count = pid_info["process"]["io_counters"][1] + self.record.io_read_bytes = pid_info["process"]["io_counters"][2] + self.record.io_write_bytes = pid_info["process"]["io_counters"][3] + await save_record(self.record) async def record_usage(self): await delete_record(execution_uuid=str(self.uuid)) diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index eedc6ae11..6e828a3d3 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -5,8 +5,7 @@ from nftables import Nftables from aleph.vm.conf import settings - -from .interfaces import TapInterface +from aleph.vm.network.interfaces import TapInterface logger = logging.getLogger(__name__) @@ -22,7 +21,7 @@ def get_customized_nftables() -> Nftables: return nft -def execute_json_nft_commands(commands: list[dict]) -> int: +def execute_json_nft_commands(commands: list[dict]) -> dict: """Executes a list of nftables commands, and returns the exit status""" nft = get_customized_nftables() commands_dict = {"nftables": commands} @@ -37,20 +36,16 @@ def execute_json_nft_commands(commands: list[dict]) -> int: if return_code != 0: logger.error("Failed to add nftables rules: %s -- %s", error, json.dumps(commands, indent=4)) - return return_code + return output def get_existing_nftables_ruleset() -> dict: """Retrieves the full nftables ruleset and returns it""" - nft = get_customized_nftables() - return_code, output, error = nft.cmd("list ruleset") + # List all NAT rules + commands = [{"list": {"ruleset": {"family": "ip", "table": "nat"}}}] - if return_code != 0: - logger.error(f"Unable to get nftables ruleset: {error}") - return {"nftables": []} - - nft_ruleset = json.loads(output) - return nft_ruleset + nft_ruleset = execute_json_nft_commands(commands) + return nft_ruleset["nftables"] def get_base_chains_for_hook(hook: str, family: str = "ip") -> list: @@ -59,7 +54,7 @@ def get_base_chains_for_hook(hook: str, family: str = "ip") -> list: nft_ruleset = get_existing_nftables_ruleset() chains = [] - for entry in nft_ruleset["nftables"]: + for entry in nft_ruleset: if ( not isinstance(entry, dict) or "chain" not in entry @@ -85,7 +80,7 @@ def get_table_for_hook(hook: str, family: str = "ip") -> str: def check_if_table_exists(family: str, table: str) -> bool: """Checks whether the specified table exists in the nftables ruleset""" nft_ruleset = get_existing_nftables_ruleset() - for entry in nft_ruleset["nftables"]: + for entry in nft_ruleset: if ( isinstance(entry, dict) and "table" in entry @@ -201,7 +196,7 @@ def teardown_nftables() -> None: remove_chain(f"{settings.NFTABLES_CHAIN_PREFIX}-supervisor-filter") -def add_chain(family: str, table: str, name: str) -> int: +def add_chain(family: str, table: str, name: str) -> dict: """Helper function to quickly create a new chain in the nftables ruleset Returns the exit code from executing the nftables commands""" commands = [_make_add_chain_command(family, table, name)] @@ -220,14 +215,14 @@ def _make_add_chain_command(family: str, table: str, name: str) -> dict: } -def remove_chain(name: str) -> int: +def remove_chain(name: str) -> dict: """Removes all rules that jump to the chain, and then removes the chain itself. Returns the exit code from executing the nftables commands""" nft_ruleset = get_existing_nftables_ruleset() commands = [] remove_chain_commands = [] - for entry in nft_ruleset["nftables"]: + for entry in nft_ruleset: if ( isinstance(entry, dict) and "rule" in entry @@ -264,7 +259,7 @@ def remove_chain(name: str) -> int: return execute_json_nft_commands(commands) -def add_postrouting_chain(name: str) -> int: +def add_postrouting_chain(name: str) -> dict: """Adds a chain and creates a rule from the base chain with the postrouting hook. Returns the exit code from executing the nftables commands""" table = get_table_for_hook("postrouting") @@ -284,7 +279,7 @@ def add_postrouting_chain(name: str) -> int: return execute_json_nft_commands(command) -def add_forward_chain(name: str) -> int: +def add_forward_chain(name: str) -> dict: """Adds a chain and creates a rule from the base chain with the forward hook. Returns the exit code from executing the nftables commands""" table = get_table_for_hook("forward") @@ -304,7 +299,7 @@ def add_forward_chain(name: str) -> int: return execute_json_nft_commands(command) -def add_masquerading_rule(vm_id: int, interface: TapInterface) -> int: +def add_masquerading_rule(vm_id: int, interface: TapInterface) -> dict: """Creates a rule for the VM with the specified id to allow outbound traffic to be masqueraded (NAT) Returns the exit code from executing the nftables commands""" table = get_table_for_hook("postrouting") @@ -340,7 +335,7 @@ def add_masquerading_rule(vm_id: int, interface: TapInterface) -> int: return execute_json_nft_commands(command) -def add_forward_rule_to_external(vm_id: int, interface: TapInterface) -> int: +def add_forward_rule_to_external(vm_id: int, interface: TapInterface) -> dict: """Creates a rule for the VM with the specified id to allow outbound traffic Returns the exit code from executing the nftables commands""" table = get_table_for_hook("forward") @@ -376,6 +371,138 @@ def add_forward_rule_to_external(vm_id: int, interface: TapInterface) -> int: return execute_json_nft_commands(command) +def add_prerouting_chain() -> dict: + """Creates the prerouting chain if it doesn't exist already. + + Returns: + int: The exit code from executing the nftables commands + """ + # Check if prerouting chain exists by looking for chains with prerouting hook + existing_chains = get_base_chains_for_hook("prerouting", "ip") + if existing_chains: + return {} # Chain already exists, nothing to do + + commands = [ + { + "add": { + "chain": { + "family": "ip", + "table": "nat", + "name": "prerouting", + "type": "nat", + "hook": "prerouting", + "prio": -100, + "policy": "accept", + } + } + } + ] + return execute_json_nft_commands(commands) + + +def add_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp") -> dict: + """Creates a rule to redirect traffic from a host port to a VM port. + + Args: + vm_id: The ID of the VM + interface: The TapInterface instance for the VM + host_port: The port number on the host to listen on + vm_port: The port number to forward to on the VM + protocol: The protocol to use (tcp or udp, defaults to tcp) + + Returns: + The exit code from executing the nftables commands + """ + add_prerouting_chain() + commands = [ + { + "add": { + "rule": { + "family": "ip", + "table": "nat", + "chain": "prerouting", + "expr": [ + { + "match": { + "op": "==", + "left": {"meta": {"key": "iifname"}}, + "right": settings.NETWORK_INTERFACE, + } + }, + { + "match": { + "op": "==", + "left": {"payload": {"protocol": protocol, "field": "dport"}}, + "right": host_port, + } + }, + { + "dnat": {"addr": str(interface.guest_ip.ip), "port": vm_port}, + }, + ], + } + } + }, + ] + + return execute_json_nft_commands(commands) + + +def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: int, protocol: str = "tcp") -> dict: + """Removes a rule that redirects traffic from a host port to a VM port. + + Args: + interface: The TapInterface instance for the VM + host_port: The port number on the host that is listened on + vm_port: The port number that is being forwarded to on the VM + protocol: The protocol used (tcp or udp) + + Returns: + The exit code from executing the nftables commands + """ + nft_ruleset = get_existing_nftables_ruleset() + commands = [] + + for entry in nft_ruleset: + if ( + isinstance(entry, dict) + and "rule" in entry + and entry["rule"].get("family") == "ip" + and entry["rule"].get("table") == "nat" + and entry["rule"].get("chain") == "prerouting" + and "expr" in entry["rule"] + ): + expr = entry["rule"]["expr"] + # Check if this is our redirect rule by matching all conditions + if ( + len(expr) == 3 + and "match" in expr[0] + and expr[0]["match"]["left"].get("meta", {}).get("key") == "iifname" + and expr[0]["match"]["right"] == settings.NETWORK_INTERFACE + and "match" in expr[1] + and expr[1]["match"]["left"].get("payload", {}).get("protocol") == protocol + and expr[1]["match"]["left"]["payload"].get("field") == "dport" + and expr[1]["match"]["right"] == host_port + and "dnat" in expr[2] + and expr[2]["dnat"].get("addr") == str(interface.guest_ip.ip) + and expr[2]["dnat"].get("port") == vm_port + ): + commands.append( + { + "delete": { + "rule": { + "family": "ip", + "table": "nat", + "chain": "prerouting", + "handle": entry["rule"]["handle"], + } + } + } + ) + + return execute_json_nft_commands(commands) + + def setup_nftables_for_vm(vm_id: int, interface: TapInterface) -> None: """Sets up chains for filter and nat purposes specific to this VM, and makes sure those chains are jumped to""" add_postrouting_chain(f"{settings.NFTABLES_CHAIN_PREFIX}-vm-nat-{vm_id}") @@ -388,3 +515,45 @@ def teardown_nftables_for_vm(vm_id: int) -> None: """Remove all nftables rules related to the specified VM""" remove_chain(f"{settings.NFTABLES_CHAIN_PREFIX}-vm-nat-{vm_id}") remove_chain(f"{settings.NFTABLES_CHAIN_PREFIX}-vm-filter-{vm_id}") + + +def check_nftables_redirections(port: int) -> bool: + """Check if there are any NAT rules redirecting to the given port. + + Args: + port: The port number to check + + Returns: + True if the port is being used in any NAT redirection rules + """ + try: + rules = get_existing_nftables_ruleset() + # Navigate through the JSON structure + for item in rules: + if "rule" not in item: + continue + + rule = item["rule"] + if "expr" not in rule: + continue + + for expr in rule["expr"]: + # Check destination port in matches + if "match" in expr: + match = expr["match"] + if "left" in match and "payload" in match["left"]: + payload = match["left"]["payload"] + if payload.get("field") == "dport" and match.get("right") == port: + return True + + # Check port in dnat rules + if "dnat" in expr: + dnat = expr["dnat"] + if dnat.get("port") == port: + return True + + return False + + except Exception as e: + logger.warning(f"Error checking NAT redirections: {e}") + return False # Assume no redirections in case of error diff --git a/src/aleph/vm/network/get_interface_ipv4.py b/src/aleph/vm/network/get_interface_ipv4.py new file mode 100644 index 000000000..fe0ea91e9 --- /dev/null +++ b/src/aleph/vm/network/get_interface_ipv4.py @@ -0,0 +1,41 @@ +import ipaddress + +import netifaces + + +def get_interface_ipv4(interface_name: str) -> str: + """ + Get the main IPv4 address from a network interface. + + Args: + interface_name: Name of the network interface (e.g., 'eth0', 'wlan0') + + Returns: + str: The IPv4 address of the interface + + Raises: + ValueError: If the interface doesn't exist or has no IPv4 address + """ + try: + # Check if the interface exists + if interface_name not in netifaces.interfaces(): + raise ValueError(f"Interface {interface_name} does not exist") + + # Get addresses for the interface + addrs = netifaces.ifaddresses(interface_name) + + # Check for IPv4 addresses (AF_INET is IPv4) + if netifaces.AF_INET not in addrs: + raise ValueError(f"No IPv4 address found for interface {interface_name}") + + # Get the first IPv4 address + ipv4_info = addrs[netifaces.AF_INET][0] + ipv4_addr = ipv4_info["addr"] + + # Validate that it's a proper IPv4 address + ipaddress.IPv4Address(ipv4_addr) + + return ipv4_addr + + except Exception as e: + raise ValueError(f"Error getting IPv4 address: {str(e)}") diff --git a/src/aleph/vm/network/hostnetwork.py b/src/aleph/vm/network/hostnetwork.py index e9fcaa794..68b05fa89 100644 --- a/src/aleph/vm/network/hostnetwork.py +++ b/src/aleph/vm/network/hostnetwork.py @@ -10,6 +10,7 @@ from aleph.vm.vm_type import VmType from .firewall import initialize_nftables, setup_nftables_for_vm, teardown_nftables +from .get_interface_ipv4 import get_interface_ipv4 from .interfaces import TapInterface from .ipaddresses import IPv4NetworkWithInterfaces from .ndp_proxy import NdpProxy @@ -115,6 +116,7 @@ class Network: ipv6_address_pool: IPv6Network network_size: int ndp_proxy: NdpProxy | None = None + host_ipv4: str IPV6_SUBNET_PREFIX: int = 124 @@ -138,6 +140,7 @@ def __init__( self.ipv6_forwarding_enabled = ipv6_forwarding_enabled self.use_ndp_proxy = use_ndp_proxy self.ndb = pyroute2.NDB() + self.host_ipv4 = get_interface_ipv4(external_interface) if not self.ipv4_address_pool.is_private: logger.warning(f"Using a network range that is not private: {self.ipv4_address_pool}") diff --git a/src/aleph/vm/network/port_availability_checker.py b/src/aleph/vm/network/port_availability_checker.py new file mode 100644 index 000000000..da7cba435 --- /dev/null +++ b/src/aleph/vm/network/port_availability_checker.py @@ -0,0 +1,43 @@ +import socket +from typing import Optional + +from aleph.vm.network.firewall import check_nftables_redirections + +MIN_DYNAMIC_PORT = 24000 +MAX_PORT = 65535 + + +def get_available_host_port(start_port: Optional[int] = None) -> int: + """Find an available port on the host system. + + Args: + start_port: Optional starting port number. If not provided, starts from MIN_DYNAMIC_PORT + + Returns: + An available port number + + Raises: + RuntimeError: If no ports are available in the valid range + """ + start_port = start_port if start_port and start_port >= MIN_DYNAMIC_PORT else MIN_DYNAMIC_PORT + for port in range(start_port, MAX_PORT): + try: + # check if there is already a redirect to that port + if check_nftables_redirections(port): + continue + # Try both TCP and UDP on all interfaces + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_sock: + tcp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + tcp_sock.bind(("0.0.0.0", port)) + tcp_sock.listen(1) + + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as udp_sock: + udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + udp_sock.bind(("0.0.0.0", port)) + + return port + + except (socket.error, OSError): + pass + + raise RuntimeError(f"No available ports found in range {MIN_DYNAMIC_PORT}-{MAX_PORT}") diff --git a/src/aleph/vm/orchestrator/metrics.py b/src/aleph/vm/orchestrator/metrics.py index 6c9b8eea0..bc2bd3053 100644 --- a/src/aleph/vm/orchestrator/metrics.py +++ b/src/aleph/vm/orchestrator/metrics.py @@ -77,6 +77,7 @@ class ExecutionRecord(Base): persistent = Column(Boolean, nullable=True) gpus = Column(JSON, nullable=True) + mapped_ports = Column(JSON, nullable=True) def __repr__(self): return f"" diff --git a/src/aleph/vm/orchestrator/migrations/env.py b/src/aleph/vm/orchestrator/migrations/env.py index 2cf116bb6..bd5eadc7a 100644 --- a/src/aleph/vm/orchestrator/migrations/env.py +++ b/src/aleph/vm/orchestrator/migrations/env.py @@ -1,7 +1,7 @@ from alembic import context from sqlalchemy import create_engine -from aleph.vm.conf import make_db_url +from aleph.vm.conf import make_sync_db_url # Auto-generate migrations from aleph.vm.orchestrator.metrics import Base @@ -36,7 +36,8 @@ def run_migrations_offline() -> None: script output. """ - url = make_db_url() + url = make_sync_db_url() + context.configure( url=url, target_metadata=target_metadata, @@ -55,7 +56,7 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = create_engine(make_db_url()) + connectable = create_engine(make_sync_db_url()) with connectable.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py b/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py new file mode 100644 index 000000000..a8f7df0d9 --- /dev/null +++ b/src/aleph/vm/orchestrator/migrations/versions/2da719d72cea_add_mapped_ports_column.py @@ -0,0 +1,39 @@ +"""add mapped_ports column + +Revision ID: 2da719d72cea +Revises: 5c6ae643c69b +Create Date: 2025-05-29 23:20:42.801850 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +from sqlalchemy import create_engine +from sqlalchemy.engine import reflection + +from aleph.vm.conf import make_db_url + +# revision identifiers, used by Alembic. +revision = "2da719d72cea" +down_revision = "5c6ae643c69b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + engine = create_engine(make_db_url()) + inspector = reflection.Inspector.from_engine(engine) + + # The table already exists on most CRNs. + tables = inspector.get_table_names() + if "executions" in tables: + columns = inspector.get_columns("executions") + column_names = [c["name"] for c in columns] + if "mapped_ports" not in column_names: + op.add_column("executions", sa.Column("mapped_ports", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("executions", "mapped_ports") diff --git a/src/aleph/vm/orchestrator/supervisor.py b/src/aleph/vm/orchestrator/supervisor.py index 087352301..a760b4e9e 100644 --- a/src/aleph/vm/orchestrator/supervisor.py +++ b/src/aleph/vm/orchestrator/supervisor.py @@ -36,6 +36,7 @@ list_executions_v2, notify_allocation, operate_reserve_resources, + operate_update, run_code_from_hostname, run_code_from_path, status_check_fastapi, @@ -56,6 +57,7 @@ operate_reboot, operate_stop, stream_logs, + operate_serial, ) logger = logging.getLogger(__name__) @@ -134,8 +136,10 @@ def setup_webapp(pool: VmPool | None): # /control APIs are used to control the VMs and access their logs web.post("/control/allocation/notify", notify_allocation), web.post("/control/reserve_resources", operate_reserve_resources), + web.post("/control/machine/{ref}/update", operate_update), web.get("/control/machine/{ref}/stream_logs", stream_logs), web.get("/control/machine/{ref}/logs", operate_logs_json), + web.get("/control/machine/{ref}/serial", operate_serial), # websocket web.post("/control/machine/{ref}/expire", operate_expire), web.post("/control/machine/{ref}/stop", operate_stop), web.post("/control/machine/{ref}/erase", operate_erase), diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index e9bfff97a..cd8b23eaf 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -25,6 +25,7 @@ ) from aleph.vm.controllers.firecracker.program import FileTooLargeError from aleph.vm.hypervisors.firecracker.microvm import MicroVMFailedInitError +from aleph.vm.models import VmExecution from aleph.vm.orchestrator import payment, status from aleph.vm.orchestrator.chain import STREAM_CHAINS from aleph.vm.orchestrator.custom_logs import set_vm_for_logging @@ -55,6 +56,7 @@ check_host_egress_ipv4, check_host_egress_ipv6, ) +from aleph.vm.orchestrator.views.operator import get_itemhash_or_400 from aleph.vm.pool import VmPool from aleph.vm.utils import ( HostNotFoundError, @@ -185,8 +187,10 @@ async def list_executions_v2(request: web.Request) -> web.Response: item_hash: { "networking": { "ipv4_network": execution.vm.tap_interface.ip_network, + "host_ipv4": pool.network.host_ipv4, "ipv6_network": execution.vm.tap_interface.ipv6_network, "ipv6_ip": execution.vm.tap_interface.guest_ipv6.ip, + "mapped_ports": execution.mapped_ports, } if execution.vm and execution.vm.tap_interface else {}, @@ -683,3 +687,22 @@ async def operate_reserve_resources(request: web.Request, authenticated_sender: }, dumps=dumps_for_json, ) + + +@cors_allow_all +async def operate_update(request: web.Request) -> web.Response: + """Notify that the instance configuration has changed + + For now used to notify the CRN that port-forwarding config has changed + and that it should be fetched and the setup upgraded""" + vm_hash = get_itemhash_or_400(request.match_info) + + pool: VmPool = request.app["vm_pool"] + execution: VmExecution = pool.executions.get(vm_hash) + if not execution: + raise HTTPNotFound(reason="VM not found") + if not execution.vm: + # Configuration will be fetched when the VM start so no need to return an error + return web.json_response({"status": "ok", "msg": "VM not starting yet"}, dumps=dumps_for_json, status=200) + await execution.fetch_port_redirect_config_and_setup() + return web.json_response({}, dumps=dumps_for_json, status=200) diff --git a/src/aleph/vm/orchestrator/views/operator.py b/src/aleph/vm/orchestrator/views/operator.py index 3178dfdc1..b31e911de 100644 --- a/src/aleph/vm/orchestrator/views/operator.py +++ b/src/aleph/vm/orchestrator/views/operator.py @@ -1,8 +1,10 @@ +import asyncio import json import logging from datetime import timedelta from http import HTTPStatus +import aiohttp import aiohttp.web_exceptions import pydantic from aiohttp import web @@ -421,3 +423,72 @@ async def operate_erase(request: web.Request, authenticated_sender: str) -> web. volume.path_on_host.unlink() return web.Response(status=200, body=f"Erased VM with ref {vm_hash}") + + +BUFFER_SIZE = 1024 + + +@cors_allow_all +async def operate_serial(request: web.Request) -> web.StreamResponse: + """Connect to the serial console of a virtual machine. + + The authentication method is slightly different because browsers do not + allow Javascript to set headers in WebSocket requests. + """ + # "/var/lib/aleph/vm/executions/decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca/serial.sock" + vm_hash = get_itemhash_or_400(request.match_info) + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) + + if execution.vm is None: + raise web.HTTPBadRequest(body=f"VM {vm_hash} is not running") + unix_socket_path = f"/var/lib/aleph/vm/executions/{execution.vm_hash}/serial.sock" + ws = web.WebSocketResponse() + logger.info(f"starting websocket: {request.path}") + await ws.prepare(request) + try: + # await authenticate_websocket_for_vm_or_403(execution, vm_hash, ws) + # await ws.send_json({"status": "connected"}) + reader, writer = await asyncio.open_unix_connection(unix_socket_path) + + async def ws_to_unix(): + msg: aiohttp.WSMessage + async for msg in ws: + if msg.type == aiohttp.WSMsgType.BINARY: + logger.info('bytes, %s', msg.data) + writer.write(msg.data) + await writer.drain() + elif msg.type == aiohttp.WSMsgType.TEXT: + print('r') + writer.write(msg.data.encode()) # allow basic text clients + await writer.drain() + elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR): + break + writer.close() + await writer.wait_closed() + + async def unix_to_ws(): + try: + while not ws.closed: + data = await reader.read(BUFFER_SIZE) + if not data: + break + logger.info('sending data, %s', data) + await ws.send_bytes(data) + except asyncio.CancelledError: + pass + + to_unix = asyncio.create_task(ws_to_unix()) + from_unix = asyncio.create_task(unix_to_ws()) + + await asyncio.wait([to_unix, from_unix], return_when=asyncio.FIRST_COMPLETED) + + for task in (to_unix, from_unix): + task.cancel() + + return ws + + finally: + await ws.close() + logger.info(f"connection {ws} closed") diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 17a8d386c..d178129d6 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -6,12 +6,11 @@ import shutil from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import Any, cast +from typing import Any from aleph_message.models import ( Chain, ExecutableMessage, - InstanceContent, ItemHash, Payment, PaymentType, @@ -168,11 +167,13 @@ async def create_a_vm( if self.network.interface_exists(vm_id): await tap_interface.delete() await self.network.create_tap(vm_id, tap_interface) + else: tap_interface = None execution.create(vm_id=vm_id, tap_interface=tap_interface) await execution.start() + await execution.fetch_port_redirect_config_and_setup() # clear the user reservations for resource in resources: @@ -180,7 +181,9 @@ async def create_a_vm( del self.reservations[resource] except Exception: # ensure the VM is removed from the pool on creation error + await execution.removed_all_ports_redirection() self.forget_vm(vm_hash) + raise self._schedule_forget_on_stop(execution) @@ -281,6 +284,7 @@ async def load_persistent_executions(self): if saved_execution.gpus else [] ) + execution.mapped_ports = saved_execution.mapped_ports # Load and instantiate the rest of resources and already assigned GPUs await execution.prepare() if self.network: diff --git a/src/aleph/vm/utils/aggregate.py b/src/aleph/vm/utils/aggregate.py new file mode 100644 index 000000000..3087bf9b7 --- /dev/null +++ b/src/aleph/vm/utils/aggregate.py @@ -0,0 +1,38 @@ +from logging import getLogger + +from aleph.vm.conf import settings + +logger = getLogger(__name__) + + +async def get_user_aggregate(addr: str, keys_arg: list[str]) -> dict: + """ + Get the settings Aggregate dict from the PyAleph API Aggregate. + + API Endpoint: + GET /api/v0/aggregates/{address}.json?keys=settings + + For more details, see the PyAleph API documentation: + https://github.com/aleph-im/pyaleph/blob/master/src/aleph/web/controllers/routes.py#L62 + """ + + import aiohttp + + async with aiohttp.ClientSession() as session: + url = f"{settings.API_SERVER}/api/v0/aggregates/{addr}.json" + logger.info(f"Fetching aggregate from {url}") + resp = await session.get(url, params={"keys": ",".join(keys_arg)}) + # No aggregate for the user + if resp.status == 404: + return {} + # Raise an error if the request failed + + resp.raise_for_status() + + resp_data = await resp.json() + return resp_data["data"] or {} + + +async def get_user_settings(addr: str, key) -> dict: + aggregate = await get_user_aggregate(addr, [key]) + return aggregate.get(key, {}) diff --git a/tests/supervisor/test_views.py b/tests/supervisor/test_views.py index 02a665bca..2fadc1c31 100644 --- a/tests/supervisor/test_views.py +++ b/tests/supervisor/test_views.py @@ -1,4 +1,3 @@ -import asyncio import os import tempfile from copy import deepcopy @@ -408,7 +407,7 @@ async def test_v2_executions_list_one_vm(aiohttp_client, mock_app_with_pool, moc async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_with_pool, mock_instance_content): "Test locally but do not create" web_app = await mock_app_with_pool - pool = web_app["vm_pool"] + pool: VmPool = web_app["vm_pool"] message = InstanceContent.model_validate(mock_instance_content) vm_hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" @@ -424,6 +423,8 @@ async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_wi vm_id = 3 from aleph.vm.network.hostnetwork import Network, make_ipv6_allocator + mocker.patch("aleph.vm.network.hostnetwork.get_interface_ipv4", return_value="10.0.5.201") + network = Network( vm_ipv4_address_pool_range=settings.IPV4_ADDRESS_POOL, vm_network_size=settings.IPV4_NETWORK_PREFIX_LENGTH, @@ -437,6 +438,7 @@ async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_wi ipv6_forwarding_enabled=False, ) network.setup() + pool.network = network from aleph.vm.vm_type import VmType @@ -455,9 +457,11 @@ async def test_v2_executions_list_vm_network(aiohttp_client, mocker, mock_app_wi assert await response.json() == { "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca": { "networking": { + "host_ipv4": "10.0.5.201", "ipv4_network": "172.16.3.0/24", "ipv6_network": "fc00:1:2:3:3:deca:deca:dec0/124", "ipv6_ip": "fc00:1:2:3:3:deca:deca:dec1", + "mapped_ports": {}, }, "status": { "defined_at": str(execution.times.defined_at), @@ -614,6 +618,7 @@ def mock_is_kernel_enabled_gpu(pci_host: str) -> bool: ) mocker.patch.object(settings, "ENABLE_GPU_SUPPORT", True) + mocker.patch.object(settings, "ALLOW_VM_NETWORKING", False) pool = VmPool() await pool.setup() app = setup_webapp(pool=pool) @@ -833,3 +838,58 @@ async def test_reserve_resources_double_fail(aiohttp_client, mocker, mock_app_wi resp = await response.json() assert resp["status"] == "error", await response.text() assert len(app["vm_pool"].reservations) == 0 + + +@pytest.mark.asyncio +async def test_operate_not_started(aiohttp_client, mock_app_with_pool, mock_instance_content, mocker): + web_app = await mock_app_with_pool + pool = web_app["vm_pool"] + client = await aiohttp_client(web_app) + vm_hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" + message = InstanceContent.model_validate(mock_instance_content) + + mocker.patch.object(VmExecution, "update_port_redirects", new=mocker.AsyncMock(return_value=False)) + execution = VmExecution( + vm_hash=vm_hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + + pool.executions = {vm_hash: execution} + response: web.Response = await client.post("/control/machine/{ref}/update".format(ref=vm_hash)) + assert response.status == 200, await response.text() + response.raise_for_status() + resp = await response.json() + assert resp == {"msg": "VM not starting yet", "status": "ok"} + assert execution.update_port_redirects.call_count == 0 + + +@pytest.mark.asyncio +async def test_operate(aiohttp_client, mock_app_with_pool, mock_instance_content, mocker): + web_app = await mock_app_with_pool + pool = web_app["vm_pool"] + client = await aiohttp_client(web_app) + vm_hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" + message = InstanceContent.model_validate(mock_instance_content) + + mocker.patch.object(VmExecution, "update_port_redirects", new=mocker.AsyncMock(return_value=False)) + execution = VmExecution( + vm_hash=vm_hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + execution.vm = mocker.Mock() + + pool.executions = {vm_hash: execution} + response: web.Response = await client.post("/control/machine/{ref}/update".format(ref=vm_hash)) + assert response.status == 200, await response.text() + response.raise_for_status() + resp = await response.json() + assert resp == {} + assert execution.update_port_redirects.call_count == 1