Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
782 changes: 773 additions & 9 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[package]
name = "defguard-gateway"
version = "1.6.0"
edition = "2021"
edition = "2024"

[dependencies]
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "8649a9ba225d7bd2066a09c9e1347705c34bd158" }
axum = "0.8"
base64 = "0.22"
clap = { version = "4.5", features = ["derive", "env"] }
defguard_wireguard_rs = "0.7.7"
defguard_wireguard_rs = { git = "https://github.com/DefGuard/wireguard-rs", rev = "0db4ea7bf4a6bd21c449f9ab8fa6676aebf4698f" }
env_logger = "0.11"
gethostname = "1.0"
ipnetwork = "0.21"
Expand Down
1 change: 1 addition & 0 deletions deny.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ allow = [
"Apache-2.0",
"Apache-2.0 WITH LLVM-exception",
"MPL-2.0",
"BSD-2-Clause",
"BSD-3-Clause",
"Unicode-3.0",
"Unicode-DFS-2016", # unicode-ident
Expand Down
4 changes: 2 additions & 2 deletions examples/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
collections::HashMap,
io::{stdout, Write},
io::{Write, stdout},
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::{Arc, Mutex},
};
Expand All @@ -19,7 +19,7 @@ use tokio::{
},
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{transport::Server, Request, Response, Status, Streaming};
use tonic::{Request, Response, Status, Streaming, transport::Server};

