Skip to content

Commit

Permalink
dnsdist: Prevent allocations and copies by using the right types
Browse files Browse the repository at this point in the history
  • Loading branch information
rgacogne committed Jan 11, 2021
1 parent 79a0ad0 commit 8130f43
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 107 deletions.
65 changes: 32 additions & 33 deletions pdns/dnsdist-ecs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ uint16_t g_ECSSourcePrefixV6 = 56;
bool g_ECSOverride{false};
bool g_addEDNSToSelfGeneratedResponses{true};

int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>& newContent)
int rewriteResponseWithoutEDNS(const std::vector<uint8_t>& initialPacket, vector<uint8_t>& newContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
Expand All @@ -52,7 +52,7 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>
if (ntohs(dh->qdcount) == 0)
return ENOENT;

PacketReader pr(initialPacket);
GenericPacketReader<std::vector<uint8_t>> pr(initialPacket);

size_t idx = 0;
DNSName rrname;
Expand Down Expand Up @@ -149,7 +149,7 @@ static bool addOrReplaceECSOption(std::vector<std::pair<uint16_t, std::string>>&
return true;
}

static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
static bool slowRewriteQueryWithExistingEDNS(const std::vector<uint8_t>& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
{
assert(initialPacket.size() >= sizeof(dnsheader));
const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
Expand All @@ -165,7 +165,7 @@ static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, v
throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS");
}

PacketReader pr(initialPacket);
GenericPacketReader<std::vector<uint8_t>> pr(initialPacket);

size_t idx = 0;
DNSName rrname;
Expand Down Expand Up @@ -317,7 +317,7 @@ static bool slowParseEDNSOptions(const std::vector<uint8_t>& packet, std::shared
return true;
}

int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last)
int locateEDNSOptRR(const std::vector<uint8_t>& packet, uint16_t * optStart, size_t * optLen, bool * last)
{
assert(optStart != NULL);
assert(optLen != NULL);
Expand All @@ -327,7 +327,8 @@ int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * opt
if (ntohs(dh->arcount) == 0)
return ENOENT;

PacketReader pr(packet);
GenericPacketReader<std::vector<uint8_t>> pr(packet);

size_t idx = 0;
DNSName rrname;
uint16_t qdcount = ntohs(dh->qdcount);
Expand Down Expand Up @@ -436,7 +437,7 @@ void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPref
generateEDNSOption(EDNSOptionCode::ECS, payload, res);
}

void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
bool generateOptRR(const std::string& optRData, std::vector<uint8_t>& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
{
const uint8_t name = 0;
dnsrecordheader dh;
Expand All @@ -445,15 +446,22 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload
edns0.version = 0;
edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;

if ((maximumSize - res.size()) < (sizeof(name) + sizeof(dh) + optRData.length())) {
return false;
}

dh.d_type = htons(QType::OPT);
dh.d_class = htons(udpPayloadSize);
static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
memcpy(&dh.d_ttl, &edns0, sizeof edns0);
dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
res.assign(reinterpret_cast<const char *>(&name), sizeof name);
res.append(reinterpret_cast<const char *>(&dh), sizeof(dh));
res.append(optRData.c_str(), optRData.length());

res.reserve(res.size() + sizeof(name) + sizeof(dh) + optRData.length());
res.insert(res.end(), reinterpret_cast<const uint8_t*>(&name), reinterpret_cast<const uint8_t*>(&name) + sizeof(name));
res.insert(res.end(), reinterpret_cast<const uint8_t*>(&dh), reinterpret_cast<const uint8_t*>(&dh) + sizeof(dh));
res.insert(res.end(), reinterpret_cast<const uint8_t*>(optRData.data()), reinterpret_cast<const uint8_t*>(optRData.data()) + optRData.length());

return true;
}

static bool replaceEDNSClientSubnetOption(std::vector<uint8_t>& packet, size_t maximumSize, size_t const oldEcsOptionStartPosition, size_t const oldEcsOptionSize, size_t const optRDLenPosition, const string& newECSOption)
Expand Down Expand Up @@ -556,17 +564,10 @@ static bool addECSToExistingOPT(std::vector<uint8_t>& packet, size_t maximumSize

static bool addEDNSWithECS(std::vector<uint8_t>& packet, size_t maximumSize, const string& newECSOption, bool& ednsAdded, bool& ecsAdded)
{
/* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
string EDNSRR;
generateOptRR(newECSOption, EDNSRR, g_EdnsUDPPayloadSize, 0, false);

if ((maximumSize - packet.size()) < EDNSRR.size()) {
if (!generateOptRR(newECSOption, packet, maximumSize, g_EdnsUDPPayloadSize, 0, false)) {
return false;
}

#warning FIXME: we can avoid a copy here by generating in place
packet.insert(packet.end(), EDNSRR.begin(), EDNSRR.end());

struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
uint16_t arcount = ntohs(dh->arcount);
arcount++;
Expand All @@ -587,7 +588,7 @@ bool handleEDNSClientSubnet(std::vector<uint8_t>& packet, const size_t maximumSi
vector<uint8_t> newContent;
newContent.reserve(packet.size());

if (!slowRewriteQueryWithExistingEDNS(std::string(reinterpret_cast<const char*>(packet.data()), packet.size()), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
if (!slowRewriteQueryWithExistingEDNS(packet, newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
ednsAdded = false;
ecsAdded = false;
return false;
Expand Down Expand Up @@ -708,7 +709,7 @@ int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optio
return 0;
}

bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
bool isEDNSOptionInOpt(const std::vector<uint8_t>& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
{
if (optLen < optRecordMinimumSize) {
return false;
Expand Down Expand Up @@ -747,7 +748,7 @@ bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const s
return false;
}

int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent)
int rewriteResponseWithoutEDNSOption(const std::vector<uint8_t>& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
Expand All @@ -758,7 +759,7 @@ int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uin
if (ntohs(dh->qdcount) == 0)
return ENOENT;

PacketReader pr(initialPacket);
GenericPacketReader<std::vector<uint8_t>> pr(initialPacket);

size_t idx = 0;
DNSName rrname;
Expand Down Expand Up @@ -844,12 +845,12 @@ int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uin
return 0;
}

bool addEDNS(std::vector<uint8_t>& packet, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
bool addEDNS(std::vector<uint8_t>& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
{
std::string optRecord;
generateOptRR(std::string(), optRecord, payloadSize, ednsrcode, dnssecOK);
if (!generateOptRR(std::string(), packet, maximumSize, payloadSize, ednsrcode, dnssecOK)) {
return false;
}

packet.insert(packet.end(), optRecord.begin(), optRecord.end());
auto dh = reinterpret_cast<dnsheader*>(packet.data());
dh->arcount = htons(ntohs(dh->arcount) + 1);

Expand Down Expand Up @@ -937,7 +938,7 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone,

if (hadEDNS) {
/* now we need to add a new OPT record */
return addEDNS(packet, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
return addEDNS(packet, dq.getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
}

return true;
Expand Down Expand Up @@ -976,7 +977,7 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)

if (g_addEDNSToSelfGeneratedResponses) {
/* now we need to add a new OPT record */
return addEDNS(packet, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
return addEDNS(packet, dq.getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
}

/* otherwise we are just fine */
Expand Down Expand Up @@ -1048,9 +1049,7 @@ bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
size_t optLen = 0;
bool last = false;
const auto& packet = dq.getData();
#warning FIXME: save an alloc+copy
std::string packetStr(reinterpret_cast<const char*>(packet.data()), packet.size());
int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
int res = locateEDNSOptRR(packet, &optStart, &optLen, &last);
if (res != 0) {
// no EDNS OPT RR
return false;
Expand All @@ -1060,7 +1059,7 @@ bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
return false;
}

if (optStart < packet.size() && packetStr.at(optStart) != 0) {
if (optStart < packet.size() && packet.at(optStart) != 0) {
// OPT RR Name != '.'
return false;
}
Expand Down
12 changes: 6 additions & 6 deletions pdns/dnsdist-ecs.hh
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ static const size_t optRecordMinimumSize = 11;
extern size_t g_EdnsUDPPayloadSize;
extern uint16_t g_PayloadSizeSelfGenAnswers;

int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>& newContent);
int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last);
void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK);
int rewriteResponseWithoutEDNS(const std::vector<uint8_t>& initialPacket, vector<uint8_t>& newContent);
int locateEDNSOptRR(const std::vector<uint8_t> & packet, uint16_t * optStart, size_t * optLen, bool * last);
bool generateOptRR(const std::string& optRData, std::vector<uint8_t>& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK);
void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength);
int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove);
int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent);
int rewriteResponseWithoutEDNSOption(const std::vector<uint8_t>& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent);
int getEDNSOptionsStart(const std::vector<uint8_t>& packet, const size_t offset, uint16_t* optRDPosition, size_t * remaining);
bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart = nullptr, uint16_t* optContentLen = nullptr);
bool addEDNS(std::vector<uint8_t>& packet, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode);
bool isEDNSOptionInOpt(const std::vector<uint8_t>& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart = nullptr, uint16_t* optContentLen = nullptr);
bool addEDNS(std::vector<uint8_t>& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode);
bool addEDNSToQueryTurnedResponse(DNSQuestion& dq);
bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum);

