Skip to content

Commit

Permalink
Merge pull request #26 from multiformats/feat/rework-resolver
Browse files Browse the repository at this point in the history
refactor Resolver to support custom per-TLD resolvers
  • Loading branch information
vyzo authored Apr 9, 2021
2 parents dba25a2 + 45cdfcf commit 963a26a
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 65 deletions.
31 changes: 31 additions & 0 deletions mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package madns

import (
"context"
"net"
)

type MockResolver struct {
IP map[string][]net.IPAddr
TXT map[string][]string
}

var _ BasicResolver = (*MockResolver)(nil)

func (r *MockResolver) LookupIPAddr(ctx context.Context, name string) ([]net.IPAddr, error) {
results, ok := r.IP[name]
if ok {
return results, nil
} else {
return []net.IPAddr{}, nil
}
}

func (r *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
results, ok := r.TXT[name]
if ok {
return results, nil
} else {
return []string{}, nil
}
}
127 changes: 64 additions & 63 deletions resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,86 @@ import (
)

var ResolvableProtocols = []ma.Protocol{DnsaddrProtocol, Dns4Protocol, Dns6Protocol, DnsProtocol}
var DefaultResolver = &Resolver{Backend: net.DefaultResolver}
var DefaultResolver = &Resolver{def: net.DefaultResolver}

const dnsaddrTXTPrefix = "dnsaddr="

