Skip to content

Commit

Permalink
Merge pull request #12383 from rgacogne/ddist-stronger-udp-path
Browse files Browse the repository at this point in the history
dnsdist: Stronger guarantees against data race in the UDP path
  • Loading branch information
rgacogne authored Jan 11, 2023
2 parents f4b68ef + e1a6df9 commit 8b093b3
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 241 deletions.
126 changes: 52 additions & 74 deletions pdns/dnsdist-idstate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -160,75 +160,27 @@ struct IDState
}

IDState(const IDState& orig) = delete;
IDState(IDState&& rhs)
IDState(IDState&& rhs) noexcept :
internal(std::move(rhs.internal))
{
if (rhs.isInUse()) {
throw std::runtime_error("Trying to move an in-use IDState");
}

#ifdef __SANITIZE_THREAD__
inUse.store(rhs.inUse.load());
age.store(rhs.age.load());
#else
age = rhs.age;
#endif
internal = std::move(rhs.internal);
}

IDState& operator=(IDState&& rhs)
IDState& operator=(IDState&& rhs) noexcept
{
if (isInUse()) {
throw std::runtime_error("Trying to overwrite an in-use IDState");
}

if (rhs.isInUse()) {
throw std::runtime_error("Trying to move an in-use IDState");
}
#ifdef __SANITIZE_THREAD__
inUse.store(rhs.inUse.load());
age.store(rhs.age.load());
#else
age = rhs.age;
#endif

internal = std::move(rhs.internal);

return *this;
}

static const int64_t unusedIndicator = -1;

static bool isInUse(int64_t usageIndicator)
{
return usageIndicator != unusedIndicator;
}

bool isInUse() const
{
return usageIndicator != unusedIndicator;
}

/* return true if the value has been successfully replaced meaning that
no-one updated the usage indicator in the meantime */
bool tryMarkUnused(int64_t expectedUsageIndicator)
{
return usageIndicator.compare_exchange_strong(expectedUsageIndicator, unusedIndicator);
return inUse;
}

/* mark as used no matter what, return true if the state was in use before */
bool markAsUsed()
{
auto currentGeneration = generation++;
return markAsUsed(currentGeneration);
}

/* mark as used no matter what, return true if the state was in use before */
bool markAsUsed(int64_t currentGeneration)
{
int64_t oldUsage = usageIndicator.exchange(currentGeneration);
return oldUsage != unusedIndicator;
}

/* We use this value to detect whether this state is in use.
For performance reasons we don't want to use a lock here, but that means
/* For performance reasons we don't want to use a lock here, but that means
we need to be very careful when modifying this value. Modifications happen
from:
- one of the UDP or DoH 'client' threads receiving a query, selecting a backend
Expand All @@ -246,26 +198,52 @@ struct IDState
the corresponding state and sending the response to the client ;
- the 'healthcheck' thread scanning the states to actively discover timeouts,
mostly to keep some counters like the 'outstanding' one sane.
We previously based that logic on the origFD (FD on which the query was received,
and therefore from where the response should be sent) but this suffered from an
ABA problem since it was quite likely that a UDP 'client thread' would reset it to the
same value since we only have so much incoming sockets:
- 1/ 'client' thread gets a query and set origFD to its FD, say 5 ;
- 2/ 'receiver' thread gets a response, read the value of origFD to 5, check that the qname,
qtype and qclass match
- 3/ during that time the 'client' thread reuses the state, setting again origFD to 5 ;
- 4/ the 'receiver' thread uses compare_exchange_strong() to only replace the value if it's still
5, except it's not the same 5 anymore and it overrides a fresh state.
We now use a 32-bit unsigned counter instead, which is incremented every time the state is set,
wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1
when the state is unused and the value of our counter otherwise.
We have two flags:
- inUse tells us if there currently is a in-flight query whose state is stored
in this state
- locked tells us whether someone currently owns the state, so no-one else can touch
it
*/
InternalQueryState internal;
std::atomic<int64_t> usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty // 8
std::atomic<uint32_t> generation{0}; // increased every time a state is used, to be able to detect an ABA issue // 4
#ifdef __SANITIZE_THREAD__
std::atomic<uint16_t> age{0};
#else
uint16_t age{0}; // 2
#endif

