diff --git a/docs/libp2p.security.pnet.rst b/docs/libp2p.security.pnet.rst new file mode 100644 index 000000000..9e7be3ea5 --- /dev/null +++ b/docs/libp2p.security.pnet.rst @@ -0,0 +1,29 @@ +libp2p.security.pnet package +================================ + +Submodules +---------- + +libp2p.security.pnet.protector module +------------------------------------- + +.. automodule:: libp2p.security.pnet.protector + :members: + :undoc-members: + :show-inheritance: + +libp2p.security.pnet.psk_conn module +------------------------------------ + +.. automodule:: libp2p.security.pnet.psk_conn + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.security.pnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.security.rst b/docs/libp2p.security.rst index fc55df33b..41ea5f399 100644 --- a/docs/libp2p.security.rst +++ b/docs/libp2p.security.rst @@ -9,6 +9,7 @@ Subpackages libp2p.security.insecure libp2p.security.noise + libp2p.security.pnet libp2p.security.secio libp2p.security.tls diff --git a/examples/ping/ping.py b/examples/ping/ping.py index f62689aa5..be37310fa 100644 --- a/examples/ping/ping.py +++ b/examples/ping/ping.py @@ -25,6 +25,7 @@ PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0") PING_LENGTH = 32 RESP_TIMEOUT = 60 +PSK = "dffb7e3135399a8b1612b2aaca1c36a3a8ac2cd0cca51ceeb2ced87d308cac6d" async def handle_ping(stream: INetStream) -> None: @@ -60,18 +61,24 @@ async def send_ping(stream: INetStream) -> None: print(f"error occurred : {e}") -async def run(port: int, destination: str) -> None: +async def run(port: int, destination: str, psk: int, transport: str) -> None: from libp2p.utils.address_validation import ( find_free_port, get_available_interfaces, - get_optimal_binding_address, ) if port <= 0: port = find_free_port() - listen_addrs = get_available_interfaces(port) - host = new_host(listen_addrs=listen_addrs) + if transport == "tcp": + listen_addrs = get_available_interfaces(port) + if transport == "ws": + listen_addrs = [multiaddr.Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/ws")] + + if psk == 1: + host = new_host(listen_addrs=listen_addrs, psk=PSK) + else: + host = new_host(listen_addrs=listen_addrs) async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start the peer-store cleanup task @@ -87,12 +94,9 @@ async def run(port: int, destination: str) -> None: for addr in all_addrs: print(f"{addr}") - # Use optimal address for the client command - optimal_addr = get_optimal_binding_address(port) - optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}" print( f"\nRun this from the same folder in another console:\n\n" - f"ping-demo -d {optimal_addr_with_peer}\n" + f"ping-demo -d {host.get_addrs()[0]} -psk {psk} -t {transport}\n" ) print("Waiting for incoming connection...") @@ -130,10 +134,23 @@ def main() -> None: type=str, help=f"destination multiaddr string, e.g. {example_maddr}", ) + + parser.add_argument( + "-psk", "--psk", default=0, type=int, help="Enable PSK in the transport layer" + ) + + parser.add_argument( + "-t", + "--transport", + default="tcp", + type=str, + help="Choose the transport layer for ping TCP/WS", + ) + args = parser.parse_args() try: - trio.run(run, *(args.port, args.destination)) + trio.run(run, *(args.port, args.destination, args.psk, args.transport)) except KeyboardInterrupt: pass diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 9624bc5ba..d0ac3d8ab 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -182,6 +182,7 @@ def new_swarm( connection_config: ConnectionConfig | QUICTransportConfig | None = None, tls_client_config: ssl.SSLContext | None = None, tls_server_config: ssl.SSLContext | None = None, + psk: str | None = None ) -> INetworkService: logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}") """ @@ -195,6 +196,7 @@ def new_swarm( :param listen_addrs: optional list of multiaddrs to listen on :param enable_quic: enable quic for transport :param quic_transport_opt: options for transport + :param psk: optional pre-shared key for PSK encryption in transport :return: return a default swarm instance Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer @@ -300,7 +302,8 @@ def new_swarm( upgrader, transport, retry_config=retry_config, - connection_config=connection_config + connection_config=connection_config, + psk=psk ) @@ -320,6 +323,7 @@ def new_host( quic_transport_opt: QUICTransportConfig | None = None, tls_client_config: ssl.SSLContext | None = None, tls_server_config: ssl.SSLContext | None = None, + psk: str | None = None ) -> IHost: """ Create a new libp2p host based on the given parameters. @@ -337,6 +341,7 @@ def new_host( :param quic_transport_opt: optional configuration for quic transport :param tls_client_config: optional TLS client configuration for WebSocket transport :param tls_server_config: optional TLS server configuration for WebSocket transport + :param psk: optional pre-shared key (PSK) :return: return a host instance """ @@ -353,7 +358,8 @@ def new_host( listen_addrs=listen_addrs, connection_config=quic_transport_opt if enable_quic else None, tls_client_config=tls_client_config, - tls_server_config=tls_server_config + tls_server_config=tls_server_config, + psk=psk ) if disc_opt is not None: diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 08a38ab21..8f66d8bbb 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -102,6 +102,7 @@ def __init__( bootstrap: list[str] | None = None, default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, + psk: str | None = None, ) -> None: self._network = network self._network.set_stream_handler(self._swarm_stream_handler) @@ -120,6 +121,7 @@ def __init__( self.bootstrap = None if bootstrap: self.bootstrap = BootstrapDiscovery(network, bootstrap) + self.psk = psk # Cache a signed-record if the local-node in the PeerStore envelope = create_signed_peer_record( diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 94d9c7a39..6f92b6722 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -34,6 +34,7 @@ from libp2p.peer.peerstore import ( PeerStoreError, ) +from libp2p.security.pnet.protector import new_protected_conn from libp2p.tools.async_service import ( Service, ) @@ -98,11 +99,13 @@ def __init__( transport: ITransport, retry_config: RetryConfig | None = None, connection_config: ConnectionConfig | QUICTransportConfig | None = None, + psk: str | None = None, ): self.self_id = peer_id self.peerstore = peerstore self.upgrader = upgrader self.transport = transport + self.psk = psk # Enhanced: Initialize retry and connection configuration self.retry_config = retry_config or RetryConfig() @@ -327,6 +330,10 @@ async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetC try: addr = Multiaddr(f"{addr}/p2p/{peer_id}") raw_conn = await self.transport.dial(addr) + + # Enable PNET if psk is provvided + if self.psk is not None: + raw_conn = new_protected_conn(raw_conn, self.psk) except OpenConnectionError as error: logger.debug("fail to dial peer %s over base transport", peer_id) raise SwarmException( @@ -515,6 +522,10 @@ async def conn_handler( raw_conn = RawConnection(read_write_closer, False) + # Enable PNET is psk is provided + if self.psk is not None: + raw_conn = new_protected_conn(raw_conn, self.psk) + # Per, https://discuss.libp2p.io/t/multistream-security/130, we first # secure the conn and then mux the conn try: diff --git a/libp2p/security/pnet/__init__.py b/libp2p/security/pnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libp2p/security/pnet/protector.py b/libp2p/security/pnet/protector.py new file mode 100644 index 000000000..af9143f0d --- /dev/null +++ b/libp2p/security/pnet/protector.py @@ -0,0 +1,10 @@ +from libp2p.abc import IRawConnection +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.security.pnet.psk_conn import PskConn + + +def new_protected_conn(conn: RawConnection | IRawConnection, psk: str) -> PskConn: + if len(psk) != 64: + raise ValueError("Expected 32-byte pre shared key (PSK)") + + return PskConn(conn, psk) diff --git a/libp2p/security/pnet/psk_conn.py b/libp2p/security/pnet/psk_conn.py new file mode 100644 index 000000000..5ef8a06c8 --- /dev/null +++ b/libp2p/security/pnet/psk_conn.py @@ -0,0 +1,58 @@ +import os + +from Crypto.Cipher import Salsa20 + +from libp2p.abc import IRawConnection +from libp2p.network.connection.raw_connection import RawConnection + + +class PskConn(RawConnection): + _psk: bytes + _conn: RawConnection | IRawConnection + + def __init__(self, conn: RawConnection | IRawConnection, psk: str) -> None: + self._psk = bytes.fromhex(psk) + self._conn = conn + + self.read_cipher: Salsa20.Salsa20Cipher | None = None + self.write_cipher: Salsa20.Salsa20Cipher | None = None + + async def write(self, data: bytes) -> None: + """ + Encrpyts and writes data to the stream. + On the first call, generates a 24-byte nonce and sends it first. + """ + if self.write_cipher is None: + nonce = os.urandom(8) + await self._conn.write(nonce) + self.write_cipher = Salsa20.new(key=self._psk, nonce=nonce) + + assert self.write_cipher is not None + ciphertext = self.write_cipher.encrypt(data) + + await self._conn.write(ciphertext) + + async def read(self, n: int | None = None) -> bytes: + """ + Reads and decrypts data. On the first call, it reads a 8-byte + nonce to initialize the decryption stream + """ + if self.read_cipher is None: + nonce = await self._conn.read(8) + if len(nonce) != 8: + raise ValueError("short nonce from stream") + + self.read_cipher = Salsa20.new(key=self._psk, nonce=nonce) + + data = await self._conn.read(n) + if not data: + return b"" + + plaintext = self.read_cipher.decrypt(data) + return plaintext + + async def close(self) -> None: + await self._conn.close() + + def get_remote_address(self) -> tuple[str, int] | None: + return self._conn.get_remote_address() diff --git a/tests/core/security/test_pnet.py b/tests/core/security/test_pnet.py new file mode 100644 index 000000000..e12ccfc44 --- /dev/null +++ b/tests/core/security/test_pnet.py @@ -0,0 +1,82 @@ +import pytest +import trio + +from libp2p.io.abc import ReadWriteCloser +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.security.pnet.protector import new_protected_conn + + +# --- MemoryPipe: implements ReadWriteCloser interface --- +class MemoryPipe(ReadWriteCloser): + """Wrap a pair of Trio memory channels into a ReadWriteCloser-like object.""" + + def __init__( + self, send: trio.MemorySendChannel, receive: trio.MemoryReceiveChannel + ): + self._send = send + self._receive = receive + + async def read(self, n: int | None = None) -> bytes: + """Read next chunk from receive channel.""" + return await self._receive.receive() + + async def write(self, data: bytes) -> None: + """Write a chunk to send channel.""" + await self._send.send(data) + + async def close(self) -> None: + """Close channels (noop for memory channels).""" + pass + + def get_remote_address(self) -> tuple[str, int] | None: + # Memory pipe doesn’t have a real address, so return None + return None + + +# --- Helper function to create a connected pair of PskConns --- +async def make_psk_pair(psk_hex: str): + send1, recv1 = trio.open_memory_channel(0) + send2, recv2 = trio.open_memory_channel(0) + + pipe1 = MemoryPipe(send1, recv2) + pipe2 = MemoryPipe(send2, recv1) + + raw1 = RawConnection(pipe1, False) + raw2 = RawConnection(pipe2, False) + + # NOTE: The new_protected_conn function needs to perform the handshake. + # We'll assume it does for this example. If not, a handshake() call + # might be needed here within a nursery. + psk_conn1 = new_protected_conn(raw1, psk_hex) + psk_conn2 = new_protected_conn(raw2, psk_hex) + + return psk_conn1, psk_conn2 + + +@pytest.mark.trio +async def test_psk_simple_message(): + # Use a fixed PSK for testing + psk_hex = "dffb7e3135399a8b1612b2aaca1c36a3a8ac2cd0cca51ceeb2ced87d308cac6d" + conn1, conn2 = await make_psk_pair(psk_hex) + + msg = b"hello world" + + async with trio.open_nursery() as nursery: + nursery.start_soon(conn1.write, msg) + received = await conn2.read(len(msg)) + + assert received == msg, "Decrypted message does not match original" + + +@pytest.mark.trio +async def test_psk_empty_message(): + # PSK for testing + psk_hex = "dffb7e3135399a8b1612b2aaca1c36a3a8ac2cd0cca51ceeb2ced87d308cac6d" + conn1, conn2 = await make_psk_pair(psk_hex) + + # Empty message should round-trip correctly + async with trio.open_nursery() as nursery: + nursery.start_soon(conn1.write, b"") + received = await conn2.read(0) + + assert received == b"", "Empty message failed"