Skip to content

Commit

Permalink
Expose mqtt events
Browse files Browse the repository at this point in the history
  • Loading branch information
akiroz committed Feb 5, 2024
1 parent e7498d1 commit e749cd5
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "zika"
version = "3.3.6"
version = "3.4.0"
license = "MIT"
description = "IP Tunneling over MQTT"
repository = "https://github.com/akiroz/zika"
Expand Down
31 changes: 9 additions & 22 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ use std::error::Error as StdError;
use std::net::Ipv4Addr;
use std::sync::Arc;

use bytes::Bytes;
use base64::{engine::general_purpose, Engine as _};
use futures::{SinkExt, stream::{SplitSink, StreamExt}};
use rand::{thread_rng, Rng, distributions::Standard};

use rumqttc;
use ipnetwork::Ipv4Network;
use tokio::{task, sync::{broadcast, Mutex}};
use tokio::{task, sync::Mutex};
use tokio_util::codec::Framed;
use tun::{AsyncDevice, TunPacket, TunPacketCodec};

Expand All @@ -23,8 +22,7 @@ type TunSink = SplitSink<Framed<AsyncDevice, TunPacketCodec>, TunPacket>;
pub struct Client {
pub local_addr: Ipv4Addr,
tunnels: Arc<Vec<Tunnel>>,
pub remote: Arc<Mutex<remote::Remote>>, // Allow external mqtt ops
pub remote_passthru_recv: Arc<Mutex<broadcast::Receiver<(String, Bytes)>>>,
pub remote: Arc<Mutex<remote::Remote>>, // Allow external mqtt access
}

struct Tunnel {
Expand Down Expand Up @@ -95,12 +93,10 @@ impl Client {
});
}

let (remote_passthru_send, remote_passthru_recv) = broadcast::channel(1);
let client = Arc::new(Self {
local_addr,
tunnels: Arc::new(tunnels),
remote: Arc::new(Mutex::new(remote)),
remote_passthru_recv: Arc::new(Mutex::new(remote_passthru_recv)),
});

