Skip to content

Commit 648921d

Browse files
committed
Add net/unix module
1 parent c73547b commit 648921d

File tree

9 files changed

+321
-23
lines changed

9 files changed

+321
-23
lines changed

src/net/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,18 @@ pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
1414
Ok(value)
1515
}
1616

17+
macro_rules! with_io_timeout {
18+
($timeout:expr, $fut:expr) => {
19+
match $timeout {
20+
Some(dur) => (tokio::time::timeout(dur.0, $fut).await)
21+
.map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e))
22+
.flatten(),
23+
None => $fut.await,
24+
}
25+
};
26+
}
27+
1728
pub mod tcp;
29+
30+
#[cfg(unix)]
31+
pub mod unix;

src/net/tcp/socket.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ use std::net::SocketAddr;
33
use std::ops::Deref;
44

55
use mlua::{Result, Table};
6-
use tokio::net::TcpSocket as TokioTcpSocket;
76

8-
pub(crate) struct TcpSocket(pub(crate) TokioTcpSocket);
7+
pub(crate) struct TcpSocket(pub(crate) tokio::net::TcpSocket);
98

109
impl Deref for TcpSocket {
11-
type Target = TokioTcpSocket;
10+
type Target = tokio::net::TcpSocket;
1211

1312
#[inline]
1413
fn deref(&self) -> &Self::Target {
@@ -29,8 +28,8 @@ pub(super) struct SocketOptions {
2928
impl TcpSocket {
3029
pub(crate) fn new_for_addr(addr: SocketAddr) -> IoResult<Self> {
3130
let sock = match addr {
32-
SocketAddr::V4(_) => TokioTcpSocket::new_v4()?,
33-
SocketAddr::V6(_) => TokioTcpSocket::new_v6()?,
31+
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
32+
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
3433
};
3534
Ok(TcpSocket(sock))
3635
}

src/net/tcp/stream.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@ use std::result::Result as StdResult;
55
use mlua::{Lua, Result, String as LuaString, Table, UserData, UserDataMethods, UserDataRegistry};
66
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
77
use tokio::net::lookup_host;
8-
use tokio::time::timeout;
98

109
use super::{SocketOptions, TcpSocket};
1110
use crate::time::Duration;
1211

1312
pub struct TcpStream {
14-
stream: tokio::net::TcpStream,
15-
read_timeout: Option<Duration>,
16-
write_timeout: Option<Duration>,
13+
pub(crate) stream: tokio::net::TcpStream,
14+
pub(crate) read_timeout: Option<Duration>,
15+
pub(crate) write_timeout: Option<Duration>,
1716
}
1817

1918
impl Deref for TcpStream {
@@ -32,17 +31,6 @@ impl DerefMut for TcpStream {
3231
}
3332
}
3433

35-
macro_rules! with_io_timeout {
36-
($timeout:expr, $fut:expr) => {
37-
match $timeout {
38-
Some(dur) => (timeout(dur.0, $fut).await)
39-
.map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e))
40-
.flatten(),
41-
None => $fut.await,
42-
}
43-
};
44-
}
45-
4634
impl From<tokio::net::TcpStream> for TcpStream {
4735
fn from(stream: tokio::net::TcpStream) -> Self {
4836
TcpStream {

src/net/unix/listener.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use std::path::PathBuf;
2+
use std::result::Result as StdResult;
3+
use std::sync::Arc;
4+
5+
use mlua::{ExternalResult as _, Lua, Result, Table, UserData, UserDataMethods, UserDataRegistry};
6+
7+
use super::UnixStream;
8+
9+
pub struct UnixListener {
10+
listener: tokio::net::UnixListener,
11+
unlink_on_drop: bool,
12+
}
13+
14+
impl Drop for UnixListener {
15+
fn drop(&mut self) {
16+
if self.unlink_on_drop {
17+
if let Ok(addr) = self.listener.local_addr() {
18+
if let Some(path) = addr.as_pathname() {
19+
let _ = std::fs::remove_file(path);
20+
}
21+
}
22+
}
23+
}
24+
}
25+
26+
impl UserData for UnixListener {
27+
fn register(registry: &mut UserDataRegistry<Self>) {
28+
registry.add_method("local_addr", |_, this, ()| {
29+
this.listener
30+
.local_addr()
31+
.map(|addr| {
32+
addr.as_pathname()
33+
.map(|p| p.to_string_lossy().to_string())
34+
.unwrap_or_else(|| "(unnamed)".to_string())
35+
})
36+
.into_lua_err()
37+
});
38+
39+
registry.add_async_function("listen", listen);
40+
41+
registry.add_async_method("accept", |_, this, ()| async move {
42+
let (stream, _) = lua_try!(this.listener.accept().await);
43+
Ok(Ok(UnixStream::from(stream)))
44+
});
45+
}
46+
}
47+
48+
pub async fn listen(
49+
_: Lua,
50+
(path, params): (String, Option<Table>),
51+
) -> Result<StdResult<UnixListener, String>> {
52+
let path = Arc::new(PathBuf::from(path));
53+
54+
let path2 = path.clone();
55+
let res = tokio::task::spawn_blocking(move || {
56+
// Remove the socket file if it already exists
57+
if path2.exists() {
58+
if let Err(err) = std::fs::remove_file(&*path2) {
59+
return Err(format!("failed to remove existing socket file: {err}"));
60+
}
61+
}
62+
Ok(())
63+
})
64+
.await;
65+
lua_try!(lua_try!(res));
66+
67+
// Control whether to remove the socket file on drop or not
68+
let unlink_on_drop = opt_param!(params, "unlink_on_drop")?.unwrap_or(false);
69+
70+
let listener = lua_try!(tokio::net::UnixListener::bind(&*path));
71+
Ok(Ok(UnixListener {
72+
listener,
73+
unlink_on_drop,
74+
}))
75+
}

src/net/unix/mod.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
mod listener;
2+
mod stream;
3+
4+
pub use listener::{UnixListener, listen};
5+
pub use stream::{UnixStream, connect};
6+
7+
use mlua::{Lua, Result, Table};
8+
9+
/// Registers the `unix` module in the given Lua state.
10+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
11+
let name = name.unwrap_or("@unix");
12+
let t = lua.create_table()?;
13+
t.set("UnixListener", lua.create_proxy::<UnixListener>()?)?;
14+
t.set("UnixStream", lua.create_proxy::<UnixStream>()?)?;
15+
t.set("listen", lua.create_async_function(listen)?)?;
16+
t.set("connect", lua.create_async_function(connect)?)?;
17+
lua.register_module(name, &t)?;
18+
Ok(t)
19+
}

src/net/unix/stream.rs

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
use std::ops::{Deref, DerefMut};
2+
use std::path::PathBuf;
3+
use std::result::Result as StdResult;
4+
5+
use mlua::{Lua, Result, String as LuaString, Table, UserData, UserDataMethods, UserDataRegistry};
6+
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
7+
8+
use crate::time::Duration;
9+
10+
pub struct UnixStream {
11+
pub(crate) stream: tokio::net::UnixStream,
12+
pub(crate) read_timeout: Option<Duration>,
13+
pub(crate) write_timeout: Option<Duration>,
14+
}
15+
16+
impl Deref for UnixStream {
17+
type Target = tokio::net::UnixStream;
18+
19+
#[inline]
20+
fn deref(&self) -> &Self::Target {
21+
&self.stream
22+
}
23+
}
24+
25+
impl DerefMut for UnixStream {
26+
#[inline]
27+
fn deref_mut(&mut self) -> &mut Self::Target {
28+
&mut self.stream
29+
}
30+
}
31+
32+
impl From<tokio::net::UnixStream> for UnixStream {
33+
fn from(stream: tokio::net::UnixStream) -> Self {
34+
UnixStream {
35+
stream,
36+
read_timeout: None,
37+
write_timeout: None,
38+
}
39+
}
40+
}
41+
42+
impl UserData for UnixStream {
43+
fn register(registry: &mut UserDataRegistry<Self>) {
44+
registry.add_async_function("connect", connect);
45+
46+
registry.add_method("local_addr", |_, this, ()| {
47+
Ok(this.local_addr().map(|addr| {
48+
addr.as_pathname()
49+
.map(|p| p.to_string_lossy().to_string())
50+
.unwrap_or_else(|| "(unnamed)".to_string())
51+
})?)
52+
});
53+
54+
registry.add_method("peer_addr", |_, this, ()| {
55+
Ok(this.peer_addr().map(|addr| {
56+
addr.as_pathname()
57+
.map(|p| p.to_string_lossy().to_string())
58+
.unwrap_or_else(|| "(unnamed)".to_string())
59+
})?)
60+
});
61+
62+
registry.add_method_mut("set_read_timeout", |_, this, dur: Option<Duration>| {
63+
this.read_timeout = dur;
64+
Ok(())
65+
});
66+
67+
registry.add_method_mut("set_write_timeout", |_, this, dur: Option<Duration>| {
68+
this.write_timeout = dur;
69+
Ok(())
70+
});
71+
72+
registry.add_async_method_mut("read", |lua, mut this, size: usize| async move {
73+
let mut buf = vec![0; size];
74+
let n = with_io_timeout!(this.read_timeout, this.read(&mut buf));
75+
let n = lua_try!(n);
76+
buf.truncate(n);
77+
Ok(Ok(lua.create_string(buf)?))
78+
});
79+
80+
registry.add_async_method_mut("read_to_end", |lua, mut this, ()| async move {
81+
let mut buf = Vec::new();
82+
let n = with_io_timeout!(this.read_timeout, this.read_to_end(&mut buf));
83+
let _n = lua_try!(n);
84+
Ok(Ok(lua.create_string(buf)?))
85+
});
86+
87+
registry.add_async_method_mut("write", |_, mut this, data: LuaString| async move {
88+
let n = with_io_timeout!(this.write_timeout, this.write(&data.as_bytes()));
89+
let n = lua_try!(n);
90+
Ok(Ok(n))
91+
});
92+
93+
registry.add_async_method_mut("write_all", |_, mut this, data: LuaString| async move {
94+
let r = with_io_timeout!(this.write_timeout, this.write_all(&data.as_bytes()));
95+
lua_try!(r);
96+
Ok(Ok(true))
97+
});
98+
99+
registry.add_async_method_mut("flush", |_, mut this, ()| async move {
100+
let r = with_io_timeout!(this.write_timeout, this.flush());
101+
lua_try!(r);
102+
Ok(Ok(true))
103+
});
104+
105+
registry.add_async_method_mut("shutdown", |_, mut this, ()| async move {
106+
lua_try!(this.shutdown().await);
107+
Ok(Ok(true))
108+
});
109+
}
110+
}
111+
112+
pub async fn connect(
113+
_: Lua,
114+
(path, params): (String, Option<Table>),
115+
) -> Result<StdResult<UnixStream, String>> {
116+
let path = PathBuf::from(path);
117+
118+
let timeout = opt_param!(Duration, params, "timeout")?; // A single timeout for any operation
119+
let connect_timeout = opt_param!(Duration, params, "connect_timeout")?.or(timeout);
120+
let read_timeout = opt_param!(Duration, params, "read_timeout")?.or(timeout);
121+
let write_timeout = opt_param!(Duration, params, "write_timeout")?.or(timeout);
122+
123+
let stream = with_io_timeout!(connect_timeout, tokio::net::UnixStream::connect(path));
124+
let stream = lua_try!(stream);
125+
126+
Ok(Ok(UnixStream {
127+
stream,
128+
read_timeout,
129+
write_timeout,
130+
}))
131+
}

tests/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ async fn run_file(modname: &str) -> Result<()> {
2525
{
2626
mlua_stdlib::net::register(&lua, None)?;
2727
mlua_stdlib::net::tcp::register(&lua, None)?;
28+
#[cfg(unix)]
29+
mlua_stdlib::net::unix::register(&lua, None)?;
2830
}
31+
#[cfg(feature = "tls")]
32+
mlua_stdlib::net::tls::register(&lua, None)?;
2933
#[cfg(feature = "task")]
3034
mlua_stdlib::task::register(&lua, None)?;
3135

@@ -58,11 +62,12 @@ macro_rules! include_tests {
5862
() => {};
5963

6064
// Grouped tests
61-
($(#[$meta:meta])? $group:ident { $($item:ident),* $(,)? }, $($rest:tt)*) => {
65+
($(#[$meta:meta])* $group:ident { $($(#[$item_meta:meta])* $item:ident),* $(,)? }, $($rest:tt)*) => {
6266
$(#[$meta])*
6367
mod $group {
6468
use super::*;
6569
$(
70+
$(#[$item_meta])*
6671
#[tokio::test]
6772
async fn $item() -> Result<()> {
6873
run_file(&format!("{}/{}", stringify!($group), stringify!($item))).await
@@ -99,6 +104,8 @@ include_tests! {
99104
#[cfg(feature = "net")]
100105
net {
101106
tcp,
107+
#[cfg(unix)]
108+
unix,
102109
},
103110

104111
#[cfg(feature = "task")]

tests/lua/net/tcp_tests.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ testing:test("Tcp ping-pong", function(t)
1717
if data == "ping" then
1818
stream:write_all("pong")
1919
else
20-
stream:write_all(string.rep(data, 2))
20+
stream:write_all(string.reverse(data))
2121
end
2222
end
2323
end)
@@ -29,7 +29,7 @@ testing:test("Tcp ping-pong", function(t)
2929
t.assert_eq(response, "pong")
3030
stream:write_all("hello")
3131
local response2 = stream:read(100)
32-
t.assert_eq(response2, "hellohello")
32+
t.assert_eq(response2, "olleh")
3333
stream:shutdown()
3434
end)
3535

0 commit comments

Comments
 (0)