diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 5fb7fa5a8ad7..2ca88419c789 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -437,14 +437,14 @@ endif if HAVE_DNS_OVER_TLS if HAVE_GNUTLS -dnsdist_LDADD += -lgnutls +dnsdist_LDADD += $(GNUTLS_LIBS) endif endif if HAVE_DNS_OVER_HTTPS if HAVE_GNUTLS -dnsdist_LDADD += -lgnutls +dnsdist_LDADD += $(GNUTLS_LIBS) endif if HAVE_LIBH2OEVLOOP diff --git a/pdns/dnsdistdist/dnsdist-lua-hooks.cc b/pdns/dnsdistdist/dnsdist-lua-hooks.cc index c5ccb48915c1..2904cd37926a 100644 --- a/pdns/dnsdistdist/dnsdist-lua-hooks.cc +++ b/pdns/dnsdistdist/dnsdist-lua-hooks.cc @@ -2,9 +2,13 @@ #include "dnsdist-lua-hooks.hh" #include "dnsdist-lua.hh" #include "lock.hh" +#include "tcpiohandler.hh" namespace dnsdist::lua::hooks { +using MaintenanceCallback = std::function; +using TicketsKeyAddedHook = std::function; + static LockGuarded> s_maintenanceHooks; void runMaintenanceHooks(const LuaContext& context) @@ -15,7 +19,7 @@ void runMaintenanceHooks(const LuaContext& context) } } -void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback) +static void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback) { (void)context; s_maintenanceHooks.lock()->push_back(std::move(callback)); @@ -26,12 +30,29 @@ void clearMaintenanceHooks() s_maintenanceHooks.lock()->clear(); } +static void setTicketsKeyAddedHook(const LuaContext& context, const TicketsKeyAddedHook& hook) +{ + TLSCtx::setTicketsKeyAddedHook([hook](const std::string& key) { + try { + auto lua = g_lua.lock(); + hook(key.c_str(), key.size()); + } + catch (const std::exception& exp) { + warnlog("Error calling the Lua hook after new tickets key has been added: %s", exp.what()); + } + }); +} + void setupLuaHooks(LuaContext& luaCtx) { luaCtx.writeFunction("addMaintenanceCallback", [&luaCtx](const MaintenanceCallback& callback) { setLuaSideEffect(); addMaintenanceCallback(luaCtx, callback); }); + luaCtx.writeFunction("setTicketsKeyAddedHook", [&luaCtx](const TicketsKeyAddedHook& hook) { + setLuaSideEffect(); + setTicketsKeyAddedHook(luaCtx, hook); + }); } } diff --git a/pdns/dnsdistdist/dnsdist-lua-hooks.hh b/pdns/dnsdistdist/dnsdist-lua-hooks.hh index 11a9084883ee..e35c0f10ac5f 100644 --- a/pdns/dnsdistdist/dnsdist-lua-hooks.hh +++ b/pdns/dnsdistdist/dnsdist-lua-hooks.hh @@ -27,9 +27,7 @@ class LuaContext; namespace dnsdist::lua::hooks { -using MaintenanceCallback = std::function; void runMaintenanceHooks(const LuaContext& context); -void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback); void clearMaintenanceHooks(); void setupLuaHooks(LuaContext& luaCtx); } diff --git a/pdns/dnsdistdist/dnsdist.cc b/pdns/dnsdistdist/dnsdist.cc index 6e5ca522488c..4fd2dc499b03 100644 --- a/pdns/dnsdistdist/dnsdist.cc +++ b/pdns/dnsdistdist/dnsdist.cc @@ -888,7 +888,7 @@ void responderThread(std::shared_ptr dss) } } -LockGuarded g_lua{LuaContext()}; +RecursiveLockGuarded g_lua{LuaContext()}; ComboAddress g_serverControl{"127.0.0.1:5199"}; static void spoofResponseFromString(DNSQuestion& dnsQuestion, const string& spoofContent, bool raw) diff --git a/pdns/dnsdistdist/dnsdist.hh b/pdns/dnsdistdist/dnsdist.hh index b643c827f31b..cd1ef1880059 100644 --- a/pdns/dnsdistdist/dnsdist.hh +++ b/pdns/dnsdistdist/dnsdist.hh @@ -1099,7 +1099,7 @@ public: using servers_t = vector>; void responderThread(std::shared_ptr dss); -extern LockGuarded g_lua; +extern RecursiveLockGuarded g_lua; extern std::string g_outputBuffer; // locking for this is ok, as locked by g_luamutex class DNSRule diff --git a/pdns/dnsdistdist/docs/reference/config.rst b/pdns/dnsdistdist/docs/reference/config.rst index 7473624276e6..7e8165968ab1 100644 --- a/pdns/dnsdistdist/docs/reference/config.rst +++ b/pdns/dnsdistdist/docs/reference/config.rst @@ -2173,6 +2173,17 @@ Other functions Code is supplied as a string, not as a function object. Note that this function does nothing in 'client' or 'config-check' modes. +.. function:: setTicketsKeyAddedHook(callback) + + .. versionadded:: 1.9.6 + + Set a Lua function that will be called everytime a new tickets key is added. The function receives: + + * the key content as a string + * the keylen as an integer + + See :doc:`../advanced/tls-sessions-management` for more information. + .. function:: submitToMainThread(cmd, dict) .. versionadded:: 1.8.0 diff --git a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc index 013d0ba7c7fa..7d6569073cec 100644 --- a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc @@ -15,7 +15,7 @@ uint16_t g_maxOutstanding{std::numeric_limits::max()}; #include "ext/luawrapper/include/LuaContext.hpp" -LockGuarded g_lua{LuaContext()}; +RecursiveLockGuarded g_lua{LuaContext()}; bool g_snmpEnabled{false}; bool g_snmpTrapsEnabled{false}; diff --git a/pdns/libssl.cc b/pdns/libssl.cc index 3f657326c432..3f1a86b0ceb4 100644 --- a/pdns/libssl.cc +++ b/pdns/libssl.cc @@ -42,6 +42,7 @@ #undef CERT #include "misc.hh" +#include "tcpiohandler.hh" #if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL) /* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */ @@ -631,6 +632,13 @@ OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() = default; void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr&& newKey) { d_ticketKeys.write_lock()->push_front(std::move(newKey)); + if (TLSCtx::hasTicketsKeyAddedHook()) { + auto key = d_ticketKeys.read_lock()->front(); + auto keyContent = key->content(); + TLSCtx::getTicketsKeyAddedHook()(keyContent); + // fills mem with 0's + OPENSSL_cleanse(keyContent.data(), keyContent.size()); + } } std::shared_ptr OpenSSLTLSTicketKeysRing::getEncryptionKey() @@ -737,6 +745,19 @@ bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_N return (memcmp(d_name, name, sizeof(d_name)) == 0); } +std::string OpenSSLTLSTicketKey::content() const +{ + std::string result{}; + result.reserve(TLS_TICKETS_KEY_NAME_SIZE + TLS_TICKETS_CIPHER_KEY_SIZE + TLS_TICKETS_MAC_KEY_SIZE); + // NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast) + result.append(reinterpret_cast(d_name), TLS_TICKETS_KEY_NAME_SIZE); + result.append(reinterpret_cast(d_cipherKey), TLS_TICKETS_CIPHER_KEY_SIZE); + result.append(reinterpret_cast(d_hmacKey), TLS_TICKETS_MAC_KEY_SIZE); + // NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast) + + return result; +} + #if OPENSSL_VERSION_MAJOR >= 3 static const std::string sha256KeyName{"sha256"}; #endif diff --git a/pdns/libssl.hh b/pdns/libssl.hh index 8dd7ff373bf6..9bcf460802de 100644 --- a/pdns/libssl.hh +++ b/pdns/libssl.hh @@ -100,6 +100,7 @@ public: #if OPENSSL_VERSION_MAJOR >= 3 int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const; bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const; + [[nodiscard]] std::string content() const; #else int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const; bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const; @@ -124,7 +125,6 @@ public: private: void addKey(std::shared_ptr&& newKey); - SharedLockGuarded > > d_ticketKeys; }; diff --git a/pdns/lock.hh b/pdns/lock.hh index c413ceeab4d9..254c2f87b999 100644 --- a/pdns/lock.hh +++ b/pdns/lock.hh @@ -333,6 +333,111 @@ private: T d_value; }; +template +class RecursiveLockGuardedHolder +{ +public: + explicit RecursiveLockGuardedHolder(T& value, std::recursive_mutex& mutex) : + d_lock(mutex), d_value(value) + { + } + + T& operator*() const noexcept + { + return d_value; + } + + T* operator->() const noexcept + { + return &d_value; + } + +private: + std::lock_guard d_lock; + T& d_value; +}; + +template +class RecursiveLockGuardedTryHolder +{ +public: + explicit RecursiveLockGuardedTryHolder(T& value, std::recursive_mutex& mutex) : + d_lock(mutex, std::try_to_lock), d_value(value) + { + } + + T& operator*() const + { + if (!owns_lock()) { + throw std::runtime_error("Trying to access data protected by a mutex while the lock has not been acquired"); + } + return d_value; + } + + T* operator->() const + { + if (!owns_lock()) { + throw std::runtime_error("Trying to access data protected by a mutex while the lock has not been acquired"); + } + return &d_value; + } + + operator bool() const noexcept + { + return d_lock.owns_lock(); + } + + [[nodiscard]] bool owns_lock() const noexcept + { + return d_lock.owns_lock(); + } + + void lock() + { + d_lock.lock(); + } + +private: + std::unique_lock d_lock; + T& d_value; +}; + +template +class RecursiveLockGuarded +{ +public: + explicit RecursiveLockGuarded(const T& value) : + d_value(value) + { + } + + explicit RecursiveLockGuarded(T&& value) : + d_value(std::move(value)) + { + } + + explicit RecursiveLockGuarded() = default; + + RecursiveLockGuardedTryHolder try_lock() + { + return RecursiveLockGuardedTryHolder(d_value, d_mutex); + } + + RecursiveLockGuardedHolder lock() + { + return RecursiveLockGuardedHolder(d_value, d_mutex); + } + + RecursiveLockGuardedHolder read_only_lock() + { + return RecursiveLockGuardedHolder(d_value, d_mutex); + } + +private: + std::recursive_mutex d_mutex; + T d_value; +}; + template class SharedLockGuardedHolder { diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index cf82471ba84d..edee311ad078 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -11,6 +11,8 @@ const bool TCPIOHandler::s_disableConnectForUnitTests = false; #include #endif /* HAVE_LIBSODIUM */ +TLSCtx::tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr}; + #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) #ifdef HAVE_LIBSSL @@ -987,6 +989,16 @@ class GnuTLSTicketsKey throw; } } + [[nodiscard]] std::string content() const + { + std::string result{}; + if (d_key.data != nullptr && d_key.size > 0) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + result.append(reinterpret_cast(d_key.data), d_key.size); + safe_memory_lock(result.data(), result.size()); + } + return result; + } ~GnuTLSTicketsKey() { @@ -1730,14 +1742,12 @@ class GnuTLSIOCtx: public TLSCtx return connection; } - void rotateTicketsKey(time_t now) override + void addTicketsKey(time_t now, std::shared_ptr&& newKey) { if (!d_enableTickets) { return; } - auto newKey = std::make_shared(); - { *(d_ticketsKey.write_lock()) = std::move(newKey); } @@ -1745,8 +1755,23 @@ class GnuTLSIOCtx: public TLSCtx if (d_ticketsKeyRotationDelay > 0) { d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; } + + if (TLSCtx::hasTicketsKeyAddedHook()) { + auto ticketsKey = *(d_ticketsKey.read_lock()); + auto content = ticketsKey->content(); + TLSCtx::getTicketsKeyAddedHook()(content); + safe_memory_release(content.data(), content.size()); + } } + void rotateTicketsKey(time_t now) override + { + if (!d_enableTickets) { + return; + } + auto newKey = std::make_shared(); + addTicketsKey(now, std::move(newKey)); + } void loadTicketsKeys(const std::string& file) final { if (!d_enableTickets) { @@ -1754,13 +1779,7 @@ class GnuTLSIOCtx: public TLSCtx } auto newKey = std::make_shared(file); - { - *(d_ticketsKey.write_lock()) = std::move(newKey); - } - - if (d_ticketsKeyRotationDelay > 0) { - d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; - } + addTicketsKey(time(nullptr), std::move(newKey)); } size_t getTicketsKeysCount() override diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index 058d10443b71..8420529811e3 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -81,7 +81,6 @@ public: { throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file"); } - void handleTicketsKeyRotation(time_t now) { if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) { @@ -124,10 +123,27 @@ public: return false; } + using tickets_key_added_hook = std::function; + + static void setTicketsKeyAddedHook(const tickets_key_added_hook& hook) + { + TLSCtx::s_ticketsKeyAddedHook = hook; + } + static const tickets_key_added_hook& getTicketsKeyAddedHook() + { + return TLSCtx::s_ticketsKeyAddedHook; + } + static bool hasTicketsKeyAddedHook() + { + return TLSCtx::s_ticketsKeyAddedHook != nullptr; + } protected: std::atomic_flag d_rotatingTicketsKey; std::atomic d_ticketsKeyNextRotation{0}; time_t d_ticketsKeyRotationDelay{0}; + +private: + static tickets_key_added_hook s_ticketsKeyAddedHook; }; class TLSFrontend diff --git a/regression-tests.dnsdist/test_TLS.py b/regression-tests.dnsdist/test_TLS.py index 9803ed550f96..27c2de52fe19 100644 --- a/regression-tests.dnsdist/test_TLS.py +++ b/regression-tests.dnsdist/test_TLS.py @@ -516,3 +516,40 @@ def setUpClass(cls): cls.startResponders() cls.startDNSDist() cls.setUpSockets() + +class TestTLSTicketsKeyAddedCallback(DNSDistTest): + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _tlsServerPort = pickAvailablePort() + _numberOfKeys = 5 + + _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey'] + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%s") + + newServer{address="127.0.0.1:%s"} + addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" }) + + callbackCalled = 0 + function keyAddedCallback(key, keyLen) + callbackCalled = keyLen + end + + """ + + def testLuaThreadCounter(self): + """ + LuaThread: Test the lua newThread interface + """ + self.sendConsoleCommand('setTicketsKeyAddedHook(keyAddedCallback)'); + called = self.sendConsoleCommand('callbackCalled') + self.assertEqual(int(called), 0) + self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()") + called = self.sendConsoleCommand('callbackCalled') + self.assertGreater(int(called), 0)