Skip to content

POC rescue console endpoint #807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions scripts/serial_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import os
import sys
import termios
import tty

import websockets

WS_URL = "ws://localhost:4020/control/machine/{ref}/serial"
EXIT_SEQUENCE = b"\x1d" # Ctrl-]


async def client():
old_settings = termios.tcgetattr(sys.stdin)
ws_url = WS_URL.format(ref="decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca")


try:
async with websockets.connect(ws_url) as ws:
print("[Connected] Use Ctrl-] to disconnect ", file=sys.stderr)
tty.setraw(sys.stdin.fileno()) # Set terminal to raw mode

async def send_input():
while True:
data = await asyncio.get_event_loop().run_in_executor(None, os.read, sys.stdin.fileno(), 1)
if not data:
break
if data == EXIT_SEQUENCE:
print("\n[Exit sequence received — quitting]", file=sys.stderr)
break
if data == b"\r":
data = b"\r\n"
await ws.send(data)

async def receive_output():
async for message in ws:
if isinstance(message, bytes):
os.write(sys.stdout.fileno(), message)
else:
os.write(sys.stdout.fileno(), message.encode())

task_send = asyncio.create_task(send_input())
task_recv = asyncio.create_task(receive_output())

done, pending = await asyncio.wait([task_send, task_recv], return_when=asyncio.FIRST_COMPLETED)

for task in pending:
task.cancel()
finally:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) # Restore terminal
print("\n[Disconnected]", file=sys.stderr)


if __name__ == "__main__":
asyncio.run(client())
4 changes: 4 additions & 0 deletions src/aleph/vm/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,5 +529,9 @@ def make_db_url():
return f"sqlite+aiosqlite:///{settings.EXECUTION_DATABASE}"


def make_sync_db_url():
return f"sqlite:///{settings.EXECUTION_DATABASE}"


