Skip to content

Commit

Permalink
Merge pull request #14327 from chbruyand/dnsdist-tickets-key-hook
Browse files Browse the repository at this point in the history
dnsdist: add support for a callback when a new tickets key is added
  • Loading branch information
rgacogne authored Jul 4, 2024
2 parents 53bcb9f + 2eca15e commit d1ac698
Show file tree
Hide file tree
Showing 13 changed files with 248 additions and 20 deletions.
4 changes: 2 additions & 2 deletions pdns/dnsdistdist/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,14 @@ endif

if HAVE_DNS_OVER_TLS
if HAVE_GNUTLS
dnsdist_LDADD += -lgnutls
dnsdist_LDADD += $(GNUTLS_LIBS)
endif
endif

if HAVE_DNS_OVER_HTTPS

if HAVE_GNUTLS
dnsdist_LDADD += -lgnutls
dnsdist_LDADD += $(GNUTLS_LIBS)
endif

if HAVE_LIBH2OEVLOOP
Expand Down
23 changes: 22 additions & 1 deletion pdns/dnsdistdist/dnsdist-lua-hooks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
#include "dnsdist-lua-hooks.hh"
#include "dnsdist-lua.hh"
#include "lock.hh"
#include "tcpiohandler.hh"

namespace dnsdist::lua::hooks
{
using MaintenanceCallback = std::function<void()>;
using TicketsKeyAddedHook = std::function<void(const char*, size_t)>;

static LockGuarded<std::vector<MaintenanceCallback>> s_maintenanceHooks;

void runMaintenanceHooks(const LuaContext& context)
Expand All @@ -15,7 +19,7 @@ void runMaintenanceHooks(const LuaContext& context)
}
}

void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback)
static void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback)
{
(void)context;
s_maintenanceHooks.lock()->push_back(std::move(callback));
Expand All @@ -26,12 +30,29 @@ void clearMaintenanceHooks()
s_maintenanceHooks.lock()->clear();
}

static void setTicketsKeyAddedHook(const LuaContext& context, const TicketsKeyAddedHook& hook)
{
TLSCtx::setTicketsKeyAddedHook([hook](const std::string& key) {
try {
auto lua = g_lua.lock();
hook(key.c_str(), key.size());
}
catch (const std::exception& exp) {
warnlog("Error calling the Lua hook after new tickets key has been added: %s", exp.what());
}
});
}

void setupLuaHooks(LuaContext& luaCtx)
{
luaCtx.writeFunction("addMaintenanceCallback", [&luaCtx](const MaintenanceCallback& callback) {
setLuaSideEffect();
addMaintenanceCallback(luaCtx, callback);
});
luaCtx.writeFunction("setTicketsKeyAddedHook", [&luaCtx](const TicketsKeyAddedHook& hook) {
setLuaSideEffect();
setTicketsKeyAddedHook(luaCtx, hook);
});
}

}
2 changes: 0 additions & 2 deletions pdns/dnsdistdist/dnsdist-lua-hooks.hh
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ class LuaContext;

namespace dnsdist::lua::hooks
{
using MaintenanceCallback = std::function<void()>;
void runMaintenanceHooks(const LuaContext& context);
void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback);
void clearMaintenanceHooks();
void setupLuaHooks(LuaContext& luaCtx);
}
2 changes: 1 addition & 1 deletion pdns/dnsdistdist/dnsdist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
}
}

LockGuarded<LuaContext> g_lua{LuaContext()};
RecursiveLockGuarded<LuaContext> g_lua{LuaContext()};
ComboAddress g_serverControl{"127.0.0.1:5199"};

static void spoofResponseFromString(DNSQuestion& dnsQuestion, const string& spoofContent, bool raw)
Expand Down
2 changes: 1 addition & 1 deletion pdns/dnsdistdist/dnsdist.hh
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ public:
using servers_t = vector<std::shared_ptr<DownstreamState>>;

void responderThread(std::shared_ptr<DownstreamState> dss);
extern LockGuarded<LuaContext> g_lua;
extern RecursiveLockGuarded<LuaContext> g_lua;
extern std::string g_outputBuffer; // locking for this is ok, as locked by g_luamutex