class StateGuard
{
public:
StateGuard(IDState& ids) :
d_ids(ids)
{
}
~StateGuard()
{
d_ids.release();
}
StateGuard(const StateGuard&) = delete;
StateGuard(StateGuard&&) = delete;
StateGuard& operator=(const StateGuard&) = delete;
StateGuard& operator=(StateGuard&&) = delete;

private:
IDState& d_ids;
};

[[nodiscard]] std::optional<StateGuard> acquire()
{
bool expected = false;
if (locked.compare_exchange_strong(expected, true)) {
return std::optional<StateGuard>(*this);
}
return std::nullopt;
}

void release()
{
locked.store(false);
}

std::atomic<bool> inUse{false}; // 1

private:
std::atomic<bool> locked{false}; // 1
};
131 changes: 41 additions & 90 deletions pdns/dnsdist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff,
doLatencyStats(incomingProtocol, udiff);
}

static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool selfGenerated, std::optional<uint16_t> queryId)
static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool selfGenerated)
{
DNSResponse dr(ids, response, ds);

Expand All @@ -653,9 +653,6 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));

if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) {
if (queryId) {
ds->releaseState(*queryId);
}
return;
}

Expand Down Expand Up @@ -686,10 +683,6 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
else {
handleResponseSent(ids, 0., dr.ids.origRemote, ComboAddress(), response.size(), cleartextDH, dnsdist::Protocol::DoUDP);
}

if (queryId) {
ds->releaseState(*queryId);
}
}

// listens on a dedicated socket, lobs answers from downstream servers to original requestors
Expand Down Expand Up @@ -728,57 +721,23 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
queryId = dh->id;

IDState* ids = dss->getExistingState(queryId);
if (ids == nullptr) {
auto ids = dss->getState(queryId);
if (!ids) {
continue;
}

int64_t usageIndicator = ids->usageIndicator;

if (!IDState::isInUse(usageIndicator)) {
/* the corresponding state is marked as not in use, meaning that:
- it was already cleaned up by another thread and the state is gone ;
- we already got a response for this query and this one is a duplicate.
Either way, we don't touch it.
*/
continue;
}

/* setting age to 0 to prevent the maintainer thread from
cleaning this IDS while we process the response.
*/
ids->age = 0;

unsigned int qnameWireLength = 0;
if (fd != ids->internal.backendFD || !responseContentMatches(response, ids->internal.qname, ids->internal.qtype, ids->internal.qclass, dss, qnameWireLength)) {
if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) {
dss->restoreState(queryId, std::move(*ids));
continue;
}

DOHUnitUniquePtr du(nullptr, DOHUnit::release);
/* atomically mark the state as available, but only if it has not been altered
in the meantime */
if (ids->tryMarkUnused(usageIndicator)) {
/* clear the potential DOHUnit asap, it's ours now
and since we just marked the state as unused,
someone could overwrite it. */
du = std::move(ids->internal.du);
/* we only decrement the outstanding counter if the value was not
altered in the meantime, which would mean that the state has been actively reused
and the other thread has not incremented the outstanding counter, so we don't
want it to be decremented twice. */
--dss->outstanding; // you'd think an attacker could game this, but we're using connected socket
} else {
/* someone updated the state in the meantime, we can't touch the existing pointer */
du.release();
/* since the state has been updated, we can't safely access it so let's just drop
this response */
continue;
}
auto du = std::move(ids->du);

dh->id = ids->internal.origID;
dh->id = ids->origID;
++dss->responses;

double udiff = ids->internal.queryRealTime.udiff();
double udiff = ids->queryRealTime.udiff();
// do that _before_ the processing, otherwise it's not fair to the backend
dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0;
dss->reportResponse(dh->rcode);
Expand All @@ -787,13 +746,12 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
if (du) {
#ifdef HAVE_DNS_OVER_HTTPS
// DoH query, we cannot touch du after that
handleUDPResponseForDoH(std::move(du), std::move(response), std::move(ids->internal));
handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids));
#endif
dss->releaseState(queryId);
continue;
}