Expand Down
14 changes: 4 additions & 10 deletions pdns/dnsdist-lua-actions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu
dq->getHeader()->ancount = htons(dq->getHeader()->ancount);

if (hadEDNS && raw == false) {
addEDNS(dq->getMutableData(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0);
addEDNS(dq->getMutableData(), dq->getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0);
}

return Action::HeaderModify;
Expand All @@ -658,16 +658,10 @@ class MacAddrAction : public DNSAction
std::string optRData;
generateEDNSOption(d_code, mac, optRData);

std::string res;
generateOptRR(optRData, res, g_EdnsUDPPayloadSize, 0, false);

if (!dq->hasRoomFor(res.length())) {
return Action::None;
}

dq->getHeader()->arcount = htons(1);
auto& data = dq->getMutableData();
data.insert(data.end(), res.begin(), res.end());
if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
dq->getHeader()->arcount = htons(1);
}

return Action::None;
}
Expand Down
4 changes: 0 additions & 4 deletions pdns/dnsdist-lua-bindings-dnsquestion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
luaCtx.registerMember<dnsheader* (DNSQuestion::*)>("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast<DNSQuestion&>(dq).getHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) { *(dq.getHeader()) = *dh; });
luaCtx.registerMember<uint16_t (DNSQuestion::*)>("len", [](const DNSQuestion& dq) -> uint16_t { return dq.getData().size(); }, [](DNSQuestion& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); });
luaCtx.registerMember<uint8_t (DNSQuestion::*)>("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; });
#warning FIXME we need to provide Lua with a way to update the size
//luaCtx.registerMember<size_t (DNSQuestion::*)>("size", [](const DNSQuestion& dq) -> size_t { return dq.getData().size(); }, [](DNSQuestion& dq, size_t newSize) { (void) newSize; });
luaCtx.registerMember<bool (DNSQuestion::*)>("tcp", [](const DNSQuestion& dq) -> bool { return dq.tcp; }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; });
luaCtx.registerMember<bool (DNSQuestion::*)>("skipCache", [](const DNSQuestion& dq) -> bool { return dq.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.skipCache = newSkipCache; });
luaCtx.registerMember<bool (DNSQuestion::*)>("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; });
Expand Down Expand Up @@ -141,8 +139,6 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
luaCtx.registerMember<dnsheader* (DNSResponse::*)>("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast<DNSResponse&>(dr).getHeader(); }, [](DNSResponse& dr, const dnsheader* dh) { *(dr.getHeader()) = *dh; });
luaCtx.registerMember<uint16_t (DNSResponse::*)>("len", [](const DNSResponse& dq) -> uint16_t { return dq.getData().size(); }, [](DNSResponse& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); });
luaCtx.registerMember<uint8_t (DNSResponse::*)>("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; });
#warning FIXME we need to provide Lua with a way to update the size
//luaCtx.registerMember<size_t (DNSResponse::*)>("size", [](const DNSResponse& dq) -> size_t { return dq.size; }, [](DNSResponse& dq, size_t newSize) { (void) newSize; });
luaCtx.registerMember<bool (DNSResponse::*)>("tcp", [](const DNSResponse& dq) -> bool { return dq.tcp; }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; });
luaCtx.registerMember<bool (DNSResponse::*)>("skipCache", [](const DNSResponse& dq) -> bool { return dq.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.skipCache = newSkipCache; });
luaCtx.registerFunction<void(DNSResponse::*)(std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc)>("editTTLs", [](DNSResponse& dr, std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc) {
Expand Down
Loading

0 comments on commit 8130f43

Please sign in to comment.