diff --git a/lib/libp2phttp.go b/lib/libp2phttp.go index c7317c8..845132f 100644 --- a/lib/libp2phttp.go +++ b/lib/libp2phttp.go @@ -13,6 +13,8 @@ import ( "net" "net/http" "net/http/httputil" + "net/url" + "strconv" "time" "github.com/libp2p/go-libp2p/core/host" @@ -29,44 +31,64 @@ func Libp2pHTTPSocketProxy(ctx context.Context, p multiaddr.Multiaddr, unixSocke httpHost := libp2phttp.Host{StreamHost: h} - ai := peer.AddrInfo{ - Addrs: []multiaddr.Multiaddr{p}, + ai, err := peer.AddrInfoFromP2pAddr(p) + if err == peer.ErrInvalidAddr { + ai = &peer.AddrInfo{Addrs: []multiaddr.Multiaddr{p}} // No peer id + err = nil } - idStr, err := p.ValueForProtocol(multiaddr.P_P2P) - if err == nil { - id, err := peer.Decode(idStr) - if err != nil { - return err - } - ai.ID = id + if err != nil { + return err } hasTLS := false hasHTTP := false + host := "" + port := 0 multiaddr.ForEach(p, func(c multiaddr.Component) bool { - if c.Protocol().Code == multiaddr.P_HTTP { + switch c.Protocol().Code { + case multiaddr.P_TLS: + hasTLS = true + case multiaddr.P_HTTP: hasHTTP = true - } - - if c.Protocol().Code == multiaddr.P_HTTPS { + case multiaddr.P_HTTPS: hasHTTP = true hasTLS = true + case multiaddr.P_IP4, multiaddr.P_IP6, multiaddr.P_DNS4, multiaddr.P_DNS6, multiaddr.P_DNS: + host = c.Value() + case multiaddr.P_TCP, multiaddr.P_UDP: + port, err = strconv.Atoi(c.Value()) return false } - - if c.Protocol().Code == multiaddr.P_TLS { - hasTLS = true - } return true }) + if err != nil { + return err + } + if port == 0 && hasHTTP { + port = 80 + if hasTLS { + port = 443 + } + } - rt, err := httpHost.NewConstrainedRoundTripper(ai) + rt, err := httpHost.NewConstrainedRoundTripper(*ai) if err != nil { return err } - rp := &httputil.ReverseProxy{ - Transport: rt, - Director: func(r *http.Request) {}, + + var rp http.Handler + if hasTLS && hasHTTP { + u, err := url.Parse("https://" + host + ":" + strconv.Itoa(port) + "/") + if err != nil { + return err + } + revProxy := httputil.NewSingleHostReverseProxy(u) + rp = revProxy + } else { + rp = &httputil.ReverseProxy{ + Transport: rt, + Director: func(r *http.Request) {}, + } } // Serves an HTTP server on the given path using unix sockets @@ -101,8 +123,8 @@ func Libp2pHTTPSocketProxy(ctx context.Context, p multiaddr.Multiaddr, unixSocke return server.Serve(l) } -// Libp2pHTTPServer serves an libp2p enabled HTTP server -func Libp2pHTTPServer() (host.Host, *libp2phttp.Host, error) { +// libp2pHTTPServer serves an libp2p enabled HTTP server +func libp2pHTTPServer() (host.Host, *libp2phttp.Host, error) { h, err := libp2pHost() if err != nil { return nil, nil, err diff --git a/lib/libp2phttp_test.go b/lib/libp2phttp_test.go index cc9e6f3..0ade3b4 100644 --- a/lib/libp2phttp_test.go +++ b/lib/libp2phttp_test.go @@ -13,15 +13,25 @@ import ( func TestHTTPProxyAndServer(t *testing.T) { // Start libp2p HTTP server - h, hh, err := Libp2pHTTPServer() + h, hh, err := libp2pHTTPServer() if err != nil { t.Fatal(err) } + hh.SetHTTPHandlerAtPath("/hello", "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) go hh.Serve() defer hh.Close() serverAddr := h.Addrs()[0].Encapsulate(multiaddr.StringCast("/p2p/" + h.ID().String())) + port, err := serverAddr.ValueForProtocol(multiaddr.P_TCP) + if err != nil || port == "" { + port, err = serverAddr.ValueForProtocol(multiaddr.P_UDP) + if err != nil || port == "" { + t.Fatal("could not get port from server address") + } + } ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -59,8 +69,7 @@ func TestHTTPProxyAndServer(t *testing.T) { }, } - // TODO update this when https://github.com/libp2p/go-libp2p/pull/2757 lands - resp, err := client.Get("http://example.com" + "/.well-known/libp2p") + resp, err := client.Get("http://127.0.0.1:" + port + "/") if err != nil { t.Fatal(err) } @@ -128,8 +137,7 @@ func TestHTTPProxyAndServerOverHTTPTransport(t *testing.T) { }, } - // TODO update this when https://github.com/libp2p/go-libp2p/pull/2757 lands - resp, err := client.Get("http://example.com/") + resp, err := client.Get("http://127.0.0.1:" + port + "/") if err != nil { t.Fatal(err) }