# Settings singleton
settings = Settings()
30 changes: 24 additions & 6 deletions src/aleph/vm/hypervisors/qemu/qemuvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def __init__(self, vm_hash, config: QemuVMConfiguration):
self.image_path = config.image_path
self.monitor_socket_path = config.monitor_socket_path
self.qmp_socket_path = config.qmp_socket_path
self.dir = Path("/var/lib/aleph/vm/executions/{}".format(vm_hash))
self.dir.mkdir(exist_ok=True, parents=True)
self.serial_socket_path = self.dir / "serial.sock"
self.vcpu_count = config.vcpu_count
self.mem_size_mb = config.mem_size_mb
self.interface_name = config.interface_name
Expand Down Expand Up @@ -80,6 +83,9 @@ async def start(
self.journal_stderr: BinaryIO = journal.stream(self._journal_stderr_name)
# hardware_resources.published ports -> not implemented at the moment
# hardware_resources.seconds -> only for microvm
# open('/proc/self/stdout', 'w').write('x\r\n')
# open('/dev/stdout', 'wb').write('x\r\n')

args = [
self.qemu_bin_path,
"-enable-kvm",
Expand All @@ -90,10 +96,10 @@ async def start(
str(self.vcpu_count),
"-drive",
f"file={self.image_path},media=disk,if=virtio",
# To debug you can pass gtk or curses instead
# To debug pass gtk or curses instead of none
"-display",
"none",
# "--no-reboot", # Rebooting from inside the VM shuts down the machine
# "--no-reboot", # Rebooting from inside the VM shuts down the machine
# Disable --no-reboot so user can reboot from inside the VM. see ALEPH-472
# Listen for commands on this socket
"-monitor",
Expand All @@ -103,18 +109,29 @@ async def start(
"-qmp",
f"unix:{self.qmp_socket_path},server,nowait",
# Tell to put the output to std fd, so we can include them in the log
"-serial",
"stdio",
# nographics. Seems redundant with -serial stdio but without it the boot process is not displayed on stdout
# "-serial",
# "stdio",
# nographic. Disable graphic ui, redirect serial to stdio (which will be modified afterward), redirect parallel to stdio also (necessary to see bios boot)
"-nographic",
# Redirect the serial, which expose an unix console, to a unix socket, used for remote debug
# Ideally, parallel should be redirected too with mux=on, but in practice it crashes qemu
# logfile=/dev/stdout allow the output to be displayed in journalctl for the log endpoint
"-chardev",
f"socket,id=iounix,path={str(self.serial_socket_path)},wait=off,server=on"
# f",logfile=/dev/stdout,mux=off,logappend=off", # DOES NOT WORK WITH SYSTEMD
f",logfile={self.dir/ "execution.log"},mux=off,logappend=off", # DOES NOT WORK WITH SYSTEMD
# Ideally insted of logfile we would use chardev.hub but it is qemu >= 10 only
"-serial",
"chardev:iounix",
###
# Boot
# order=c only first hard drive
# reboot-timeout in combination with -no-reboot, makes it so qemu stop if there is no bootable device
"-boot",
"order=c,reboot-timeout=1",
# Uncomment for debug
# "-serial", "telnet:localhost:4321,server,nowait",
# "-snapshot", # Do not save anything to disk
# "-snapshot", # Do not save anything to disk
]
if self.interface_name:
# script=no, downscript=no tell qemu not to try to set up the network itself
Expand All @@ -130,6 +147,7 @@ async def start(
self.qemu_process = proc = await asyncio.create_subprocess_exec(
*args,
stdin=asyncio.subprocess.DEVNULL,
# stdout= asyncio.subprocess.STDOUT,
stdout=self.journal_stdout,
stderr=self.journal_stderr,
)
Expand Down
169 changes: 117 additions & 52 deletions src/aleph/vm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
AlephQemuConfidentialInstance,
AlephQemuConfidentialResources,
)
from aleph.vm.network.firewall import add_port_redirect_rule, remove_port_redirect_rule
from aleph.vm.network.interfaces import TapInterface
from aleph.vm.network.port_availability_checker import get_available_host_port
from aleph.vm.orchestrator.metrics import (
ExecutionRecord,
delete_record,
Expand All @@ -42,6 +44,9 @@
from aleph.vm.resources import GpuDevice, HostGPU
from aleph.vm.systemd import SystemDManager
from aleph.vm.utils import create_task_log_exceptions, dumps_for_json, is_pinging
from aleph.vm.utils.aggregate import get_user_settings

SUPPORTED_PROTOCOL_FOR_REDIRECT = ["udp", "tcp"]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,6 +97,78 @@ class VmExecution:
systemd_manager: SystemDManager | None

persistent: bool = False
mapped_ports: dict[int, dict] # Port redirect to the VM
record: ExecutionRecord | None = None

async def fetch_port_redirect_config_and_setup(self):
message = self.message
try:
port_forwarding_settings = await get_user_settings(message.address, "port-forwarding")
ports_requests = port_forwarding_settings.get(self.vm_hash)

except Exception:
ports_requests = {}
logger.exception("Could not fetch the port redirect settings for user")
if not ports_requests:
# FIXME DEBUG FOR NOW
ports_requests = {22: {"tcp": True, "udp": False}}
ports_requests = ports_requests or {}
await self.update_port_redirects(ports_requests)

async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool]]):
assert self.vm, "The VM attribute has to be set before calling update_port_redirects()"

logger.info("Updating port redirect. Current %s, New %s", self.mapped_ports, requested_ports)
redirect_to_remove = set(self.mapped_ports.keys()) - set(requested_ports.keys())
redirect_to_add = set(requested_ports.keys()) - set(self.mapped_ports.keys())
redirect_to_check = set(requested_ports.keys()).intersection(set(self.mapped_ports.keys()))
interface = self.vm.tap_interface

for vm_port in redirect_to_remove:
current = self.mapped_ports[vm_port]
for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT:
if current[protocol]:
host_port = current["host"]
remove_port_redirect_rule(interface, host_port, vm_port, protocol)
del self.mapped_ports[vm_port]

for vm_port in redirect_to_add:
target = requested_ports[vm_port]
host_port = get_available_host_port(start_port=25000)
for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT:
if target[protocol]:
add_port_redirect_rule(interface, host_port, vm_port, protocol)
self.mapped_ports[vm_port] = {"host": host_port, **target}

for vm_port in redirect_to_check:
current = self.mapped_ports[vm_port]
target = requested_ports[vm_port]
host_port = current["host"]
for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT:
if current[protocol] != target[protocol]:
if target[protocol]:
add_port_redirect_rule(interface, host_port, vm_port, protocol)
else:
remove_port_redirect_rule(interface, host_port, vm_port, protocol)
self.mapped_ports[vm_port] = {"host": host_port, **target}

# Save to DB
if self.record:
self.record.mapped_ports = self.mapped_ports
await save_record(self.record)

async def removed_all_ports_redirection(self):
if not self.vm:
return
interface = self.vm.tap_interface
# copy in a list since we modify dict during iteration
for vm_port, map_detail in list(self.mapped_ports.items()):
host_port = map_detail["host"]
for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT:
if map_detail[protocol]:
remove_port_redirect_rule(interface, host_port, vm_port, protocol)

del self.mapped_ports[vm_port]

@property
def is_starting(self) -> bool:
Expand Down Expand Up @@ -190,6 +267,7 @@ def __init__(
self.snapshot_manager = snapshot_manager
self.systemd_manager = systemd_manager
self.persistent = persistent
self.mapped_ports = {}

def to_dict(self) -> dict:
return {
Expand Down Expand Up @@ -350,7 +428,8 @@ async def start(self):

if self.vm and self.vm.support_snapshot and self.snapshot_manager:
await self.snapshot_manager.start_for(vm=self.vm)

else:
self.times.started_at = datetime.now(tz=timezone.utc)
self.ready_event.set()
await self.save()
except Exception:
Expand Down Expand Up @@ -398,7 +477,10 @@ async def non_blocking_wait_for_boot(self):
except Exception as e:
logger.warning("%s failed to responded to ping or is not running, stopping it.: %s ", self, e)
assert self.vm
await self.stop()
try:
await self.stop()
except Exception as f:
logger.exception("%s failed to stop: %s", self, f)
return False

async def wait_for_init(self):
Expand Down Expand Up @@ -456,6 +538,8 @@ async def stop(self) -> None:
await self.all_runs_complete()
await self.record_usage()
await self.vm.teardown()
await self.removed_all_ports_redirection()

self.times.stopped_at = datetime.now(tz=timezone.utc)
self.cancel_expiration()
self.cancel_update()
Expand Down Expand Up @@ -497,57 +581,38 @@ async def save(self):
"""Save to DB"""
assert self.vm, "The VM attribute has to be set before calling save()"

pid_info = self.vm.to_dict() if self.vm else None
# Handle cases when the process cannot be accessed
if not self.persistent and pid_info and pid_info.get("process"):
await save_record(
ExecutionRecord(
uuid=str(self.uuid),
vm_hash=self.vm_hash,
vm_id=self.vm_id,
time_defined=self.times.defined_at,
time_prepared=self.times.prepared_at,
time_started=self.times.started_at,
time_stopping=self.times.stopping_at,
cpu_time_user=pid_info["process"]["cpu_times"].user,
cpu_time_system=pid_info["process"]["cpu_times"].system,
io_read_count=pid_info["process"]["io_counters"][0],
io_write_count=pid_info["process"]["io_counters"][1],
io_read_bytes=pid_info["process"]["io_counters"][2],
io_write_bytes=pid_info["process"]["io_counters"][3],
vcpus=self.vm.hardware_resources.vcpus,
memory=self.vm.hardware_resources.memory,
network_tap=self.vm.tap_interface.device_name if self.vm.tap_interface else "",
message=self.message.model_dump_json(),
original_message=self.original.model_dump_json(),
persistent=self.persistent,
)
)
else:
# The process cannot be accessed, or it's a persistent VM.
await save_record(
ExecutionRecord(
uuid=str(self.uuid),
vm_hash=self.vm_hash,
vm_id=self.vm_id,
time_defined=self.times.defined_at,
time_prepared=self.times.prepared_at,
time_started=self.times.started_at,
time_stopping=self.times.stopping_at,
cpu_time_user=None,
cpu_time_system=None,
io_read_count=None,
io_write_count=None,
io_read_bytes=None,
io_write_bytes=None,
vcpus=self.vm.hardware_resources.vcpus,
memory=self.vm.hardware_resources.memory,
message=self.message.model_dump_json(),
original_message=self.original.model_dump_json(),
persistent=self.persistent,
gpus=json.dumps(self.gpus, default=pydantic_encoder),
)
if not self.record:
self.record = ExecutionRecord(
uuid=str(self.uuid),
vm_hash=self.vm_hash,
vm_id=self.vm_id,
time_defined=self.times.defined_at,
time_prepared=self.times.prepared_at,
time_started=self.times.started_at,
time_stopping=self.times.stopping_at,
cpu_time_user=None,
cpu_time_system=None,
io_read_count=None,
io_write_count=None,
io_read_bytes=None,
io_write_bytes=None,
vcpus=self.vm.hardware_resources.vcpus,
memory=self.vm.hardware_resources.memory,
message=self.message.model_dump_json(),
original_message=self.original.model_dump_json(),
persistent=self.persistent,
gpus=json.dumps(self.gpus, default=pydantic_encoder),
)
pid_info = self.vm.to_dict() if self.vm else None
# Handle cases when the process cannot be accessed
if not self.persistent and pid_info and pid_info.get("process"):
self.record.cpu_time_user = pid_info["process"]["cpu_times"].user
self.record.cpu_time_system = pid_info["process"]["cpu_times"].system
self.record.io_read_count = pid_info["process"]["io_counters"][0]
self.record.io_write_count = pid_info["process"]["io_counters"][1]
self.record.io_read_bytes = pid_info["process"]["io_counters"][2]
self.record.io_write_bytes = pid_info["process"]["io_counters"][3]
await save_record(self.record)

async def record_usage(self):
await delete_record(execution_uuid=str(self.uuid))
Expand Down
Loading
Loading