handleResponseForUDPClient(ids->internal, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, queryId);
handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false);
}
}
catch (const std::exception& e) {
Expand Down Expand Up @@ -1445,7 +1403,7 @@ class UDPTCPCrossQuerySender : public TCPQuerySender
static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();

handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated, std::nullopt);
handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated);
}

void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
Expand Down Expand Up @@ -1487,62 +1445,55 @@ class UDPCrossProtocolQuery : public CrossProtocolQuery
}
};

bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest)
bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest)
{
bool doh = dq.ids.du != nullptr;
unsigned int idOffset = 0;
int64_t generation;
IDState* ids = ds->getIDState(idOffset, generation);

dq.getHeader()->id = idOffset;

bool failed = false;
size_t proxyPayloadSize = 0;
if (ds->d_config.useProxyProtocol) {
try {
size_t payloadSize = 0;
if (addProxyProtocol(dq, &payloadSize)) {
if (addProxyProtocol(dq, &proxyPayloadSize)) {
if (dq.ids.du) {
dq.ids.du->proxyProtocolPayloadSize = payloadSize;
dq.ids.du->proxyProtocolPayloadSize = proxyPayloadSize;
}
}
}
catch (const std::exception& e) {
vinfolog("Adding proxy protocol payload to %squery from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what());
failed = true;
vinfolog("Adding proxy protocol payload to %s query from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what());
return false;
}
}

try {
if (!failed) {
int fd = ds->pickSocketForSending();
dq.ids.backendFD = fd;
dq.ids.origID = queryID;
dq.ids.forwardedOverUDP = true;
ids->internal = std::move(dq.ids);
int fd = ds->pickSocketForSending();
dq.ids.backendFD = fd;
dq.ids.origID = queryID;
dq.ids.forwardedOverUDP = true;

vinfolog("Got query for %s|%s from %s%s, relayed to %s", ids->internal.qname.toLogString(), QType(ids->internal.qtype).toString(), ids->internal.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr());
/* you can't touch du after this line, unless the call returned a non-negative value,
because it might already have been freed */
ssize_t ret = udpClientSendRequestToBackend(ds, fd, query);
vinfolog("Got query for %s|%s from %s%s, relayed to %s", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr());

if (ret < 0) {
failed = true;
}
}
else {
ids->internal = std::move(dq.ids);
auto idOffset = ds->saveState(std::move(dq.ids));
/* set the correct ID */
memcpy(query.data() + proxyPayloadSize, &idOffset, sizeof(idOffset));

/* you can't touch ids or du after this line, unless the call returned a non-negative value,
because it might already have been freed */
ssize_t ret = udpClientSendRequestToBackend(ds, fd, query);

if (ret < 0) {
failed = true;
}

if (failed) {
/* we are about to handle the error, make sure that
this pointer is not accessed when the state is cleaned,
but first check that it still belongs to us */
if (ids->tryMarkUnused(generation) && ids->internal.du) {
dq.ids.du = std::move(ids->internal.du);
--ds->outstanding;
}
if (dq.ids.du) {
dq.ids.du->status_code = 502;
/* clear up the state. In the very unlikely event it was reused
in the meantime, so be it. */
auto cleared = ds->getState(idOffset);
if (cleared) {
dq.ids.du = std::move(cleared->du);
if (dq.ids.du) {
dq.ids.du->status_code = 502;
}
}
++g_stats.downstreamSendErrors;
++ds->sendErrors;
Expand Down Expand Up @@ -1667,7 +1618,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
return;
}

assignOutgoingUDPQueryToBackend(ss, dh->id, dq, std::move(query), dest);
assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, dest);
}
catch(const std::exception& e){
vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", ids.origRemote.toStringWithPort(), queryId, e.what());
Expand Down
Loading

0 comments on commit 8b093b3

Please sign in to comment.