diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 27be7e43..aa933988 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -107,7 +107,9 @@ async def fetch_port_redirect_config_and_setup(self): ports_requests: dict[int, dict] = {} try: port_forwarding_settings = await get_user_settings(message.address, "port-forwarding") - ports_requests = port_forwarding_settings.get(self.vm_hash, {}) or {} + vm_port_forwarding = port_forwarding_settings.get(self.vm_hash, {}) or {} + if "ports" in ports_requests: + ports_requests = vm_port_forwarding.get("ports", {}) except Exception: logger.info("Could not fetch the port redirect settings for user %s", message.address, exc_info=True) @@ -138,7 +140,7 @@ async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool] host_port = get_available_host_port(start_port=24000) for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: if target[protocol]: - add_port_redirect_rule(interface, host_port, vm_port, protocol) + add_port_redirect_rule(self.vm.vm_id, interface, host_port, vm_port, protocol) self.mapped_ports[vm_port] = {"host": host_port, **target} for vm_port in redirect_to_check: @@ -148,7 +150,7 @@ async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool] for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: if current[protocol] != target[protocol]: if target[protocol]: - add_port_redirect_rule(interface, host_port, vm_port, protocol) + add_port_redirect_rule(self.vm.vm_id, interface, host_port, vm_port, protocol) else: remove_port_redirect_rule(interface, host_port, vm_port, protocol) self.mapped_ports[vm_port] = {"host": host_port, **target} diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index 140e4548..c5dfed2c 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -393,7 +393,7 @@ def add_or_get_prerouting_chain() -> dict: # Check if prerouting chain exists by looking for chains with prerouting hook existing_chains = get_base_chains_for_hook("prerouting", "ip") if existing_chains: - return existing_chains[0] # Chain already exists, nothing to do + return existing_chains[0]["chain"] # Chain already exists, nothing to do commands = [ { @@ -410,12 +410,13 @@ def add_or_get_prerouting_chain() -> dict: } } ] + execute_json_nft_commands(commands) chain = commands[0]["add"]["chain"] - return execute_json_nft_commands(commands) + return chain def add_port_redirect_rule( - interface: TapInterface, host_port: int, vm_port: int, protocol: Literal["tcp"] | Literal["udp"] = "tcp" + vm_id, interface: TapInterface, host_port: int, vm_port: int, protocol: Literal["tcp"] | Literal["udp"] = "tcp" ) -> dict: """Creates a rule to redirect traffic from a host port to a VM port. @@ -460,6 +461,36 @@ def add_port_redirect_rule( } }, ] + # Add rule to accept that trafic on the host interface to that destination port + table = get_table_for_hook("forward") + commands += [ + { + "add": { + "rule": { + "family": "ip", + "table": table, + "chain": f"{settings.NFTABLES_CHAIN_PREFIX}-vm-filter-{vm_id}", + "expr": [ + { + "match": { + "op": "==", + "left": {"meta": {"key": "iifname"}}, + "right": interface.device_name, + } + }, + { + "match": { + "op": "==", + "left": {"payload": {"protocol": "tcp", "field": "dport"}}, + "right": vm_port, + } + }, + {"accept": None}, + ], + } + } + } + ] return execute_json_nft_commands(commands) @@ -477,6 +508,9 @@ def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: The exit code from executing the nftables commands """ nft_ruleset = get_existing_nftables_ruleset() + chain = add_or_get_prerouting_chain() + table = chain["table"] + commands = [] for entry in nft_ruleset: @@ -484,8 +518,8 @@ def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: isinstance(entry, dict) and "rule" in entry and entry["rule"].get("family") == "ip" - and entry["rule"].get("table") == "nat" - and entry["rule"].get("chain") == "prerouting" + and entry["rule"].get("table") == table + and entry["rule"].get("chain") == chain["name"] and "expr" in entry["rule"] ): expr = entry["rule"]["expr"] @@ -508,8 +542,8 @@ def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: "delete": { "rule": { "family": "ip", - "table": "nat", - "chain": "prerouting", + "table": table, + "chain": chain["name"], "handle": entry["rule"]["handle"], } }