type Backend interface {
// BasicResolver is a low level interface for DNS resolution
type BasicResolver interface {
LookupIPAddr(context.Context, string) ([]net.IPAddr, error)
LookupTXT(context.Context, string) ([]string, error)
}

// Resolver is an object capable of resolving dns multiaddrs by using one or more BasicResolvers;
// it supports custom per domain/TLD resolvers.
// It also implements the BasicResolver interface so that it can act as a custom per domain/TLD
// resolver.
type Resolver struct {
Backend Backend
def BasicResolver
custom map[string]BasicResolver
}

var _ Backend = (*MockBackend)(nil)
var _ BasicResolver = (*Resolver)(nil)

type MockBackend struct {
IP map[string][]net.IPAddr
TXT map[string][]string
// NewResolver creates a new Resolver instance with the specified options
func NewResolver(opts ...Option) (*Resolver, error) {
r := &Resolver{def: net.DefaultResolver}
for _, opt := range opts {
err := opt(r)
if err != nil {
return nil, err
}
}

return r, nil
}

func (r *MockBackend) LookupIPAddr(ctx context.Context, name string) ([]net.IPAddr, error) {
results, ok := r.IP[name]
if ok {
return results, nil
} else {
return []net.IPAddr{}, nil
type Option func(*Resolver) error

// WithDefaultResolver is an option that specifies the default basic resolver,
// which resolves any TLD that doesn't have a custom resolver.
// Defaults to net.DefaultResolver
func WithDefaultResolver(def BasicResolver) Option {
return func(r *Resolver) error {
r.def = def
return nil
}
}

func (r *MockBackend) LookupTXT(ctx context.Context, name string) ([]string, error) {
results, ok := r.TXT[name]
if ok {
return results, nil
} else {
return []string{}, nil
// WithDomainResolver specifies a custom resolver for a domain/TLD.
// Custom resolver selection matches domains left to right, with more specific resolvers
// superseding generic ones.
func WithDomainResolver(domain string, rslv BasicResolver) Option {
return func(r *Resolver) error {
if r.custom == nil {
r.custom = make(map[string]BasicResolver)
}
r.custom[domain] = rslv
return nil
}
}

func Matches(maddr ma.Multiaddr) (matches bool) {
ma.ForEach(maddr, func(c ma.Component) bool {
switch c.Protocol().Code {
case DnsProtocol.Code, Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code:
matches = true
func (r *Resolver) getResolver(domain string) BasicResolver {
// we match left-to-right, with more specific resolvers superseding generic ones.
// So for a domain a.b.c, we will try a.b,c, b.c, c, and fallback to the default if
// there is no match
rslv, ok := r.custom[domain]
if ok {
return rslv
}

for i := strings.Index(domain, "."); i != -1; i = strings.Index(domain, ".") {
domain = domain[i+1:]
rslv, ok = r.custom[domain]
if ok {
return rslv
}
return !matches
})
return matches
}
}

func Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) {
return DefaultResolver.Resolve(ctx, maddr)
return r.def
}

// Resolve resolves a DNS multiaddr.
func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) {
var results []ma.Multiaddr
for i := 0; maddr != nil; i++ {
Expand Down Expand Up @@ -99,6 +126,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia

proto := resolve.Protocol()
value := resolve.Value()
rslv := r.getResolver(value)

// resolve the dns component
var resolved []ma.Multiaddr
Expand All @@ -114,7 +142,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia
// differentiating between IPv6 and IPv4. A v4-in-v6
// AAAA record will _look_ like an A record to us and
// there's nothing we can do about that.
records, err := r.Backend.LookupIPAddr(ctx, value)
records, err := rslv.LookupIPAddr(ctx, value)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -155,7 +183,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia
// matching the result of step 2.

// First, lookup the TXT record
records, err := r.Backend.LookupTXT(ctx, "_dnsaddr."+value)
records, err := rslv.LookupTXT(ctx, "_dnsaddr."+value)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -235,37 +263,10 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia
return results, nil
}

// counts the number of components in the multiaddr
func addrLen(maddr ma.Multiaddr) int {
length := 0
ma.ForEach(maddr, func(_ ma.Component) bool {
length++
return true
})
return length
}

// trims `offset` components from the beginning of the multiaddr.
func offset(maddr ma.Multiaddr, offset int) ma.Multiaddr {
_, after := ma.SplitFunc(maddr, func(c ma.Component) bool {
if offset == 0 {
return true
}
offset--
return false
})
return after
func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) {
return r.getResolver(domain).LookupIPAddr(ctx, domain)
}

// takes the cross product of two sets of multiaddrs
//
// assumes `a` is non-empty.
func cross(a, b []ma.Multiaddr) []ma.Multiaddr {
res := make([]ma.Multiaddr, 0, len(a)*len(b))
for _, x := range a {
for _, y := range b {
res = append(res, x.Encapsulate(y))
}
}
return res
func (r *Resolver) LookupTXT(ctx context.Context, txt string) ([]string, error) {
return r.getResolver(txt).LookupTXT(ctx, txt)
}
91 changes: 89 additions & 2 deletions resolve_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package madns

import (
"bytes"
"context"
"net"
"testing"
Expand Down Expand Up @@ -29,7 +30,7 @@ var txtd = "dnsaddr=" + txtmd.String()
var txte = "dnsaddr=" + txtme.String()

func makeResolver() *Resolver {
mock := &MockBackend{
mock := &MockResolver{
IP: map[string][]net.IPAddr{
"example.com": []net.IPAddr{ip4a, ip4b, ip6a, ip6b},
},
Expand All @@ -38,7 +39,7 @@ func makeResolver() *Resolver {
"_dnsaddr.matching.com": []string{txtc, txtd, txte, "not a dnsaddr", "dnsaddr=/foobar"},
},
}
resolver := &Resolver{Backend: mock}
resolver := &Resolver{def: mock}
return resolver
}

Expand Down Expand Up @@ -234,3 +235,89 @@ func TestBadDomain(t *testing.T) {
t.Error("expected malformed address to fail to parse")
}
}

func TestCustomResolver(t *testing.T) {
ip1 := net.IPAddr{IP: net.ParseIP("1.2.3.4")}
ip2 := net.IPAddr{IP: net.ParseIP("2.3.4.5")}
ip3 := net.IPAddr{IP: net.ParseIP("3.4.5.6")}
ip4 := net.IPAddr{IP: net.ParseIP("4.5.6.8")}
ip5 := net.IPAddr{IP: net.ParseIP("5.6.8.9")}
ip6 := net.IPAddr{IP: net.ParseIP("6.8.9.10")}
def := &MockResolver{
IP: map[string][]net.IPAddr{
"example.com": []net.IPAddr{ip1},
},
}
custom1 := &MockResolver{
IP: map[string][]net.IPAddr{
"custom.test": []net.IPAddr{ip2},
"another.custom.test": []net.IPAddr{ip3},
"more.custom.test": []net.IPAddr{ip6},
},
}
custom2 := &MockResolver{
IP: map[string][]net.IPAddr{
"more.custom.test": []net.IPAddr{ip4},
"some.more.custom.test": []net.IPAddr{ip5},
},
}

rslv, err := NewResolver(
WithDefaultResolver(def),
WithDomainResolver("custom.test", custom1),
WithDomainResolver("more.custom.test", custom2),
)
if err != nil {
t.Fatal(err)
}

sameIP := func(ip1, ip2 net.IPAddr) bool {
return bytes.Equal(ip1.IP, ip2.IP)
}

ctx := context.Background()
res, err := rslv.LookupIPAddr(ctx, "example.com")
if err != nil {
t.Fatal(err)
}

if len(res) != 1 || !sameIP(res[0], ip1) {
t.Fatal("expected result to be ip1")
}

res, err = rslv.LookupIPAddr(ctx, "custom.test")
if err != nil {
t.Fatal(err)
}

if len(res) != 1 || !sameIP(res[0], ip2) {
t.Fatal("expected result to be ip2")
}

res, err = rslv.LookupIPAddr(ctx, "another.custom.test")
if err != nil {
t.Fatal(err)
}

if len(res) != 1 || !sameIP(res[0], ip3) {
t.Fatal("expected result to be ip3")
}

res, err = rslv.LookupIPAddr(ctx, "more.custom.test")
if err != nil {
t.Fatal(err)
}

if len(res) != 1 || !sameIP(res[0], ip4) {
t.Fatal("expected result to be ip4")
}

res, err = rslv.LookupIPAddr(ctx, "some.more.custom.test")
if err != nil {
t.Fatal(err)
}

if len(res) != 1 || !sameIP(res[0], ip5) {
t.Fatal("expected result to be ip5")
}
}
57 changes: 57 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package madns

import (
"context"

ma "github.com/multiformats/go-multiaddr"
)

func Matches(maddr ma.Multiaddr) (matches bool) {
ma.ForEach(maddr, func(c ma.Component) bool {
switch c.Protocol().Code {
case DnsProtocol.Code, Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code:
matches = true
}
return !matches
})
return matches
}

func Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) {
return DefaultResolver.Resolve(ctx, maddr)
}

// counts the number of components in the multiaddr
func addrLen(maddr ma.Multiaddr) int {
length := 0
ma.ForEach(maddr, func(_ ma.Component) bool {
length++
return true
})
return length
}

// trims `offset` components from the beginning of the multiaddr.
func offset(maddr ma.Multiaddr, offset int) ma.Multiaddr {
_, after := ma.SplitFunc(maddr, func(c ma.Component) bool {
if offset == 0 {
return true
}
offset--
return false
})
return after
}

// takes the cross product of two sets of multiaddrs
//
// assumes `a` is non-empty.
func cross(a, b []ma.Multiaddr) []ma.Multiaddr {
res := make([]ma.Multiaddr, 0, len(a)*len(b))
for _, x := range a {
for _, y := range b {
res = append(res, x.Encapsulate(y))
}
}
return res
}

0 comments on commit 963a26a

Please sign in to comment.