pub struct HostConfig {
name: String,
Expand Down
2 changes: 1 addition & 1 deletion src/enterprise/firewall/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl FirewallApi {
pub(crate) trait FirewallManagementApi {
/// Set up the firewall with `default_policy`, `priority`, and cleans up any existing rules.
fn setup(&mut self, default_policy: Policy, priority: Option<i32>)
-> Result<(), FirewallError>;
-> Result<(), FirewallError>;

/// Clean up the firewall rules.
fn cleanup(&mut self) -> Result<(), FirewallError>;
Expand Down
6 changes: 1 addition & 5 deletions src/enterprise/firewall/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,7 @@ pub(crate) enum Policy {

impl From<bool> for Policy {
fn from(allow: bool) -> Self {
if allow {
Self::Allow
} else {
Self::Deny
}
if allow { Self::Allow } else { Self::Deny }
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/enterprise/firewall/nftables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use netfilter::{
use nftnl::Batch;

use super::{
Address, FirewallError, FirewallRule, Policy, Port, Protocol, SnatBinding,
api::{FirewallApi, FirewallManagementApi},
iprange::IpAddrRangeError,
Address, FirewallError, FirewallRule, Policy, Port, Protocol, SnatBinding,
};
use crate::enterprise::firewall::iprange::IpAddrRange;

Expand Down Expand Up @@ -273,7 +273,9 @@ impl FirewallManagementApi for FirewallApi {
masquerade_enabled: bool,
snat_bindings: &[SnatBinding],
) -> Result<(), FirewallError> {
debug!("Setting up POSTROUTING chain rules with masquerade status: {masquerade_enabled} and SNAT bindings: {snat_bindings:?}");
debug!(
"Setting up POSTROUTING chain rules with masquerade status: {masquerade_enabled} and SNAT bindings: {snat_bindings:?}"
);

if let Some(batch) = &mut self.batch {
set_nat_rules(batch, &self.ifname, masquerade_enabled, snat_bindings)?;
Expand Down
51 changes: 24 additions & 27 deletions src/enterprise/firewall/nftables/netfilter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use std::{
};

use nftnl::{
Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table,
expr::{Expression, Immediate, InterfaceName, Nat, NatType, Register},
nft_expr, nftnl_sys,
set::{Set, SetKey},
Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table,
};

use super::{get_set_id, Address, FilterRule, Policy, Port, Protocol, State};
use crate::enterprise::firewall::{iprange::IpAddrRange, max_address, FirewallError, SnatBinding};
use super::{Address, FilterRule, Policy, Port, Protocol, State, get_set_id};
use crate::enterprise::firewall::{FirewallError, SnatBinding, iprange::IpAddrRange, max_address};

const FILTER_TABLE: &str = "filter";
const NAT_TABLE: &str = "nat";
Expand Down Expand Up @@ -98,7 +98,7 @@ fn add_address_to_set(set: *mut nftnl_sys::nftnl_set, ip: &Address) -> Result<()
return Err(FirewallError::InvalidConfiguration(format!(
"Expected both addresses to be of the same type, got {net:?} and \
{upper_bound:?}",
)))
)));
}
}
}
Expand Down Expand Up @@ -308,29 +308,26 @@ impl FirewallRule for FilterRule<'_> {
// 1 Protocol
// > 0 Ports
else if !self.dest_ports.is_empty() {
if let Some(protocol) = self.protocols.first() {
if protocol.supports_ports() {
let set = new_anon_set::<InetService>(
chain.get_table(),
ProtoFamily::Inet,
true,
)?;
batch.add(&set, nftnl::MsgType::Add);

for port in self.dest_ports {
add_port_to_set(set.as_ptr(), port)?;
}

// <protocol> dport {x, x-x}
set.elems_iter().for_each(|elem| {
batch.add(&elem, nftnl::MsgType::Add);
});

rule.add_expr(&nft_expr!(meta l4proto));
rule.add_expr(&nft_expr!(cmp == *protocol as u8));
rule.add_expr(protocol.as_port_payload_expr()?);
rule.add_expr(&nft_expr!(lookup & set));
if let Some(protocol) = self.protocols.first()
&& protocol.supports_ports()
{
let set =
new_anon_set::<InetService>(chain.get_table(), ProtoFamily::Inet, true)?;
batch.add(&set, nftnl::MsgType::Add);

for port in self.dest_ports {
add_port_to_set(set.as_ptr(), port)?;
}

// <protocol> dport {x, x-x}
set.elems_iter().for_each(|elem| {
batch.add(&elem, nftnl::MsgType::Add);
});

rule.add_expr(&nft_expr!(meta l4proto));
rule.add_expr(&nft_expr!(cmp == *protocol as u8));
rule.add_expr(protocol.as_port_payload_expr()?);
rule.add_expr(&nft_expr!(lookup & set));
}

debug!(
Expand Down Expand Up @@ -876,7 +873,7 @@ pub(crate) fn send_batch(batch: &FinalizedBatch) -> Result<(), FirewallError> {
Err(err) => {
return Err(FirewallError::NetlinkError(format!(
"There was an error while sending netlink messages: {err:?}"
)))
)));
}
};
}
Expand Down
6 changes: 3 additions & 3 deletions src/enterprise/firewall/packetfilter/api.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::os::fd::AsRawFd;

use super::{
calls::{pf_begin, pf_commit, pf_rollback, IocTrans, IocTransElement},
rule::RuleSet,
FirewallRule,
calls::{IocTrans, IocTransElement, pf_begin, pf_commit, pf_rollback},
rule::RuleSet,
};
use crate::enterprise::firewall::{
api::{FirewallApi, FirewallManagementApi},
FirewallError, Policy, SnatBinding,
api::{FirewallApi, FirewallManagementApi},
};

impl FirewallManagementApi for FirewallApi {
Expand Down
4 changes: 2 additions & 2 deletions src/enterprise/firewall/packetfilter/calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
use std::{
ffi::{c_char, c_int, c_long, c_uchar, c_uint, c_ulong, c_ushort, c_void},
fmt,
mem::{size_of, zeroed, MaybeUninit},
mem::{MaybeUninit, size_of, zeroed},
ptr,
};

use ipnetwork::IpNetwork;
use libc::{pid_t, uid_t, IFNAMSIZ};
use libc::{IFNAMSIZ, pid_t, uid_t};
use nix::{ioctl_none, ioctl_readwrite};

use super::rule::{Action, AddressFamily, Direction, PacketFilterRule, RuleSet, State};
Expand Down
6 changes: 3 additions & 3 deletions src/enterprise/firewall/packetfilter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ mod rule;

use std::os::fd::{AsRawFd, RawFd};

use calls::{pf_begin_addrs, IocPoolAddr};
use calls::{IocPoolAddr, pf_begin_addrs};
use rule::PacketFilterRule;

use self::calls::{pf_add_rule, Change, IocRule, Rule};
use super::{api::FirewallApi, FirewallError, FirewallRule};
use self::calls::{Change, IocRule, Rule, pf_add_rule};
use super::{FirewallError, FirewallRule, api::FirewallApi};
use crate::enterprise::firewall::Port;

const ANCHOR_PREFIX: &str = "defguard/";
Expand Down
52 changes: 23 additions & 29 deletions src/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,50 @@
use defguard_version::{
client::ClientVersionInterceptor, get_tracing_variables, ComponentInfo, DefguardComponent,
Version,
ComponentInfo, DefguardComponent, Version, client::ClientVersionInterceptor,
get_tracing_variables,
};
use defguard_wireguard_rs::{net::IpAddrMask, WireguardInterfaceApi};
use defguard_wireguard_rs::{WireguardInterfaceApi, net::IpAddrMask};
use gethostname::gethostname;
use std::{
collections::HashMap,
fs::read_to_string,
str::FromStr,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
atomic::{AtomicBool, Ordering},
},
time::{Duration, SystemTime},
};
use tokio::{
select,
sync::mpsc,
task::{spawn, JoinHandle},
task::{JoinHandle, spawn},
time::{interval, sleep},
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{
Request, Status, Streaming,
codegen::InterceptedService,
metadata::{Ascii, MetadataValue},
service::{Interceptor, InterceptorLayer},
transport::{Certificate, Channel, ClientTlsConfig, Endpoint},
Request, Status, Streaming,
};
use tower::ServiceBuilder;
use tracing::{instrument, Instrument};
use tracing::{Instrument, instrument};

use crate::{
VERSION,
config::Config,
enterprise::firewall::{
api::{FirewallApi, FirewallManagementApi},
FirewallConfig, FirewallRule, SnatBinding,
api::{FirewallApi, FirewallManagementApi},
},
error::GatewayError,
execute_command, mask,
proto::gateway::{
gateway_service_client::GatewayServiceClient, stats_update::Payload, update, Configuration,
ConfigurationRequest, Peer, StatsUpdate, Update,
Configuration, ConfigurationRequest, Peer, StatsUpdate, Update,
gateway_service_client::GatewayServiceClient, stats_update::Payload, update,
},
version::ensure_core_version_supported,
VERSION,
};

const TEN_SECS: Duration = Duration::from_secs(10);
Expand All @@ -55,7 +55,7 @@ struct InterfaceConfiguration {
name: String,
prvkey: String,
addresses: Vec<IpAddrMask>,
port: u32,
port: u16,
}

impl From<Configuration> for InterfaceConfiguration {
Expand All @@ -70,7 +70,7 @@ impl From<Configuration> for InterfaceConfiguration {
name: config.name,
prvkey: config.prvkey,
addresses,
port: config.port,
port: config.port as u16,
}
}
}
Expand Down Expand Up @@ -613,7 +613,7 @@ impl Gateway {
{
error!("Failed to update peer: {err}");
}
};
}
}
Some(update::Update::FirewallConfig(config)) => {
if self.config.disable_firewall_management {
Expand Down Expand Up @@ -746,14 +746,19 @@ mod tests {

#[cfg(not(any(target_os = "macos", target_os = "netbsd")))]
use defguard_wireguard_rs::Kernel;
#[cfg(target_os = "macos")]
#[cfg(any(target_os = "macos", target_os = "netbsd"))]
use defguard_wireguard_rs::Userspace;
use defguard_wireguard_rs::WGApi;
use ipnetwork::IpNetwork;

use super::*;
use crate::enterprise::firewall::{Address, FirewallRule, Policy, Port, Protocol};

#[cfg(any(target_os = "macos", target_os = "netbsd"))]
type WG = WGApi<Userspace>;
#[cfg(not(any(target_os = "macos", target_os = "netbsd")))]
type WG = WGApi<Kernel>;

#[tokio::test]
async fn test_configuration_comparison() {
let old_config = InterfaceConfiguration {
Expand Down Expand Up @@ -783,10 +788,7 @@ mod tests {
.map(|peer| (peer.pubkey.clone(), peer))
.collect();

#[cfg(any(target_os = "macos", target_os = "netbsd"))]
let wgapi = WGApi::<Userspace>::new("wg0".into()).unwrap();
#[cfg(not(target_os = "macos"))]
let wgapi = WGApi::<Kernel>::new("wg0".into()).unwrap();
let wgapi = WG::new("wg0").unwrap();
let config = Config::default();
let client = Gateway::setup_client(&config).unwrap();
let firewall_api = FirewallApi::new("wg0").unwrap();
Expand Down Expand Up @@ -975,11 +977,7 @@ mod tests {
snat_bindings: vec![],
};

#[cfg(target_os = "macos")]
let wgapi = WGApi::<Userspace>::new("wg0".into()).unwrap();
#[cfg(not(target_os = "macos"))]
let wgapi = WGApi::<Kernel>::new("wg0".into()).unwrap();

let wgapi = WG::new("wg0").unwrap();
let config = Config::default();
let client = Gateway::setup_client(&config).unwrap();
let mut gateway = Gateway {
Expand Down Expand Up @@ -1048,11 +1046,7 @@ mod tests {
snat_bindings: vec![],
};

#[cfg(target_os = "macos")]
let wgapi = WGApi::<Userspace>::new("wg0".into()).unwrap();
#[cfg(not(target_os = "macos"))]
let wgapi = WGApi::<Kernel>::new("wg0".into()).unwrap();

let wgapi = WG::new("wg0").unwrap();
let config = Config::default();
let client = Gateway::setup_client(&config).unwrap();
let mut gateway = Gateway {
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ extern crate log;
use std::{process::Command, str::FromStr, time::SystemTime};

use config::Config;
use defguard_wireguard_rs::{host::Peer, net::IpAddrMask, InterfaceConfiguration};
use defguard_wireguard_rs::{InterfaceConfiguration, host::Peer, net::IpAddrMask};
use error::GatewayError;
use syslog::{BasicLogger, Facility, Formatter3164};

Expand All @@ -33,7 +33,7 @@ pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "+", env!("VERGEN_G
/// Used to log sensitive/secret objects.
#[macro_export]
macro_rules! mask {
($object:expr, $field:ident) => {{
($object:expr_2021, $field:ident) => {{
let mut object = $object.clone();
object.$field = String::from("***");
object
Expand Down Expand Up @@ -99,7 +99,7 @@ impl From<proto::gateway::Configuration> for InterfaceConfiguration {
name: config.name,
prvkey: config.prvkey,
addresses,
port: config.port,
port: config.port as u16,
peers,
mtu: None,
}
Expand Down
Loading
Loading