class DNSRule
Expand Down
11 changes: 11 additions & 0 deletions pdns/dnsdistdist/docs/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,17 @@ Other functions
Code is supplied as a string, not as a function object.
Note that this function does nothing in 'client' or 'config-check' modes.

.. function:: setTicketsKeyAddedHook(callback)

.. versionadded:: 1.9.6

Set a Lua function that will be called everytime a new tickets key is added. The function receives:

* the key content as a string
* the keylen as an integer

See :doc:`../advanced/tls-sessions-management` for more information.

.. function:: submitToMainThread(cmd, dict)

.. versionadded:: 1.8.0
Expand Down
2 changes: 1 addition & 1 deletion pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
uint16_t g_maxOutstanding{std::numeric_limits<uint16_t>::max()};

#include "ext/luawrapper/include/LuaContext.hpp"
LockGuarded<LuaContext> g_lua{LuaContext()};
RecursiveLockGuarded<LuaContext> g_lua{LuaContext()};

bool g_snmpEnabled{false};
bool g_snmpTrapsEnabled{false};
Expand Down
21 changes: 21 additions & 0 deletions pdns/libssl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

#undef CERT
#include "misc.hh"
#include "tcpiohandler.hh"

#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL)
/* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */
Expand Down Expand Up @@ -631,6 +632,13 @@ OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() = default;
void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr<OpenSSLTLSTicketKey>&& newKey)
{
d_ticketKeys.write_lock()->push_front(std::move(newKey));
if (TLSCtx::hasTicketsKeyAddedHook()) {
auto key = d_ticketKeys.read_lock()->front();
auto keyContent = key->content();
TLSCtx::getTicketsKeyAddedHook()(keyContent);
// fills mem with 0's
OPENSSL_cleanse(keyContent.data(), keyContent.size());
}
}

std::shared_ptr<OpenSSLTLSTicketKey> OpenSSLTLSTicketKeysRing::getEncryptionKey()
Expand Down Expand Up @@ -737,6 +745,19 @@ bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_N
return (memcmp(d_name, name, sizeof(d_name)) == 0);
}

std::string OpenSSLTLSTicketKey::content() const
{
std::string result{};
result.reserve(TLS_TICKETS_KEY_NAME_SIZE + TLS_TICKETS_CIPHER_KEY_SIZE + TLS_TICKETS_MAC_KEY_SIZE);
// NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
result.append(reinterpret_cast<const char*>(d_name), TLS_TICKETS_KEY_NAME_SIZE);
result.append(reinterpret_cast<const char*>(d_cipherKey), TLS_TICKETS_CIPHER_KEY_SIZE);
result.append(reinterpret_cast<const char*>(d_hmacKey), TLS_TICKETS_MAC_KEY_SIZE);
// NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)

return result;
}

#if OPENSSL_VERSION_MAJOR >= 3
static const std::string sha256KeyName{"sha256"};
#endif
Expand Down
2 changes: 1 addition & 1 deletion pdns/libssl.hh
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public:
#if OPENSSL_VERSION_MAJOR >= 3
int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const;
bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const;
[[nodiscard]] std::string content() const;
#else
int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const;
bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const;
Expand All @@ -124,7 +125,6 @@ public:

private:
void addKey(std::shared_ptr<OpenSSLTLSTicketKey>&& newKey);

SharedLockGuarded<boost::circular_buffer<std::shared_ptr<OpenSSLTLSTicketKey> > > d_ticketKeys;
};

Expand Down
105 changes: 105 additions & 0 deletions pdns/lock.hh
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,111 @@ private:
T d_value;
};

template <typename T>
class RecursiveLockGuardedHolder
{
public:
explicit RecursiveLockGuardedHolder(T& value, std::recursive_mutex& mutex) :
d_lock(mutex), d_value(value)
{
}

T& operator*() const noexcept
{
return d_value;
}

T* operator->() const noexcept
{
return &d_value;
}

private:
std::lock_guard<std::recursive_mutex> d_lock;
T& d_value;
};

