Skip to content

Commit 35dfdda

Browse files
committed
Add net/udp module
1 parent 3a5a556 commit 35dfdda

File tree

8 files changed

+247
-7
lines changed

8 files changed

+247
-7
lines changed

src/net/common.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ use mlua::{IntoLua, Lua, Result, Value};
77

88
/// Socket address that can be either TCP or Unix domain socket.
99
pub enum AnySocketAddr {
10-
Tcp(std::net::SocketAddr),
10+
IP(std::net::SocketAddr),
1111
#[cfg(unix)]
1212
Unix(tokio::net::unix::SocketAddr),
1313
}
1414

1515
impl fmt::Display for AnySocketAddr {
1616
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1717
match self {
18-
AnySocketAddr::Tcp(addr) => write!(f, "{addr}"),
18+
AnySocketAddr::IP(addr) => write!(f, "{addr}"),
1919
#[cfg(unix)]
2020
AnySocketAddr::Unix(addr) => {
2121
let path = addr
@@ -42,11 +42,11 @@ pub trait AddressProvider {
4242

4343
impl AddressProvider for tokio::net::TcpStream {
4444
fn local_addr(&self) -> io::Result<AnySocketAddr> {
45-
Ok(AnySocketAddr::Tcp(self.local_addr()?))
45+
Ok(AnySocketAddr::IP(self.local_addr()?))
4646
}
4747

4848
fn peer_addr(&self) -> io::Result<AnySocketAddr> {
49-
Ok(AnySocketAddr::Tcp(self.peer_addr()?))
49+
Ok(AnySocketAddr::IP(self.peer_addr()?))
5050
}
5151
}
5252

src/net/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@ mod common;
3232
pub mod tcp;
3333
#[cfg(feature = "tls")]
3434
pub mod tls;
35+
pub mod udp;
3536
#[cfg(unix)]
3637
pub mod unix;

src/net/tcp/listener.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub struct TcpListener(pub(crate) tokio::net::TcpListener);
1212

1313
impl TcpListener {
1414
pub(crate) fn local_addr(&self) -> io::Result<AnySocketAddr> {
15-
self.0.local_addr().map(AnySocketAddr::Tcp)
15+
self.0.local_addr().map(AnySocketAddr::IP)
1616
}
1717
}
1818

src/net/tcp/stream.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ impl From<tokio::net::TcpStream> for TcpStream {
4747

4848
impl AddressProvider for TcpStream {
4949
fn local_addr(&self) -> io::Result<AnySocketAddr> {
50-
self.stream.local_addr().map(AnySocketAddr::Tcp)
50+
self.stream.local_addr().map(AnySocketAddr::IP)
5151
}
5252

5353
fn peer_addr(&self) -> io::Result<AnySocketAddr> {
54-
self.stream.peer_addr().map(AnySocketAddr::Tcp)
54+
self.stream.peer_addr().map(AnySocketAddr::IP)
5555
}
5656
}
5757

src/net/udp/mod.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use mlua::{Lua, Result, Table};
2+
3+
pub use socket::{UdpSocket, bind};
4+
5+
/// A loader for the `net/udp` module.
6+
fn loader(lua: &Lua) -> Result<Table> {
7+
let t = lua.create_table()?;
8+
t.set("UdpSocket", lua.create_proxy::<UdpSocket>()?)?;
9+
t.set("bind", lua.create_async_function(bind)?)?;
10+
Ok(t)
11+
}
12+
13+
/// Registers the `net/udp` module in the given Lua state.
14+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
15+
let name = name.unwrap_or("@net/udp");
16+
let value = loader(lua)?;
17+
lua.register_module(name, &value)?;
18+
Ok(value)
19+
}
20+
21+
mod socket;

src/net/udp/socket.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
use std::io;
2+
use std::ops::Deref;
3+
use std::result::Result as StdResult;
4+
5+
use mlua::{Lua, Result, String as LuaString, Table, UserData, UserDataMethods, UserDataRegistry, Value};
6+
7+
use crate::net::{AddressProvider, AnySocketAddr};
8+
use crate::time::Duration;
9+
10+
/// UDP socket wrapper.
11+
pub struct UdpSocket {
12+
pub(crate) socket: tokio::net::UdpSocket,
13+
pub(crate) recv_timeout: Option<Duration>,
14+
}
15+
16+
impl Deref for UdpSocket {
17+
type Target = tokio::net::UdpSocket;
18+
19+
#[inline]
20+
fn deref(&self) -> &Self::Target {
21+
&self.socket
22+
}
23+
}
24+
25+
impl From<tokio::net::UdpSocket> for UdpSocket {
26+
fn from(socket: tokio::net::UdpSocket) -> Self {
27+
UdpSocket {
28+
socket,
29+
recv_timeout: None,
30+
}
31+
}
32+
}
33+
34+
impl AddressProvider for UdpSocket {
35+
fn local_addr(&self) -> io::Result<AnySocketAddr> {
36+
self.socket.local_addr().map(AnySocketAddr::IP)
37+
}
38+
39+
fn peer_addr(&self) -> io::Result<AnySocketAddr> {
40+
self.socket.peer_addr().map(AnySocketAddr::IP)
41+
}
42+
}
43+
44+
impl UserData for UdpSocket {
45+
fn register(registry: &mut UserDataRegistry<Self>) {
46+
registry.add_async_function("bind", bind);
47+
48+
registry.add_method("local_addr", |_, this, ()| Ok(this.local_addr()?));
49+
registry.add_method("peer_addr", |_, this, ()| Ok(this.peer_addr()?));
50+
51+
registry.add_method_mut("set_recv_timeout", |_, this, dur: Option<Duration>| {
52+
this.recv_timeout = dur;
53+
Ok(())
54+
});
55+
56+
registry.add_method("set_ttl", |_, this, ttl: u32| {
57+
lua_try!(this.socket.set_ttl(ttl));
58+
Ok(Ok(()))
59+
});
60+
61+
registry.add_method("ttl", |_, this, ()| {
62+
let ttl = lua_try!(this.socket.ttl());
63+
Ok(Ok(ttl))
64+
});
65+
66+
registry.add_async_method("connect", |_, this, (host, port): (String, u16)| async move {
67+
lua_try!(this.socket.connect((host, port)).await);
68+
Ok(Ok(true))
69+
});
70+
71+
registry.add_async_method_mut("send", |_, this, data: LuaString| async move {
72+
let n = lua_try!(this.socket.send(&data.as_bytes()).await);
73+
Ok(Ok(n))
74+
});
75+
76+
registry.add_async_method_mut("recv", |lua, this, size: Option<usize>| async move {
77+
let size = size.unwrap_or(1472); // Default MTU size minus UDP header
78+
let mut buf = vec![0; size]; // TODO: reuse buffer?
79+
let n = with_io_timeout!(this.recv_timeout, this.socket.recv(&mut buf));
80+
let n = lua_try!(n);
81+
buf.truncate(n);
82+
Ok(Ok(lua.create_string(buf)?))
83+
});
84+
85+
registry.add_async_method_mut(
86+
"send_to",
87+
|_, this, (data, host, port): (LuaString, String, u16)| async move {
88+
let n = lua_try!(this.socket.send_to(&data.as_bytes(), (host, port)).await);
89+
Ok(Ok(n))
90+
},
91+
);
92+
93+
registry.add_async_method_mut("recv_from", |lua, this, size: Option<usize>| async move {
94+
let size = size.unwrap_or(1472); // Default MTU size minus UDP header
95+
let mut buf = vec![0; size]; // TODO: reuse buffer?
96+
match with_io_timeout!(this.recv_timeout, this.socket.recv_from(&mut buf)) {
97+
Ok((n, addr)) => {
98+
buf.truncate(n);
99+
let data = lua.create_string(buf)?;
100+
Ok((Value::String(data), addr.to_string()))
101+
}
102+
Err(e) => Ok((Value::Nil, e.to_string())),
103+
}
104+
});
105+
106+
registry.add_method("set_broadcast", |_, this, enable: bool| {
107+
lua_try!(this.socket.set_broadcast(enable));
108+
Ok(Ok(true))
109+
});
110+
111+
registry.add_method("broadcast", |_, this, ()| {
112+
let enabled = lua_try!(this.socket.broadcast());
113+
Ok(Ok(enabled))
114+
});
115+
}
116+
}
117+
118+
/// Binds a UDP socket to the given host and port with optional parameters.
119+
pub async fn bind(
120+
_: Lua,
121+
(host, port, params): (String, Option<u16>, Option<Table>),
122+
) -> Result<StdResult<UdpSocket, String>> {
123+
let port = port.unwrap_or(0);
124+
let recv_timeout = opt_param!(Duration, params, "recv_timeout")?;
125+
126+
let socket = lua_try!(tokio::net::UdpSocket::bind((host, port)).await);
127+
128+
Ok(Ok(UdpSocket { socket, recv_timeout }))
129+
}

tests/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ async fn run_file(modname: &str) -> Result<()> {
2525
{
2626
mlua_stdlib::net::register(&lua, None)?;
2727
mlua_stdlib::net::tcp::register(&lua, None)?;
28+
mlua_stdlib::net::udp::register(&lua, None)?;
2829
#[cfg(unix)]
2930
mlua_stdlib::net::unix::register(&lua, None)?;
3031
}
@@ -104,6 +105,7 @@ include_tests! {
104105
#[cfg(feature = "net")]
105106
net {
106107
tcp,
108+
udp,
107109
#[cfg(feature = "tls")]
108110
tls,
109111
#[cfg(unix)]

tests/lua/net/udp_tests.lua

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
local task = require("@task")
2+
local time = require("@time")
3+
local udp = require("@net/udp")
4+
5+
testing:test("UDP ping-pong", function(t)
6+
local server, server_err = udp.bind("127.0.0.1")
7+
t.assert_ne(server, nil, server_err)
8+
local server_port = server:local_addr():match(":(%d+)$")
9+
10+
local client, client_err = udp.bind("127.0.0.1")
11+
t.assert_ne(client, nil, client_err)
12+
13+
-- Server task
14+
task.spawn(function()
15+
local data, addr_or_err = server:recv_from(100)
16+
t.assert_ne(data, nil, addr_or_err)
17+
local addr = addr_or_err
18+
t.assert_eq(data, "ping")
19+
20+
-- Extract port from address
21+
local client_host, client_port = addr:match("([^:]+):(%d+)")
22+
local sent, send_err = server:send_to("pong", client_host, client_port)
23+
t.assert_ne(sent, nil, send_err)
24+
end)
25+
26+
-- Client send_to/receive_from
27+
local sent, send_err = client:send_to("ping", "127.0.0.1", server_port)
28+
t.assert_ne(sent, nil, send_err)
29+
t.assert_eq(sent, 4)
30+
local data, recv_err = client:recv_from(100)
31+
t.assert_ne(data, nil, recv_err)
32+
t.assert_eq(data, "pong")
33+
end)
34+
35+
testing:test("UDP connected socket", function(t)
36+
local server, server_err = udp.bind("127.0.0.1")
37+
t.assert_ne(server, nil, server_err)
38+
local server_port = server:local_addr():match(":(%d+)$")
39+
40+
local client, client_err = udp.bind("127.0.0.1")
41+
t.assert_ne(client, nil, client_err)
42+
43+
-- Connect client to server
44+
local ok, connect_err = client:connect("127.0.0.1", server_port)
45+
t.assert_ne(ok, nil, connect_err)
46+
47+
-- Server task
48+
task.spawn(function()
49+
local data, addr, recv_err = server:recv_from(100)
50+
t.assert_ne(data, nil, recv_err)
51+
t.assert_eq(data, "hello")
52+
53+
local client_host, client_port = addr:match("([^:]+):(%d+)")
54+
server:send_to("world", client_host, client_port)
55+
end)
56+
57+
-- Client uses send/recv (after connect)
58+
local sent, send_err = client:send("hello")
59+
t.assert_ne(sent, nil, send_err)
60+
t.assert_eq(sent, 5)
61+
local data, recv_err = client:recv(100)
62+
t.assert_ne(data, nil, recv_err)
63+
t.assert_eq(data, "world")
64+
end)
65+
66+
testing:test("UDP timeout", function(t)
67+
local socket, err = udp.bind("127.0.0.1", nil, { recv_timeout = "100ms" })
68+
t.assert_ne(socket, nil, err)
69+
70+
local start = time.instant()
71+
local data, recv_err = socket:recv(100)
72+
local elapsed = start:elapsed():as_secs()
73+
74+
t.assert_eq(data, nil)
75+
t.assert_match(recv_err, "deadline has elapsed")
76+
t.assert(elapsed >= 0.1, "elapsed time should be at least 100ms, got " .. tostring(elapsed))
77+
end)
78+
79+
testing:test("UDP broadcast", function(t)
80+
local socket, err = udp.bind("0.0.0.0")
81+
t.assert_ne(socket, nil, err)
82+
83+
t.assert_eq(socket:broadcast(), false)
84+
local bc_ok, bc_err = socket:set_broadcast(true)
85+
t.assert_eq(bc_ok, true, bc_err)
86+
t.assert_eq(socket:broadcast(), true)
87+
end)

0 commit comments

Comments
 (0)