Skip to content

Commit

Permalink
dnsdist: Use a view for parsing ALPN data, add a regression test
Browse files Browse the repository at this point in the history
  • Loading branch information
rgacogne committed Mar 4, 2024
1 parent b599f69 commit 2a3c2b4
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 27 deletions.
1 change: 1 addition & 0 deletions pdns/dnsdistdist/views.hh
4 changes: 2 additions & 2 deletions pdns/dnsname.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ static void checkLabelLength(uint8_t length)
}

// this parses a DNS name until a compression pointer is found
size_t DNSName::parsePacketUncompressed(const UnsignedCharView& view, size_t pos, bool uncompress)
size_t DNSName::parsePacketUncompressed(const pdns::views::UnsignedCharView& view, size_t pos, bool uncompress)
{
const size_t initialPos = pos;
size_t totalLength = 0;
Expand Down Expand Up @@ -189,7 +189,7 @@ void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool unc
}
unsigned char labellen{0};

UnsignedCharView view(qpos, len);
pdns::views::UnsignedCharView view(qpos, len);
auto pos = parsePacketUncompressed(view, offset, uncompress);

labellen = view.at(pos);
Expand Down
23 changes: 2 additions & 21 deletions pdns/dnsname.hh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ inline unsigned char dns_tolower(unsigned char c)
}

#include "burtle.hh"
#include "views.hh"

// #include "dns.hh"
// #include "logger.hh"
Expand Down Expand Up @@ -216,28 +217,8 @@ public:
private:
string_t d_storage;

class UnsignedCharView
{
public:
UnsignedCharView(const char* data_, size_t size_): view(data_, size_)
{
}
const unsigned char& at(std::string_view::size_type pos) const
{
return reinterpret_cast<const unsigned char&>(view.at(pos));
}

size_t size() const
{
return view.size();
}

private:
std::string_view view;
};

void packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset);
size_t parsePacketUncompressed(const UnsignedCharView& view, size_t position, bool uncompress);
size_t parsePacketUncompressed(const pdns::views::UnsignedCharView& view, size_t position, bool uncompress);
static void appendEscapedLabel(std::string& appendTo, const char* orig, size_t len);
static std::string unescapeLabel(const std::string& orig);
static void throwSafeRangeError(const std::string& msg, const char* buf, size_t length);
Expand Down
1 change: 1 addition & 0 deletions pdns/recursordist/views.hh
10 changes: 6 additions & 4 deletions pdns/tcpiohandler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -876,21 +876,23 @@ class OpenSSLTLSIOCtx: public TLSCtx
if (!arg) {
return SSL_TLSEXT_ERR_ALERT_WARNING;
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);

const pdns::views::UnsignedCharView inView(in, inlen);
// Server preference algorithm as per RFC 7301 section 3.2
for (const auto& tentative : obj->d_alpnProtos) {
size_t pos = 0;
while (pos < inlen) {
size_t protoLen = in[pos];
while (pos < inView.size()) {
size_t protoLen = inView.at(pos);
pos++;
if (protoLen > (inlen - pos)) {
/* something is very wrong */
return SSL_TLSEXT_ERR_ALERT_WARNING;
}

if (tentative.size() == protoLen && memcmp(in + pos, tentative.data(), tentative.size()) == 0) {
*out = in + pos;
if (tentative.size() == protoLen && memcmp(&inView.at(pos), tentative.data(), tentative.size()) == 0) {
*out = &inView.at(pos);
*outlen = protoLen;
return SSL_TLSEXT_ERR_OK;
}
Expand Down
55 changes: 55 additions & 0 deletions pdns/views.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* This file is part of PowerDNS or dnsdist.
* Copyright -- PowerDNS.COM B.V. and its contributors
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of version 2 of the GNU General Public License as
* published by the Free Software Foundation.
*
* In addition, for the avoidance of any doubt, permission is granted to
* link this program with OpenSSL and to (re)distribute the binaries
* produced as the result of such linking.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/
#pragma once

#include <string_view>

namespace pdns::views
{

class UnsignedCharView
{
public:
UnsignedCharView(const char* data_, size_t size_) :
view(data_, size_)
{
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): No unsigned char view in C++17
UnsignedCharView(const unsigned char* data_, size_t size_) :
view(reinterpret_cast<const char*>(data_), size_)
{
}
const unsigned char& at(std::string_view::size_type pos) const
{
return reinterpret_cast<const unsigned char&>(view.at(pos));
}

size_t size() const
{
return view.size();
}

private:
std::string_view view;
};

}
12 changes: 12 additions & 0 deletions regression-tests.dnsdist/test_DOH.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,18 @@ def testDOHHTTP1(self):
self.assertEqual(rcode, 400)
self.assertEqual(data, b'<html><body>This server implements RFC 8484 - DNS Queries over HTTP, and requires HTTP/2 in accordance with section 5.2 of the RFC.</body></html>\r\n')

def testDOHHTTP1NotSelectedOverH2(self):
"""
DOH: Check that HTTP/1.1 is not selected over H2 when offered in the wrong order by the client
"""
if self._dohLibrary == 'h2o':
raise unittest.SkipTest('h2o supports HTTP/1.1, this test is only relevant for nghttp2')
alpn = ['http/1.1', 'h2']
conn = self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert, alpn=alpn)
if not hasattr(conn, 'selected_alpn_protocol'):
raise unittest.SkipTest('Unable to check the selected ALPN, Python version is too old to support selected_alpn_protocol')
self.assertEqual(conn.selected_alpn_protocol(), 'h2')

def testDOHInvalid(self):
"""
DOH: Invalid DNS query
Expand Down

0 comments on commit 2a3c2b4

Please sign in to comment.