template <typename T>
class RecursiveLockGuardedTryHolder
{
public:
explicit RecursiveLockGuardedTryHolder(T& value, std::recursive_mutex& mutex) :
d_lock(mutex, std::try_to_lock), d_value(value)
{
}

T& operator*() const
{
if (!owns_lock()) {
throw std::runtime_error("Trying to access data protected by a mutex while the lock has not been acquired");
}
return d_value;
}

T* operator->() const
{
if (!owns_lock()) {
throw std::runtime_error("Trying to access data protected by a mutex while the lock has not been acquired");
}
return &d_value;
}

operator bool() const noexcept
{
return d_lock.owns_lock();
}

[[nodiscard]] bool owns_lock() const noexcept
{
return d_lock.owns_lock();
}

void lock()
{
d_lock.lock();
}

private:
std::unique_lock<std::recursive_mutex> d_lock;
T& d_value;
};

template <typename T>
class RecursiveLockGuarded
{
public:
explicit RecursiveLockGuarded(const T& value) :
d_value(value)
{
}

explicit RecursiveLockGuarded(T&& value) :
d_value(std::move(value))
{
}

explicit RecursiveLockGuarded() = default;

RecursiveLockGuardedTryHolder<T> try_lock()
{
return RecursiveLockGuardedTryHolder<T>(d_value, d_mutex);
}

RecursiveLockGuardedHolder<T> lock()
{
return RecursiveLockGuardedHolder<T>(d_value, d_mutex);
}

RecursiveLockGuardedHolder<const T> read_only_lock()
{
return RecursiveLockGuardedHolder<const T>(d_value, d_mutex);
}

private:
std::recursive_mutex d_mutex;
T d_value;
};

template <typename T>
class SharedLockGuardedHolder
{
Expand Down
39 changes: 29 additions & 10 deletions pdns/tcpiohandler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ const bool TCPIOHandler::s_disableConnectForUnitTests = false;
#include <sodium.h>
#endif /* HAVE_LIBSODIUM */

TLSCtx::tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr};

#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
#ifdef HAVE_LIBSSL

Expand Down Expand Up @@ -987,6 +989,16 @@ class GnuTLSTicketsKey
throw;
}
}
[[nodiscard]] std::string content() const
{
std::string result{};
if (d_key.data != nullptr && d_key.size > 0) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
result.append(reinterpret_cast<const char*>(d_key.data), d_key.size);
safe_memory_lock(result.data(), result.size());
}
return result;
}

~GnuTLSTicketsKey()
{
Expand Down Expand Up @@ -1730,37 +1742,44 @@ class GnuTLSIOCtx: public TLSCtx
return connection;
}

void rotateTicketsKey(time_t now) override
void addTicketsKey(time_t now, std::shared_ptr<GnuTLSTicketsKey>&& newKey)
{
if (!d_enableTickets) {
return;
}

auto newKey = std::make_shared<GnuTLSTicketsKey>();

{
*(d_ticketsKey.write_lock()) = std::move(newKey);
}

if (d_ticketsKeyRotationDelay > 0) {
d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
}

if (TLSCtx::hasTicketsKeyAddedHook()) {
auto ticketsKey = *(d_ticketsKey.read_lock());
auto content = ticketsKey->content();
TLSCtx::getTicketsKeyAddedHook()(content);
safe_memory_release(content.data(), content.size());
}
}
void rotateTicketsKey(time_t now) override
{
if (!d_enableTickets) {
return;
}

auto newKey = std::make_shared<GnuTLSTicketsKey>();
addTicketsKey(now, std::move(newKey));
}
void loadTicketsKeys(const std::string& file) final
{
if (!d_enableTickets) {
return;
}

auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
{
*(d_ticketsKey.write_lock()) = std::move(newKey);
}

if (d_ticketsKeyRotationDelay > 0) {
d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
}
addTicketsKey(time(nullptr), std::move(newKey));
}

size_t getTicketsKeysCount() override
Expand Down
Loading

0 comments on commit d1ac698

Please sign in to comment.