let loop_client = client.clone();
Expand All @@ -120,20 +116,13 @@ impl Client {
let loop2_client = client.clone();
task::spawn(async move {
loop {
if let Some((topic, msg)) = remote_recv.recv().await {
let handle_result = loop2_client.handle_remote_message(&mut tun_sink, topic.as_str(), &msg).await;
match handle_result {
Err(err) => log::error!("handle_remote_message error {:?}", err),
Ok(handled) => {
if !handled {
if let Err(err) = remote_passthru_send.send((topic, msg)) {
log::warn!("remote_passthru_send error {:?}", err);
}
}
match remote_recv.recv().await {
None => panic!("remote_recv: None"),
Some((topic, msg)) => {
if let Err(err) = loop2_client.handle_remote_message(&mut tun_sink, topic.as_str(), &msg).await {
log::error!("handle_remote_message {:?}", err);
}
}
} else {
break;
}
}
});
Expand Down Expand Up @@ -166,14 +155,12 @@ impl Client {
}

// mqtt -> tun
async fn handle_remote_message(&self, tun_sink: &mut TunSink, topic: &str, msg: &[u8]) -> Result<bool, Box<dyn StdError>> {
async fn handle_remote_message(&self, tun_sink: &mut TunSink, topic: &str, msg: &[u8]) -> Result<(), Box<dyn StdError>> {
if let Some(tunnel) = self.tunnels.iter().find(|&t| t.topic == topic) {
let pkt = nat::do_nat(msg, tunnel.bind_addr, self.local_addr)?;
tun_sink.send(TunPacket::new(pkt)).await?;
Ok(true)
} else {
Ok(false)
}
Ok(())
}

}
39 changes: 26 additions & 13 deletions src/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use std::ops::Range;

use log;
use bytes::Bytes;
use tokio::{sync::{mpsc, Mutex}, task};
use tokio::{sync::{mpsc, broadcast, Mutex}, task};
use rumqttc::v5::{
self as mqtt,
mqttbytes::{v5::Packet, QoS, v5::PublishProperties},
mqttbytes::{v5::{Packet, PublishProperties}, QoS},
};

use crate::lookup_pool::LookupPool;
Expand All @@ -16,8 +16,9 @@ use crate::lookup_pool::LookupPool;
struct RemoteIncomingContext {
nth: usize,
mqtt_client: Arc<mqtt::AsyncClient>,
sender: mpsc::Sender<(String, Bytes)>,
subs: Arc<Mutex<Vec<String>>>,
msg_send: mpsc::Sender<(String, Bytes)>,
evt_send: broadcast::Sender<(usize, Packet)>,
}

struct RemoteClient {
Expand All @@ -29,18 +30,21 @@ struct RemoteClient {
pub struct Remote {
clients: Vec<RemoteClient>,
subs: Arc<Mutex<Vec<String>>>,
pub on_event: broadcast::Receiver<(usize, Packet)>,
}

impl Remote {
pub fn new(
broker_opts: &Vec<mqtt::MqttOptions>,
topics: Vec<String>,
) -> (Self, mpsc::Receiver<(String, Bytes)>) {
let (sender, receiver) = mpsc::channel(128);
pub fn new(broker_opts: &Vec<mqtt::MqttOptions>, topics: Vec<String>) -> (
Self,
mpsc::Receiver<(String, Bytes)>,
) {
let (msg_send, msg_recv) = mpsc::channel(64);
let (evt_send, evt_recv) = broadcast::channel(1);
let subs = Arc::new(Mutex::new(topics));
let mut remote = Self {
clients: Vec::with_capacity(broker_opts.len()),
subs: subs.clone(),
on_event: evt_recv,
};
for (idx, opt) in broker_opts.iter().enumerate() {
log::debug!("broker[{}] opts {:?}", idx, opt);
Expand All @@ -59,16 +63,23 @@ impl Remote {
let mut context = RemoteIncomingContext {
nth: idx,
mqtt_client: arc_mqtt_client,
sender: sender.clone(),
subs: subs.clone(),
msg_send: msg_send.clone(),
evt_send: evt_send.clone(),

};
task::spawn(async move {
loop {
let evt = event_loop.poll().await;
use mqtt::Event::Incoming;
match event_loop.poll().await {
match evt {
Ok(Incoming(pkt)) => {
log::trace!("broker[{}] recv {:?}", idx, pkt);
Self::handle_packet(&mut context, pkt).await;
Self::handle_packet(&mut context, pkt.clone()).await;
context.evt_send.send((idx, pkt)).unwrap_or_else(|err| {
log::warn!("broker[{}] evt_send {:?}", idx, err);
0
});
}
Err(err) => {
log::warn!("broker[{}] recv {:?}", idx, err);
Expand All @@ -84,7 +95,7 @@ impl Remote {
});
remote.clients.push(remote_client);
}
(remote, receiver)
(remote, msg_recv)
}

async fn handle_packet(context: &mut RemoteIncomingContext, pkt: Packet) {
Expand Down Expand Up @@ -114,7 +125,9 @@ impl Remote {
.ok()
.filter(|n| n.len() > 0);
if let Some(topic) = topic_str {
_ = context.sender.send((topic, payload)).await; // What if it's not ok?
context.msg_send.send((topic, payload)).await.unwrap_or_else(|err| {
log::warn!("broker[{}] msg_send {:?}", context.nth, err);
});
} else {
log::debug!("drop packet, non utf8 topic: {:?}", topic);
}
Expand Down
27 changes: 16 additions & 11 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct Server {
pub topic: String,
pub local_addr: Ipv4Addr,
ip_pool: Arc<Mutex<IpPool>>,
pub remote: Arc<Mutex<remote::Remote>>, // Allow external mqtt access
}

impl Server {
Expand All @@ -42,7 +43,7 @@ impl Server {
let mut ip_iter = SizedIpv4NetworkIterator::new(ip_network);
let local_addr = ip_iter.next().expect("subnet size > 1");

let (mut remote, mut remote_recv) = remote::Remote::new(&mqtt_options, vec![server_config.topic.clone()]);
let (remote, mut remote_recv) = remote::Remote::new(&mqtt_options, vec![server_config.topic.clone()]);

log::info!("bind {:?}/{}", local_addr, ip_network.prefix());

Expand All @@ -66,14 +67,17 @@ impl Server {
topic: server_config.topic,
local_addr,
ip_pool: Arc::new(Mutex::new(LookupPool::new(ip_iter))),
remote: Arc::new(Mutex::new(remote)),
});

let loop_ip_pool = server.ip_pool.clone();
let loop_remote = server.remote.clone();
task::spawn(async move {
while let Some(packet) = tun_stream.next().await {
match packet {
Ok(pkt) => {
let mut ip_pool = loop_ip_pool.lock().await;
let mut remote = loop_remote.lock().await;
let result = Self::handle_packet(&mut remote, &mut ip_pool, &pkt).await;
if let Err(err) = result {
log::error!("handle_packet error {:?}", err);
Expand All @@ -87,14 +91,15 @@ impl Server {
let loop_server = server.clone();
task::spawn(async move {
loop {
if let Some((_topic, payload)) = remote_recv.recv().await {
let (id, msg) = payload.split_at(server_config.id_length);
let handle_result = loop_server.handle_remote_message(&mut tun_sink, id, msg).await;
if let Err(err) = handle_result {
log::error!("handle_remote_message error {:?}", err);
match remote_recv.recv().await {
None => panic!("remote_recv: None"),
Some((_topic, msg)) => {
let (id, pkt) = msg.split_at(server_config.id_length);
let handle_result = loop_server.handle_remote_message(&mut tun_sink, id, pkt).await;
if let Err(err) = handle_result {
log::error!("handle_remote_message error {:?}", err);
}
}
} else {
break;
}
}
});
Expand All @@ -118,7 +123,7 @@ impl Server {
}

// mqtt -> tun
async fn handle_remote_message(&self, tun_sink: &mut TunSink, id: &[u8], msg: &[u8]) -> Result<(), Box<dyn StdError>> {
async fn handle_remote_message(&self, tun_sink: &mut TunSink, id: &[u8], pkt: &[u8]) -> Result<(), Box<dyn StdError>> {
let base64_id = general_purpose::URL_SAFE_NO_PAD.encode(id);
let (existing_tunnel, ip) = {
let mut ip_pool = self.ip_pool.lock().await;
Expand All @@ -127,8 +132,8 @@ impl Server {
if !existing_tunnel {
log::info!("alloc tunnel {} (IP {})", base64_id, ip);
}
let pkt = nat::do_nat(msg, ip, self.local_addr)?;
tun_sink.send(TunPacket::new(pkt)).await?;
let nat_pkt = nat::do_nat(pkt, ip, self.local_addr)?;
tun_sink.send(TunPacket::new(nat_pkt)).await?;
Ok(())
}
}

0 comments on commit e749cd5

Please sign in to comment.