diff --git a/.gitignore b/.gitignore index a2804ed5..ceb3efec 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,4 @@ Session.vim # Auto-generated tag files tags +.rgignore diff --git a/.goreleaser.yml b/.goreleaser.yml index 7ad1696e..0d377a7b 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -12,7 +12,25 @@ builds: ignore: - goos: windows goarch: arm64 - ldflags: -s -w -X github.com/grepplabs/kafka-proxy/config.Version={{.Version}} + ldflags: -s -w -X github.com/limoges/kafka-proxy/config.Version={{.Version}} +kos: + - repositories: + - ghcr.io/limoges/kafka-proxy + platforms: + - linux/arm64 + - linux/amd64 + tags: + - latest + - "{{.Tag}}" + - "{{.Version}}" + bare: true + preserve_import_paths: false + labels: + "org.opencontainers.image.source": "https://github.com/limoges/kafka-proxy" + "org.opencontainers.image.description": "Proxy connections to Kafka cluster. Connect through SOCKS Proxy, HTTP Proxy or to cluster running in Kubernetes." + annotations: + "org.opencontainers.image.source": "https://github.com/limoges/kafka-proxy" + "org.opencontainers.image.description": "Proxy connections to Kafka cluster. Connect through SOCKS Proxy, HTTP Proxy or to cluster running in Kubernetes." archives: - name_template: "{{ .ProjectName }}-{{ .Tag }}-{{ .Os }}-{{ .Arch }}" files: diff --git a/README.md b/README.md index 68f385b9..472b9e31 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,7 @@ The Kafka Proxy is based on idea of [Cloud SQL Proxy](https://github.com/GoogleC It allows a service to connect to Kafka brokers without having to deal with SASL/PLAIN authentication and SSL certificates. It works by opening tcp sockets on the local machine and proxying connections to the associated Kafka brokers -when the sockets are used. The host and port in [Metadata](http://kafka.apache.org/protocol.html#The_Messages_Metadata) -and [FindCoordinator](http://kafka.apache.org/protocol.html#The_Messages_FindCoordinator) +when the sockets are used. The host and port in [Metadata](http://kafka.apache.org/protocol.html#The_Messages_Metadata), [FindCoordinator](http://kafka.apache.org/protocol.html#The_Messages_FindCoordinator) & [DescribeCluster](https://github.com/apache/kafka/blob/trunk/clients/src/main/resources/common/message/DescribeClusterResponse.json) responses received from the brokers are replaced by local counterparts. For discovered brokers (not configured as the boostrap servers), local listeners are started on random ports. The dynamic local listeners feature can be disabled and an additional list of external server mappings can be provided. diff --git a/cmd/kafka-proxy/server.go b/cmd/kafka-proxy/server.go index 6c23cd94..bef98db6 100644 --- a/cmd/kafka-proxy/server.go +++ b/cmd/kafka-proxy/server.go @@ -483,7 +483,7 @@ func SetLogger() { } logrus.SetFormatter(formatter) } else { - logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) + logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: false}) } level, err := logrus.ParseLevel(c.Log.Level) if err != nil { diff --git a/go.mod b/go.mod index 0e1653d3..34514135 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,9 @@ module github.com/grepplabs/kafka-proxy -go 1.23 +go 1.23.0 + +toolchain go1.23.5 + require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/aws/aws-sdk-go-v2 v1.36.1 @@ -20,6 +23,7 @@ require ( github.com/jcmturner/gofork v1.7.6 github.com/jcmturner/gokrb5/v8 v8.4.3 github.com/oklog/run v1.1.0 + github.com/pires/go-proxyproto v0.8.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.11.1 github.com/samber/slog-logrus/v2 v2.5.2 diff --git a/go.sum b/go.sum index 111fd660..7b869f2f 100644 --- a/go.sum +++ b/go.sum @@ -212,6 +212,8 @@ github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pires/go-proxyproto v0.8.0 h1:5unRmEAPbHXHuLjDg01CxJWf91cw3lKHc/0xzKpXEe0= +github.com/pires/go-proxyproto v0.8.0/go.mod h1:iknsfgnH8EkjrMeMyvfKByp9TiBZCKZM0jx2xmKqnVY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/main.go b/main.go index 06db1a04..0896d354 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,12 @@ import ( "github.com/grepplabs/kafka-proxy/cmd/kafka-proxy" "github.com/grepplabs/kafka-proxy/cmd/tools" "github.com/spf13/cobra" + "log" + "net/http" "os" + "runtime" + + _ "net/http/pprof" ) var RootCmd = &cobra.Command{ @@ -25,6 +30,13 @@ func init() { } func main() { + runtime.SetBlockProfileRate(1) // blocking sync primitives + runtime.SetMutexProfileFraction(1) // contended mutexes + + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + if err := RootCmd.Execute(); err != nil { fmt.Println(err) os.Exit(1) diff --git a/proxy/client.go b/proxy/client.go index eb27e699..91edc9b5 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -1,6 +1,7 @@ package proxy import ( + "crypto/tls" "crypto/x509" "fmt" "net" @@ -295,27 +296,27 @@ func (c *Client) handleConn(conn Conn) { logrus.Infof("Dial address changed from %s to %s", conn.BrokerAddress, dialAddress) } - server, err := c.DialAndAuth(dialAddress) + server, err := c.DialAndAuth(conn, dialAddress) if err != nil { logrus.Infof("couldn't connect to %s(%s): %v", dialAddress, conn.BrokerAddress, err) _ = conn.LocalConnection.Close() return } + if tcpConn, ok := server.(*net.TCPConn); ok { if err := c.tcpConnOptions.setTCPConnOptions(tcpConn); err != nil { logrus.Infof("WARNING: Error while setting TCP options for kafka connection %s on %v: %v", conn.BrokerAddress, server.LocalAddr(), err) } } c.conns.Add(conn.BrokerAddress, conn.LocalConnection) - localDesc := "local connection on " + conn.LocalConnection.LocalAddr().String() + " from " + conn.LocalConnection.RemoteAddr().String() + " (" + conn.BrokerAddress + ")" - copyThenClose(c.processorConfig, server, conn.LocalConnection, conn.BrokerAddress, conn.BrokerAddress, localDesc) + copyThenClose(c.processorConfig, server, conn.LocalConnection, conn.BrokerAddress, conn.LocalConnection.RemoteAddr(), conn.LocalConnection.LocalAddr()) if err := c.conns.Remove(conn.BrokerAddress, conn.LocalConnection); err != nil { logrus.Info(err) } } -func (c *Client) DialAndAuth(brokerAddress string) (net.Conn, error) { - conn, err := c.dialer.Dial("tcp", brokerAddress) +func (c *Client) DialAndAuth(downstream Conn, brokerAddress string) (conn net.Conn, err error) { + conn, err = c.dialer.Dial("tcp", brokerAddress) if err != nil { return nil, err } @@ -323,10 +324,19 @@ func (c *Client) DialAndAuth(brokerAddress string) (net.Conn, error) { _ = conn.Close() return nil, err } + + if tlsConn, ok := conn.(*tls.Conn); ok { + err := tlsConn.Handshake() + if err != nil { + return nil, fmt.Errorf("client handshake with upstream broker failed: %w", err) + } + } err = c.auth(conn, brokerAddress) if err != nil { return nil, err } + + logrus.Infof("%s: Client(%s) Successfully connected to upstream: %s", downstream.LocalConnection.LocalAddr(), downstream.LocalConnection.RemoteAddr(), brokerAddress) return conn, nil } diff --git a/proxy/common.go b/proxy/common.go index 406d641d..4a94ff53 100644 --- a/proxy/common.go +++ b/proxy/common.go @@ -4,11 +4,12 @@ import ( "bytes" "errors" "fmt" - "github.com/sirupsen/logrus" "io" "net" "sync" "time" + + "github.com/sirupsen/logrus" ) type DeadlineReadWriteCloser interface { @@ -112,8 +113,9 @@ func copyError(readDesc, writeDesc string, readErr bool, err error) { logrus.Infof("%v had error: %s", desc, err.Error()) } -func copyThenClose(cfg ProcessorConfig, remote, local DeadlineReadWriteCloser, brokerAddress string, remoteDesc, localDesc string) { +func copyThenClose(cfg ProcessorConfig, remote, local DeadlineReadWriteCloser, brokerAddress string, remoteAddr, localAddr net.Addr) { + localDesc := "local connection on " + localAddr.String() + " from " + remoteAddr.String() + " (" + brokerAddress + ")" processor := newProcessor(cfg, brokerAddress) firstErr := make(chan error, 1) @@ -123,9 +125,9 @@ func copyThenClose(cfg ProcessorConfig, remote, local DeadlineReadWriteCloser, b select { case firstErr <- err: if readErr && err == io.EOF { - logrus.Infof("Client closed %v", localDesc) + logrus.Infof("%s: Client(%s) closed connection for broker (%s)", localAddr, remoteAddr, brokerAddress) } else { - copyError(localDesc, remoteDesc, readErr, err) + copyError(localDesc, remoteAddr.String(), readErr, err) } remote.Close() local.Close() @@ -137,9 +139,9 @@ func copyThenClose(cfg ProcessorConfig, remote, local DeadlineReadWriteCloser, b select { case firstErr <- err: if readErr && err == io.EOF { - logrus.Infof("Server %v closed connection", remoteDesc) + logrus.Infof("Server %v closed connection", remoteAddr.String()) } else { - copyError(remoteDesc, localDesc, readErr, err) + copyError(remoteAddr.String(), localDesc, readErr, err) } remote.Close() local.Close() diff --git a/proxy/protocol/real_decoder.go b/proxy/protocol/real_decoder.go index d87b8dd4..8e7479e3 100644 --- a/proxy/protocol/real_decoder.go +++ b/proxy/protocol/real_decoder.go @@ -2,8 +2,9 @@ package protocol import ( "encoding/binary" - "github.com/google/uuid" "math" + + "github.com/google/uuid" ) var errInvalidArrayLength = PacketDecodingError{"invalid array length"} diff --git a/proxy/protocol/responses.go b/proxy/protocol/responses.go index e062176a..eb8b5726 100644 --- a/proxy/protocol/responses.go +++ b/proxy/protocol/responses.go @@ -10,8 +10,10 @@ import ( const ( apiKeyMetadata = 3 apiKeyFindCoordinator = 10 + apiKeyDescribeCluster = 60 brokersKeyName = "brokers" + brokerKeyName = "broker_id" hostKeyName = "host" portKeyName = "port" nodeKeyName = "node_id" @@ -23,6 +25,7 @@ const ( var ( metadataResponseSchemaVersions = createMetadataResponseSchemaVersions() findCoordinatorResponseSchemaVersions = createFindCoordinatorResponseSchemaVersions() + describeClusterResponseSchemaVersions = createDescribeClusterResponseSchemaVersions() ) func createMetadataResponseSchemaVersions() []Schema { @@ -300,6 +303,63 @@ func createFindCoordinatorResponseSchemaVersions() []Schema { return []Schema{findCoordinatorResponseV0, findCoordinatorResponseV1, findCoordinatorResponseV2, findCoordinatorResponseV3, findCoordinatorResponseV4, findCoordinatorResponseV5, findCoordinatorResponseV6} } +func createDescribeClusterResponseSchemaVersions() []Schema { + describeClusterBrokerV0 := NewSchema("describe_cluster_broker_v0", + &Mfield{Name: brokerKeyName, Ty: TypeInt32}, + &Mfield{Name: hostKeyName, Ty: TypeCompactStr}, + &Mfield{Name: portKeyName, Ty: TypeInt32}, + &Mfield{Name: "rack", Ty: TypeCompactNullableStr}, + // by the spec it should not be here. However, in kafka itself they don't check API version + // and put this field starting from v0 + // See https://github.com/tinaselenge/kafka/blob/814212f103c980f080593544079c4507f2557b08/core/src/main/scala/kafka/server/KafkaApis.scala#L3607 + &Mfield{Name: "is_fenced", Ty: TypeBool}, + ) + + describeClusterBrokerV2 := NewSchema("describe_cluster_broker_v2", + &Mfield{Name: brokerKeyName, Ty: TypeInt32}, + &Mfield{Name: hostKeyName, Ty: TypeStr}, + &Mfield{Name: portKeyName, Ty: TypeInt32}, + &Mfield{Name: "rack", Ty: TypeNullableStr}, + &Mfield{Name: "is_fenced", Ty: TypeBool}, + ) + + describeClusterV0 := NewSchema("describe_cluster_response_v0", + &Mfield{Name: "throttle_time_ms", Ty: TypeInt32}, + &Mfield{Name: "error_code", Ty: TypeInt16}, + &Mfield{Name: "error_message", Ty: TypeCompactNullableStr}, + &Mfield{Name: "cluster_id", Ty: TypeCompactStr}, + &Mfield{Name: "controller_id", Ty: TypeInt32}, + &CompactArray{Name: brokersKeyName, Ty: describeClusterBrokerV0}, + &Mfield{Name: "cluster_authorized_operations", Ty: TypeInt32}, + &SchemaTaggedFields{Name: "response_tagged_fields"}, + ) + + describeClusterV1 := NewSchema("describe_cluster_response_v1", + &Mfield{Name: "throttle_time_ms", Ty: TypeInt32}, + &Mfield{Name: "error_code", Ty: TypeInt16}, + &Mfield{Name: "error_message", Ty: TypeCompactNullableStr}, + &Mfield{Name: "endpoint_type", Ty: TypeInt8}, + &Mfield{Name: "cluster_id", Ty: TypeCompactStr}, + &Mfield{Name: "controller_id", Ty: TypeInt32}, + &CompactArray{Name: brokersKeyName, Ty: describeClusterBrokerV0}, + &Mfield{Name: "cluster_authorized_operations", Ty: TypeInt32}, + &SchemaTaggedFields{Name: "response_tagged_fields"}, + ) + + describeClusterV2 := NewSchema("describe_cluster_response_v2", + &Mfield{Name: "throttle_time_ms", Ty: TypeInt32}, + &Mfield{Name: "error_code", Ty: TypeInt16}, + &Mfield{Name: "error_message", Ty: TypeNullableStr}, + &Mfield{Name: "endpoint_type", Ty: TypeInt8}, + &Mfield{Name: "cluster_id", Ty: TypeStr}, + &Mfield{Name: "controller_id", Ty: TypeInt32}, + &CompactArray{Name: brokersKeyName, Ty: describeClusterBrokerV2}, + &Mfield{Name: "cluster_authorized_operations", Ty: TypeInt32}, + &SchemaTaggedFields{Name: "response_tagged_fields"}, + ) + return []Schema{describeClusterV0, describeClusterV1, describeClusterV2} +} + func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error { if decodedStruct == nil { return errors.New("decoded struct must not be nil") @@ -416,6 +476,56 @@ func modifyCoordinator(decodedStruct *Struct, fn config.NetAddressMappingFunc) e return nil } +func modifyDescribeClusterResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error { + if decodedStruct == nil { + return errors.New("decoded struct must not be nil") + } + if fn == nil { + return errors.New("net address mapper must not be nil") + } + brokersArray, ok := decodedStruct.Get(brokersKeyName).([]interface{}) + if !ok { + return errors.New("brokers list not found") + } + for _, brokerElement := range brokersArray { + broker := brokerElement.(*Struct) + host, ok := broker.Get(hostKeyName).(string) + if !ok { + return errors.New("broker.host not found") + } + port, ok := broker.Get(portKeyName).(int32) + if !ok { + return errors.New("broker.port not found") + } + brokerId, ok := broker.Get(brokerKeyName).(int32) + if !ok { + return errors.New("broker.broker_id not found") + } + + if host == "" && port <= 0 { + continue + } + + newHost, newPort, err := fn(host, port, brokerId) + if err != nil { + return err + } + if host != newHost { + err := broker.Replace(hostKeyName, newHost) + if err != nil { + return err + } + } + if port != newPort { + err = broker.Replace(portKeyName, newPort) + if err != nil { + return err + } + } + } + return nil +} + type ResponseModifier interface { Apply(resp []byte) ([]byte, error) } @@ -446,6 +556,8 @@ func GetResponseModifier(apiKey int16, apiVersion int16, addressMappingFunc conf return newResponseModifier(apiKey, apiVersion, addressMappingFunc, metadataResponseSchemaVersions, modifyMetadataResponse) case apiKeyFindCoordinator: return newResponseModifier(apiKey, apiVersion, addressMappingFunc, findCoordinatorResponseSchemaVersions, modifyFindCoordinatorResponse) + case apiKeyDescribeCluster: + return newResponseModifier(apiKey, apiVersion, addressMappingFunc, describeClusterResponseSchemaVersions, modifyDescribeClusterResponse) default: return nil, nil } diff --git a/proxy/protocol/responses_test.go b/proxy/protocol/responses_test.go index 1ba49d41..b9301939 100644 --- a/proxy/protocol/responses_test.go +++ b/proxy/protocol/responses_test.go @@ -2453,11 +2453,77 @@ func testMetadataResponse(t *testing.T, apiVersion int16, payload string, expect if err != nil { t.Fatal(err) } - /* - for _, av := range dc.AttrValues() { - fmt.Printf("\"%s\",\n", av) - } - */ + a.Equal(expectedInput, dc.AttrValues()) + resp, err := EncodeSchema(s, schema) + a.Nil(err) + a.Equal(bytes, resp) + + modifier, err := GetResponseModifier(apiKeyMetadata, apiVersion, testResponseModifier2) + if err != nil { + t.Fatal(err) + } + a.Nil(err) + resp, err = modifier.Apply(resp) + a.Nil(err) + s, err = DecodeSchema(resp, schema) + a.Nil(err) + dc = newDecodeCheck() + err = dc.Traverse(s) + if err != nil { + t.Fatal(err) + } + a.Equal(expectedModified, dc.AttrValues()) +} + +func TestDescribeClusterV0(t *testing.T) { + apiVersion := int16(0) + payload := "0000000004000000010a6c6f63616c686f737400004a940000000000020a6c6f63616c686f7374000071a40000000000030a6c6f63616c686f7374000098b4000000ffffffff040000135f5f636f6e73756d65725f6f6666736574730000000000000000000000000000000001020000000000010000000100000005040000000100000002000000030400000001000000020000000301008000000000000507746f706963323293e5edee894e2cb924edb932be3d6e00018000000000000007746f70696333000000000000000000000000000000000002000500000000ffffffffffffffff04000000010000000200000003040000000100000002000000030100800000000000" + expectedInput := []string{ + "throttle_time_ms int32 0", + "error_code int16 0", + "error_message *string ", + "cluster_id string my_cluster", + "controller_id int32 0", + "[brokers]", + "brokers struct", + "broker_id int32 0", + "host string localhost", + "port int32 9092", + "rack *string ", + } + expectedModified := []string{ + "throttle_time_ms int32 0", + "error_code int16 0", + "error_message *string ", + "cluster_id string my_cluster", + "controller_id int32 0", + "[brokers]", + "brokers struct", + "broker_id int32 0", + "host string localhost", + "port int32 9092", + "rack *string ", + } + testDescribeClusterResponse(t, apiVersion, payload, expectedInput, expectedModified) +} + +func testDescribeClusterResponse(t *testing.T, apiVersion int16, payload string, expectedInput, expectedModified []string) { + bytes, err := hex.DecodeString(payload) + if err != nil { + t.Fatal(err) + } + a := assert.New(t) + + schema := describeClusterResponseSchemaVersions[apiVersion] + + s, err := DecodeSchema(bytes, schema) + a.Nil(err) + + dc := newDecodeCheck() + err = dc.Traverse(s) + if err != nil { + t.Fatal(err) + } a.Equal(expectedInput, dc.AttrValues()) resp, err := EncodeSchema(s, schema) a.Nil(err) diff --git a/proxy/protocol/schema.go b/proxy/protocol/schema.go index 6810916c..aa824669 100644 --- a/proxy/protocol/schema.go +++ b/proxy/protocol/schema.go @@ -3,14 +3,16 @@ package protocol import ( "bytes" "fmt" - "github.com/google/uuid" "reflect" + "github.com/google/uuid" + "github.com/pkg/errors" ) var ( TypeBool = &Bool{} + TypeInt8 = &Int8{} TypeInt16 = &Int16{} TypeInt32 = &Int32{} TypeStr = &Str{} @@ -103,6 +105,34 @@ func (f *Bool) GetName() string { return "bool" } +// Field int8 + +type Int8 struct{} + +func (f *Int8) decode(pd packetDecoder) (interface{}, error) { + return pd.getInt8() +} +func (f *Int8) encode(pe packetEncoder, value interface{}) error { + in, ok := value.(int8) + if !ok { + return SchemaEncodingError{fmt.Sprintf("value %T not a int8", value)} + } + pe.putInt8(in) + return nil +} + +func (f *Int8) GetFields() []boundField { + return nil +} + +func (f *Int8) GetFieldsByName() map[string]*boundField { + return nil +} + +func (f *Int8) GetName() string { + return "int8" +} + // Field int16 type Int16 struct{} diff --git a/proxy/proxy.go b/proxy/proxy.go index 6fbf4190..da73d299 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -14,7 +14,7 @@ import ( "github.com/sirupsen/logrus" ) -type ListenFunc func(cfg *ListenerConfig) (l net.Listener, err error) +type ListenFunc func(listenAddress string) (l net.Listener, err error) type Listeners struct { // Source of new connections to Kafka broker. @@ -27,6 +27,7 @@ type Listeners struct { tcpConnOptions TCPConnOptions listenFunc ListenFunc + tlsConfig *tls.Config deterministicListeners bool disableDynamicListeners bool @@ -52,11 +53,12 @@ func NewListeners(cfg *config.Config) (*Listeners, error) { } } - listenFunc := func(cfg *ListenerConfig) (net.Listener, error) { + listenFunc := func(listenAddress string) (ln net.Listener, err error) { if tlsConfig != nil { - return tls.Listen("tcp", cfg.ListenerAddress, tlsConfig) + logrus.Infof("%s: Starting tls listener", listenAddress) + return tls.Listen("tcp", listenAddress, tlsConfig) } - return net.Listen("tcp", cfg.ListenerAddress) + return net.Listen("tcp", listenAddress) } brokerToListenerConfig, err := getBrokerToListenerConfig(cfg) @@ -69,6 +71,7 @@ func NewListeners(cfg *config.Config) (*Listeners, error) { dynamicAdvertisedListener: cfg.Proxy.DynamicAdvertisedListener, connSrc: make(chan Conn, 1), brokerToListenerConfig: brokerToListenerConfig, + tlsConfig: tlsConfig, tcpConnOptions: tcpConnOptions, listenFunc: listenFunc, deterministicListeners: cfg.Proxy.DeterministicListeners, @@ -118,7 +121,7 @@ func getBrokerToListenerConfig(cfg *config.Config) (map[string]*ListenerConfig, return brokerToListenerConfig, nil } -func (p *Listeners) GetNetAddressMapping(brokerHost string, brokerPort int32, brokerId int32) (listenerHost string, listenerPort int32, err error) { +func (p *Listeners) GetNetAddressMapping(brokerHost string, brokerPort int32, brokerId int32) (advertisedHost string, advertisedPort int32, err error) { if brokerHost == "" || brokerPort <= 0 { return "", 0, fmt.Errorf("broker address '%s:%d' is invalid", brokerHost, brokerPort) } @@ -137,7 +140,24 @@ func (p *Listeners) GetNetAddressMapping(brokerHost string, brokerPort int32, br logrus.Infof("Starting dynamic listener for broker %s", brokerAddress) return p.ListenDynamicInstance(brokerAddress, brokerId) } - return "", 0, fmt.Errorf("net address mapping for %s:%d was not found", brokerHost, brokerPort) + if len(p.dynamicAdvertisedListener) == 0 { + return "", 0, fmt.Errorf("net address mapping for %s:%d was not found", brokerHost, brokerPort) + } + p.lock.Lock() + defer p.lock.Unlock() + dynamicAdvertisedHost, dynamicAdvertisedPort, err := p.getDynamicAdvertisedAddress(brokerId, int(brokerPort)) + if err != nil { + return "", 0, err + } + // TODO: what should be the p.defaultListenerIP for the listener? + cfg := NewListenerConfig( + brokerAddress, + net.JoinHostPort(p.defaultListenerIP, fmt.Sprintf("%v", dynamicAdvertisedPort)), + net.JoinHostPort(dynamicAdvertisedHost, fmt.Sprint(dynamicAdvertisedPort)), + brokerId, + ) + p.brokerToListenerConfig[brokerAddress] = cfg + return util.SplitHostPort(cfg.GetAdvertisedAddress()) } func (p *Listeners) findListenerConfig(brokerId int32) *ListenerConfig { @@ -149,7 +169,10 @@ func (p *Listeners) findListenerConfig(brokerId int32) *ListenerConfig { return nil } +// ListenDynamicInstance creates a new listener for the upstream broker address and broker id. func (p *Listeners) ListenDynamicInstance(brokerAddress string, brokerId int32) (string, int32, error) { + + fmt.Println("proxy/proxy.go: ListenDynamicInstance:", "brokerAddress", brokerAddress, "brokerId", brokerId) p.lock.Lock() defer p.lock.Unlock() // double check @@ -184,8 +207,9 @@ func (p *Listeners) ListenDynamicInstance(brokerAddress string, brokerId int32) p.dynamicSequentialMinPort += 1 } } + cfg := NewListenerConfig(brokerAddress, listenerAddress, "", brokerId) - l, err := listenInstance(p.connSrc, cfg, p.tcpConnOptions, p.listenFunc) + l, err := p.listenInstance(p.connSrc, cfg, p.tcpConnOptions, p.listenFunc, p) if err != nil { return "", 0, err } @@ -245,14 +269,33 @@ func (p *Listeners) templateDynamicAdvertisedAddress(brokerID int32) (string, er return buf.String(), nil } +func (p *Listeners) GetBrokerAddressByAdvertisedHost(host string) (brokerAddress string, brokerId int32, err error) { + + // TODO: Use a better algorithm than a for loop + // Maybe use a bi-directional map? I believe downstream host-port always has only one upstream hostport. + for brokerAddr, c := range p.brokerToListenerConfig { + if c == nil { + continue + } + advertisedHost, _, err := net.SplitHostPort(c.GetAdvertisedAddress()) + if err != nil { + fmt.Println("failed to split address", err) + } + if advertisedHost == host { + return brokerAddr, c.GetBrokerID(), nil + } + } + + return "", 0, fmt.Errorf("broker not found for advertised host %q", host) +} + func (p *Listeners) ListenInstances(cfgs []config.ListenerConfig) (<-chan Conn, error) { p.lock.Lock() defer p.lock.Unlock() - // allows multiple local addresses to point to the remote for _, v := range cfgs { cfg := FromListenerConfig(v) - _, err := listenInstance(p.connSrc, cfg, p.tcpConnOptions, p.listenFunc) + _, err := p.listenInstance(p.connSrc, cfg, p.tcpConnOptions, p.listenFunc, p) if err != nil { return nil, err } @@ -260,8 +303,21 @@ func (p *Listeners) ListenInstances(cfgs []config.ListenerConfig) (<-chan Conn, return p.connSrc, nil } -func listenInstance(dst chan<- Conn, cfg *ListenerConfig, opts TCPConnOptions, listenFunc ListenFunc) (net.Listener, error) { - l, err := listenFunc(cfg) +type BrokerConfigMap interface { + ToListenerConfig() config.ListenerConfig + GetListenerAddress() string + GetBrokerAddress() string + GetAdvertisedAddress() string + GetBrokerID() int32 +} + +type HostBasedRouting interface { + GetBrokerAddressByAdvertisedHost(host string) (brokerAddress string, brokerId int32, err error) +} + +func (p *Listeners) listenInstance(dst chan<- Conn, cfg BrokerConfigMap, opts TCPConnOptions, listenFunc ListenFunc, brokers HostBasedRouting) (net.Listener, error) { + + l, err := listenFunc(cfg.GetListenerAddress()) if err != nil { return nil, err } @@ -269,28 +325,74 @@ func listenInstance(dst chan<- Conn, cfg *ListenerConfig, opts TCPConnOptions, l for { c, err := l.Accept() if err != nil { - logrus.Infof("Error in accept for %q on %v: %v", cfg.ToListenerConfig(), cfg.ListenerAddress, err) + logrus.Infof("Error in accept for %q on %v: %v", cfg.ToListenerConfig(), l.Addr(), err) l.Close() return } - if tcpConn, ok := c.(*net.TCPConn); ok { - if err := opts.setTCPConnOptions(tcpConn); err != nil { + + logrus.Infof("%s: Client(%s) Accepted connection", l.Addr().String(), c.RemoteAddr().String()) + + var ( + clientServerName string + underlyingConn net.Conn = c + ) + + if conn, ok := underlyingConn.(*tls.Conn); ok { + logrus.Infof("%s: Client(%s) Accepted tls connection", l.Addr().String(), c.RemoteAddr().String()) + + tlsState := conn.ConnectionState() + + if !tlsState.HandshakeComplete { + err := conn.Handshake() + if err != nil { + logrus.Warnf("downstream client tls handshake error: %s", err) + } + } + tlsState = conn.ConnectionState() + + logrus.Infof("%s: Client(%s) Accepted tls connection with server name: %q", l.Addr().String(), c.RemoteAddr().String(), tlsState.ServerName) + if tlsState.ServerName == "" { + logrus.Warnf("%s: Client(%s) tls server name could not be read, defaulting to advertised %s", l.Addr().String(), c.RemoteAddr().String(), cfg.GetAdvertisedAddress()) + advertisedHost, _, _ := net.SplitHostPort(cfg.GetAdvertisedAddress()) + clientServerName = advertisedHost + } else { + clientServerName = tlsState.ServerName + logrus.Infof("%s: Client(%s) Discovered server name from tls client hello: %s", l.Addr().String(), c.RemoteAddr().String(), clientServerName) + } + underlyingConn = conn.NetConn() + } + + if conn, ok := underlyingConn.(*net.TCPConn); ok { + if err := opts.setTCPConnOptions(conn); err != nil { logrus.Infof("WARNING: Error while setting TCP options for accepted connection %q on %v: %v", cfg.ToListenerConfig(), l.Addr().String(), err) } } + brokerAddress := cfg.GetBrokerAddress() - if cfg.BrokerID != UnknownBrokerID { - logrus.Infof("New connection for %s brokerId %d", brokerAddress, cfg.BrokerID) + brokerId := cfg.GetBrokerID() + + if clientServerName != "" { + if address, id, err := brokers.GetBrokerAddressByAdvertisedHost(clientServerName); err == nil { + brokerAddress = address + brokerId = id + } else { + logrus.Infof("%s: Failed to match host/authority %q with any advertised address", l.Addr(), clientServerName) + } + } + + if brokerId == UnknownBrokerID { + logrus.Infof("%s: Client(%s) connected with server name %q", l.Addr(), c.RemoteAddr().String(), clientServerName) } else { - logrus.Infof("New connection for %s", brokerAddress) + logrus.Infof("%s: Client(%s) connected for %s brokerId %d with server name %s", l.Addr(), c.RemoteAddr().String(), brokerAddress, brokerId, clientServerName) } + logrus.Infof("%s: Client(%s) proxying starting", l.Addr(), c.RemoteAddr().String()) dst <- Conn{BrokerAddress: brokerAddress, LocalConnection: c} } }) - if cfg.BrokerID != UnknownBrokerID { - logrus.Infof("Listening on %s (%s) for remote %s broker %d", cfg.ListenerAddress, l.Addr().String(), cfg.GetBrokerAddress(), cfg.BrokerID) + if cfg.GetBrokerID() != UnknownBrokerID { + logrus.Infof("Listening on %s for remote %s broker %d", l.Addr().String(), cfg.GetBrokerAddress(), cfg.GetBrokerID()) } else { - logrus.Infof("Listening on %s (%s) for remote %s", cfg.ListenerAddress, l.Addr().String(), cfg.GetBrokerAddress()) + logrus.Infof("Listening on %s for remote %s", l.Addr().String(), cfg.GetBrokerAddress()) } return l, nil } diff --git a/proxy/proxy_config.go b/proxy/proxy_config.go index 79a17322..a778be67 100644 --- a/proxy/proxy_config.go +++ b/proxy/proxy_config.go @@ -1,8 +1,10 @@ package proxy import ( - "github.com/grepplabs/kafka-proxy/config" + "fmt" "sync/atomic" + + "github.com/grepplabs/kafka-proxy/config" ) const UnknownBrokerID = -1 @@ -49,6 +51,22 @@ func (c *ListenerConfig) GetBrokerAddress() string { return *addressPtr } +func (c *ListenerConfig) GetListenerAddress() string { + return c.ListenerAddress +} + +func (c *ListenerConfig) GetBrokerID() int32 { + return c.BrokerID +} + +func (c *ListenerConfig) GetAdvertisedAddress() string { + return c.AdvertisedAddress +} + func (c *ListenerConfig) SetBrokerAddress(address string) { c.BrokerAddressPtr.Store(&address) } + +func (c *ListenerConfig) String() string { + return fmt.Sprintf("ListenerConfig[%d]: %s->%s->%s", c.BrokerID, c.AdvertisedAddress, c.ListenerAddress, c.GetBrokerAddress()) +} diff --git a/proxy/tls_test.go b/proxy/tls_test.go index 06223bfd..8c119046 100644 --- a/proxy/tls_test.go +++ b/proxy/tls_test.go @@ -119,7 +119,11 @@ func TestTLSUnknownAuthorityNoCAChainCert(t *testing.T) { c.Proxy.TLS.ListenerKeyFile = bundle.ServerKey.Name() _, _, _, err := makeTLSPipe(c, nil) - a.EqualError(err, "tls: failed to verify certificate: x509: certificate signed by unknown authority") + + // Can be + // - tls: failed to verify certificate: x509: “localhost” certificate is not standards compliant + // - tls: failed to verify certificate: x509: certificate signed by unknown authority + a.ErrorContains(err, "tls: failed to verify certificate: x509:") } func TestTLSUnknownAuthorityWrongCAChainCert(t *testing.T) { diff --git a/vendor/github.com/pires/go-proxyproto/.gitignore b/vendor/github.com/pires/go-proxyproto/.gitignore new file mode 100644 index 00000000..a2d2c301 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/.gitignore @@ -0,0 +1,11 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +.idea +bin +pkg + +*.out diff --git a/vendor/github.com/pires/go-proxyproto/LICENSE b/vendor/github.com/pires/go-proxyproto/LICENSE new file mode 100644 index 00000000..a65c05a6 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2016 Paulo Pires + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/pires/go-proxyproto/README.md b/vendor/github.com/pires/go-proxyproto/README.md new file mode 100644 index 00000000..982707cc --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/README.md @@ -0,0 +1,162 @@ +# go-proxyproto + +[![Actions Status](https://github.com/pires/go-proxyproto/workflows/test/badge.svg)](https://github.com/pires/go-proxyproto/actions) +[![Coverage Status](https://coveralls.io/repos/github/pires/go-proxyproto/badge.svg?branch=master)](https://coveralls.io/github/pires/go-proxyproto?branch=master) +[![Go Report Card](https://goreportcard.com/badge/github.com/pires/go-proxyproto)](https://goreportcard.com/report/github.com/pires/go-proxyproto) +[![](https://godoc.org/github.com/pires/go-proxyproto?status.svg)](https://pkg.go.dev/github.com/pires/go-proxyproto?tab=doc) + + +A Go library implementation of the [PROXY protocol, versions 1 and 2](https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt), +which provides, as per specification: +> (...) a convenient way to safely transport connection +> information such as a client's address across multiple layers of NAT or TCP +> proxies. It is designed to require little changes to existing components and +> to limit the performance impact caused by the processing of the transported +> information. + +This library is to be used in one of or both proxy clients and proxy servers that need to support said protocol. +Both protocol versions, 1 (text-based) and 2 (binary-based) are supported. + +## Installation + +```shell +$ go get -u github.com/pires/go-proxyproto +``` + +## Usage + +### Client + +```go +package main + +import ( + "io" + "log" + "net" + + proxyproto "github.com/pires/go-proxyproto" +) + +func chkErr(err error) { + if err != nil { + log.Fatalf("Error: %s", err.Error()) + } +} + +func main() { + // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto + target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:2319") + chkErr(err) + + conn, err := net.DialTCP("tcp", nil, target) + chkErr(err) + + defer conn.Close() + + // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you + // have two conn's + header := &proxyproto.Header{ + Version: 1, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP("20.2.2.2"), + Port: 2000, + }, + } + // After the connection was created write the proxy headers first + _, err = header.WriteTo(conn) + chkErr(err) + // Then your data... e.g.: + _, err = io.WriteString(conn, "HELO") + chkErr(err) +} +``` + +### Server + +```go +package main + +import ( + "log" + "net" + + proxyproto "github.com/pires/go-proxyproto" +) + +func main() { + // Create a listener + addr := "localhost:9876" + list, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) + } + + // Wrap listener in a proxyproto listener + proxyListener := &proxyproto.Listener{Listener: list} + defer proxyListener.Close() + + // Wait for a connection and accept it + conn, err := proxyListener.Accept() + defer conn.Close() + + // Print connection details + if conn.LocalAddr() == nil { + log.Fatal("couldn't retrieve local address") + } + log.Printf("local address: %q", conn.LocalAddr().String()) + + if conn.RemoteAddr() == nil { + log.Fatal("couldn't retrieve remote address") + } + log.Printf("remote address: %q", conn.RemoteAddr().String()) +} +``` + +### HTTP Server +```go +package main + +import ( + "net" + "net/http" + "time" + + "github.com/pires/go-proxyproto" +) + +func main() { + server := http.Server{ + Addr: ":8080", + } + + ln, err := net.Listen("tcp", server.Addr) + if err != nil { + panic(err) + } + + proxyListener := &proxyproto.Listener{ + Listener: ln, + ReadHeaderTimeout: 10 * time.Second, + } + defer proxyListener.Close() + + server.Serve(proxyListener) +} +``` + +## Special notes + +### AWS + +AWS Network Load Balancer (NLB) does not push the PPV2 header until the client starts sending the data. This is a problem if your server speaks first. e.g. SMTP, FTP, SSH etc. + +By default, NLB target group attribute `proxy_protocol_v2.client_to_server.header_placement` has the value `on_first_ack_with_payload`. You need to contact AWS support to change it to `on_first_ack`, instead. + +Just to be clear, you need this fix only if your server is designed to speak first. diff --git a/vendor/github.com/pires/go-proxyproto/addr_proto.go b/vendor/github.com/pires/go-proxyproto/addr_proto.go new file mode 100644 index 00000000..d254fc41 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/addr_proto.go @@ -0,0 +1,62 @@ +package proxyproto + +// AddressFamilyAndProtocol represents address family and transport protocol. +type AddressFamilyAndProtocol byte + +const ( + UNSPEC AddressFamilyAndProtocol = '\x00' + TCPv4 AddressFamilyAndProtocol = '\x11' + UDPv4 AddressFamilyAndProtocol = '\x12' + TCPv6 AddressFamilyAndProtocol = '\x21' + UDPv6 AddressFamilyAndProtocol = '\x22' + UnixStream AddressFamilyAndProtocol = '\x31' + UnixDatagram AddressFamilyAndProtocol = '\x32' +) + +// IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise. +func (ap AddressFamilyAndProtocol) IsIPv4() bool { + return ap&0xF0 == 0x10 +} + +// IsIPv6 returns true if the address family is IPv6 (AF_INET6), false otherwise. +func (ap AddressFamilyAndProtocol) IsIPv6() bool { + return ap&0xF0 == 0x20 +} + +// IsUnix returns true if the address family is UNIX (AF_UNIX), false otherwise. +func (ap AddressFamilyAndProtocol) IsUnix() bool { + return ap&0xF0 == 0x30 +} + +// IsStream returns true if the transport protocol is TCP or STREAM (SOCK_STREAM), false otherwise. +func (ap AddressFamilyAndProtocol) IsStream() bool { + return ap&0x0F == 0x01 +} + +// IsDatagram returns true if the transport protocol is UDP or DGRAM (SOCK_DGRAM), false otherwise. +func (ap AddressFamilyAndProtocol) IsDatagram() bool { + return ap&0x0F == 0x02 +} + +// IsUnspec returns true if the transport protocol or address family is unspecified, false otherwise. +func (ap AddressFamilyAndProtocol) IsUnspec() bool { + return (ap&0xF0 == 0x00) || (ap&0x0F == 0x00) +} + +func (ap AddressFamilyAndProtocol) toByte() byte { + if ap.IsIPv4() && ap.IsStream() { + return byte(TCPv4) + } else if ap.IsIPv4() && ap.IsDatagram() { + return byte(UDPv4) + } else if ap.IsIPv6() && ap.IsStream() { + return byte(TCPv6) + } else if ap.IsIPv6() && ap.IsDatagram() { + return byte(UDPv6) + } else if ap.IsUnix() && ap.IsStream() { + return byte(UnixStream) + } else if ap.IsUnix() && ap.IsDatagram() { + return byte(UnixDatagram) + } + + return byte(UNSPEC) +} diff --git a/vendor/github.com/pires/go-proxyproto/header.go b/vendor/github.com/pires/go-proxyproto/header.go new file mode 100644 index 00000000..209c2ccf --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/header.go @@ -0,0 +1,280 @@ +// Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification: +// https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt +package proxyproto + +import ( + "bufio" + "bytes" + "errors" + "io" + "net" + "time" +) + +var ( + // Protocol + SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'} + SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'} + + ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header") + ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less") + ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n") + ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command") + ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol") + ErrCantReadLength = errors.New("proxyproto: can't read length") + ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address") + ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address") + ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present") + ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version") + ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command") + ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol") + ErrInvalidLength = errors.New("proxyproto: invalid length") + ErrInvalidAddress = errors.New("proxyproto: invalid address") + ErrInvalidPortNumber = errors.New("proxyproto: invalid port number") + ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one") +) + +// Header is the placeholder for proxy protocol header. +type Header struct { + Version byte + Command ProtocolVersionAndCommand + TransportProtocol AddressFamilyAndProtocol + SourceAddr net.Addr + DestinationAddr net.Addr + rawTLVs []byte +} + +// HeaderProxyFromAddrs creates a new PROXY header from a source and a +// destination address. If version is zero, the latest protocol version is +// used. +// +// The header is filled on a best-effort basis: if hints cannot be inferred +// from the provided addresses, the header will be left unspecified. +func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header { + if version < 1 || version > 2 { + version = 2 + } + h := &Header{ + Version: version, + Command: LOCAL, + TransportProtocol: UNSPEC, + } + switch sourceAddr := sourceAddr.(type) { + case *net.TCPAddr: + if _, ok := destAddr.(*net.TCPAddr); !ok { + break + } + if len(sourceAddr.IP.To4()) == net.IPv4len { + h.TransportProtocol = TCPv4 + } else if len(sourceAddr.IP) == net.IPv6len { + h.TransportProtocol = TCPv6 + } + case *net.UDPAddr: + if _, ok := destAddr.(*net.UDPAddr); !ok { + break + } + if len(sourceAddr.IP.To4()) == net.IPv4len { + h.TransportProtocol = UDPv4 + } else if len(sourceAddr.IP) == net.IPv6len { + h.TransportProtocol = UDPv6 + } + case *net.UnixAddr: + if _, ok := destAddr.(*net.UnixAddr); !ok { + break + } + switch sourceAddr.Net { + case "unix": + h.TransportProtocol = UnixStream + case "unixgram": + h.TransportProtocol = UnixDatagram + } + } + if h.TransportProtocol != UNSPEC { + h.Command = PROXY + h.SourceAddr = sourceAddr + h.DestinationAddr = destAddr + } + return h +} + +func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) { + if !header.TransportProtocol.IsStream() { + return nil, nil, false + } + sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) + destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) + return sourceAddr, destAddr, sourceOK && destOK +} + +func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) { + if !header.TransportProtocol.IsDatagram() { + return nil, nil, false + } + sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr) + destAddr, destOK := header.DestinationAddr.(*net.UDPAddr) + return sourceAddr, destAddr, sourceOK && destOK +} + +func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) { + if !header.TransportProtocol.IsUnix() { + return nil, nil, false + } + sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr) + destAddr, destOK := header.DestinationAddr.(*net.UnixAddr) + return sourceAddr, destAddr, sourceOK && destOK +} + +func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) { + if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { + return sourceAddr.IP, destAddr.IP, true + } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { + return sourceAddr.IP, destAddr.IP, true + } else { + return nil, nil, false + } +} + +func (header *Header) Ports() (sourcePort, destPort int, ok bool) { + if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { + return sourceAddr.Port, destAddr.Port, true + } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { + return sourceAddr.Port, destAddr.Port, true + } else { + return 0, 0, false + } +} + +// EqualTo returns true if headers are equivalent, false otherwise. +// Deprecated: use EqualsTo instead. This method will eventually be removed. +func (header *Header) EqualTo(otherHeader *Header) bool { + return header.EqualsTo(otherHeader) +} + +// EqualsTo returns true if headers are equivalent, false otherwise. +func (header *Header) EqualsTo(otherHeader *Header) bool { + if otherHeader == nil { + return false + } + if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol { + return false + } + // TLVs only exist for version 2 + if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { + return false + } + // Return early for header with LOCAL command, which contains no address information + if header.Command == LOCAL { + return true + } + return header.SourceAddr.String() == otherHeader.SourceAddr.String() && + header.DestinationAddr.String() == otherHeader.DestinationAddr.String() +} + +// WriteTo renders a proxy protocol header in a format and writes it to an io.Writer. +func (header *Header) WriteTo(w io.Writer) (int64, error) { + buf, err := header.Format() + if err != nil { + return 0, err + } + + return bytes.NewBuffer(buf).WriteTo(w) +} + +// Format renders a proxy protocol header in a format to write over the wire. +func (header *Header) Format() ([]byte, error) { + switch header.Version { + case 1: + return header.formatVersion1() + case 2: + return header.formatVersion2() + default: + return nil, ErrUnknownProxyProtocolVersion + } +} + +// TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol. +func (header *Header) TLVs() ([]TLV, error) { + return SplitTLVs(header.rawTLVs) +} + +// SetTLVs sets the TLVs stored in this header. This method replaces any +// previous TLV. +func (header *Header) SetTLVs(tlvs []TLV) error { + raw, err := JoinTLVs(tlvs) + if err != nil { + return err + } + header.rawTLVs = raw + return nil +} + +// Read identifies the proxy protocol version and reads the remaining of +// the header, accordingly. +// +// If proxy protocol header signature is not present, the reader buffer remains untouched +// and is safe for reading outside of this code. +// +// If proxy protocol header signature is present but an error is raised while processing +// the remaining header, assume the reader buffer to be in a corrupt state. +// Also, this operation will block until enough bytes are available for peeking. +func Read(reader *bufio.Reader) (*Header, error) { + // In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone. + b1, err := reader.Peek(1) + if err != nil { + if err == io.EOF { + return nil, ErrNoProxyProtocol + } + return nil, err + } + + if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) { + signature, err := reader.Peek(5) + if err != nil { + if err == io.EOF { + return nil, ErrNoProxyProtocol + } + return nil, err + } + if bytes.Equal(signature[:5], SIGV1) { + return parseVersion1(reader) + } + + signature, err = reader.Peek(12) + if err != nil { + if err == io.EOF { + return nil, ErrNoProxyProtocol + } + return nil, err + } + if bytes.Equal(signature[:12], SIGV2) { + return parseVersion2(reader) + } + } + + return nil, ErrNoProxyProtocol +} + +// ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed +// there's no proxy protocol header. +func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) { + type header struct { + h *Header + e error + } + read := make(chan *header, 1) + + go func() { + h := &header{} + h.h, h.e = Read(reader) + read <- h + }() + + timer := time.NewTimer(timeout) + select { + case result := <-read: + timer.Stop() + return result.h, result.e + case <-timer.C: + return nil, ErrNoProxyProtocol + } +} diff --git a/vendor/github.com/pires/go-proxyproto/policy.go b/vendor/github.com/pires/go-proxyproto/policy.go new file mode 100644 index 00000000..ebef8b98 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/policy.go @@ -0,0 +1,206 @@ +package proxyproto + +import ( + "fmt" + "net" + "strings" +) + +// PolicyFunc can be used to decide whether to trust the PROXY info from +// upstream. If set, the connecting address is passed in as an argument. +// +// See below for the different policies. +// +// In case an error is returned the connection is denied. +type PolicyFunc func(upstream net.Addr) (Policy, error) + +// ConnPolicyFunc can be used to decide whether to trust the PROXY info +// based on connection policy options. If set, the connecting addresses +// (remote and local) are passed in as argument. +// +// See below for the different policies. +// +// In case an error is returned the connection is denied. +type ConnPolicyFunc func(connPolicyOptions ConnPolicyOptions) (Policy, error) + +// ConnPolicyOptions contains the remote and local addresses of a connection. +type ConnPolicyOptions struct { + Upstream net.Addr + Downstream net.Addr +} + +// Policy defines how a connection with a PROXY header address is treated. +type Policy int + +const ( + // USE address from PROXY header + USE Policy = iota + // IGNORE address from PROXY header, but accept connection + IGNORE + // REJECT connection when PROXY header is sent + // Note: even though the first read on the connection returns an error if + // a PROXY header is present, subsequent reads do not. It is the task of + // the code using the connection to handle that case properly. + REJECT + // REQUIRE connection to send PROXY header, reject if not present + // Note: even though the first read on the connection returns an error if + // a PROXY header is not present, subsequent reads do not. It is the task + // of the code using the connection to handle that case properly. + REQUIRE + // SKIP accepts a connection without requiring the PROXY header + // Note: an example usage can be found in the SkipProxyHeaderForCIDR + // function. + SKIP +) + +// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a +// connection from a skipHeaderCIDR without requiring a PROXY header, e.g. +// Kubernetes pods local traffic. The def is a policy to use when an upstream +// address doesn't match the skipHeaderCIDR. +func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { + return func(upstream net.Addr) (Policy, error) { + ip, err := ipFromAddr(upstream) + if err != nil { + return def, err + } + + if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) { + return SKIP, nil + } + + return def, nil + } +} + +// WithPolicy adds given policy to a connection when passed as option to NewConn() +func WithPolicy(p Policy) func(*Conn) { + return func(c *Conn) { + c.ProxyHeaderPolicy = p + } +} + +// LaxWhiteListPolicy returns a PolicyFunc which decides whether the +// upstream ip is allowed to send a proxy header based on a list of allowed +// IP addresses and IP ranges. In case upstream IP is not in list the proxy +// header will be ignored. If one of the provided IP addresses or IP ranges +// is invalid it will return an error instead of a PolicyFunc. +func LaxWhiteListPolicy(allowed []string) (PolicyFunc, error) { + allowFrom, err := parse(allowed) + if err != nil { + return nil, err + } + + return whitelistPolicy(allowFrom, IGNORE), nil +} + +// MustLaxWhiteListPolicy returns a LaxWhiteListPolicy but will panic if one +// of the provided IP addresses or IP ranges is invalid. +func MustLaxWhiteListPolicy(allowed []string) PolicyFunc { + pfunc, err := LaxWhiteListPolicy(allowed) + if err != nil { + panic(err) + } + + return pfunc +} + +// StrictWhiteListPolicy returns a PolicyFunc which decides whether the +// upstream ip is allowed to send a proxy header based on a list of allowed +// IP addresses and IP ranges. In case upstream IP is not in list reading on +// the connection will be refused on the first read. Please note: subsequent +// reads do not error. It is the task of the code using the connection to +// handle that case properly. If one of the provided IP addresses or IP +// ranges is invalid it will return an error instead of a PolicyFunc. +func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { + allowFrom, err := parse(allowed) + if err != nil { + return nil, err + } + + return whitelistPolicy(allowFrom, REJECT), nil +} + +// MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic +// if one of the provided IP addresses or IP ranges is invalid. +func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { + pfunc, err := StrictWhiteListPolicy(allowed) + if err != nil { + panic(err) + } + + return pfunc +} + +func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc { + return func(upstream net.Addr) (Policy, error) { + upstreamIP, err := ipFromAddr(upstream) + if err != nil { + // something is wrong with the source IP, better reject the connection + return REJECT, err + } + + for _, allowFrom := range allowed { + if allowFrom(upstreamIP) { + return USE, nil + } + } + + return def, nil + } +} + +func parse(allowed []string) ([]func(net.IP) bool, error) { + a := make([]func(net.IP) bool, len(allowed)) + for i, allowFrom := range allowed { + if strings.LastIndex(allowFrom, "/") > 0 { + _, ipRange, err := net.ParseCIDR(allowFrom) + if err != nil { + return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP range: %v", allowFrom, err) + } + + a[i] = ipRange.Contains + } else { + allowed := net.ParseIP(allowFrom) + if allowed == nil { + return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP address", allowFrom) + } + + a[i] = allowed.Equal + } + } + + return a, nil +} + +func ipFromAddr(upstream net.Addr) (net.IP, error) { + upstreamString, _, err := net.SplitHostPort(upstream.String()) + if err != nil { + return nil, err + } + + upstreamIP := net.ParseIP(upstreamString) + if nil == upstreamIP { + return nil, fmt.Errorf("proxyproto: invalid IP address") + } + + return upstreamIP, nil +} + +// IgnoreProxyHeaderNotOnInterface retuns a ConnPolicyFunc which can be used to +// decide whether to use or ignore PROXY headers depending on the connection +// being made on a specific interface. This policy can be used when the server +// is bound to multiple interfaces but wants to allow on only one interface. +func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) ConnPolicyFunc { + return func(connOpts ConnPolicyOptions) (Policy, error) { + ip, err := ipFromAddr(connOpts.Downstream) + if err != nil { + return REJECT, err + } + + if allowedIP.Equal(ip) { + return USE, nil + } + + return IGNORE, nil + } +} diff --git a/vendor/github.com/pires/go-proxyproto/protocol.go b/vendor/github.com/pires/go-proxyproto/protocol.go new file mode 100644 index 00000000..270b90d2 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/protocol.go @@ -0,0 +1,389 @@ +package proxyproto + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +var ( + // DefaultReadHeaderTimeout is how long header processing waits for header to + // be read from the wire, if Listener.ReaderHeaderTimeout is not set. + // It's kept as a global variable so to make it easier to find and override, + // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s" + DefaultReadHeaderTimeout = 10 * time.Second + + // ErrInvalidUpstream should be returned when an upstream connection address + // is not trusted, and therefore is invalid. + ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information") +) + +// Listener is used to wrap an underlying listener, +// whose connections may be using the HAProxy Proxy Protocol. +// If the connection is using the protocol, the RemoteAddr() will return +// the correct client address. ReadHeaderTimeout will be applied to all +// connections in order to prevent blocking operations. If no ReadHeaderTimeout +// is set, a default of 200ms will be used. This can be disabled by setting the +// timeout to < 0. +// +// Only one of Policy or ConnPolicy should be provided. If both are provided then +// a panic would occur during accept. +type Listener struct { + Listener net.Listener + // Deprecated: use ConnPolicyFunc instead. This will be removed in future release. + Policy PolicyFunc + ConnPolicy ConnPolicyFunc + ValidateHeader Validator + ReadHeaderTimeout time.Duration +} + +// Conn is used to wrap and underlying connection which +// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will +// return the address of the client instead of the proxy address. Each connection +// will have its own readHeaderTimeout and readDeadline set by the Accept() call. +type Conn struct { + readDeadline atomic.Value // time.Time + once sync.Once + readErr error + conn net.Conn + bufReader *bufio.Reader + reader io.Reader + header *Header + ProxyHeaderPolicy Policy + Validate Validator + readHeaderTimeout time.Duration +} + +// Validator receives a header and decides whether it is a valid one +// In case the header is not deemed valid it should return an error. +type Validator func(*Header) error + +// ValidateHeader adds given validator for proxy headers to a connection when passed as option to NewConn() +func ValidateHeader(v Validator) func(*Conn) { + return func(c *Conn) { + if v != nil { + c.Validate = v + } + } +} + +// SetReadHeaderTimeout sets the readHeaderTimeout for a connection when passed as option to NewConn() +func SetReadHeaderTimeout(t time.Duration) func(*Conn) { + return func(c *Conn) { + if t >= 0 { + c.readHeaderTimeout = t + } + } +} + +// Accept waits for and returns the next valid connection to the listener. +func (p *Listener) Accept() (net.Conn, error) { + for { + // Get the underlying connection + conn, err := p.Listener.Accept() + if err != nil { + return nil, err + } + + proxyHeaderPolicy := USE + if p.Policy != nil && p.ConnPolicy != nil { + panic("only one of policy or connpolicy must be provided.") + } + if p.Policy != nil || p.ConnPolicy != nil { + if p.Policy != nil { + proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) + } else { + proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{ + Upstream: conn.RemoteAddr(), + Downstream: conn.LocalAddr(), + }) + } + if err != nil { + // can't decide the policy, we can't accept the connection + conn.Close() + + if errors.Is(err, ErrInvalidUpstream) { + // keep listening for other connections + continue + } + + return nil, err + } + // Handle a connection as a regular one + if proxyHeaderPolicy == SKIP { + return conn, nil + } + } + + newConn := NewConn( + conn, + WithPolicy(proxyHeaderPolicy), + ValidateHeader(p.ValidateHeader), + ) + + // If the ReadHeaderTimeout for the listener is unset, use the default timeout. + if p.ReadHeaderTimeout == 0 { + p.ReadHeaderTimeout = DefaultReadHeaderTimeout + } + + // Set the readHeaderTimeout of the new conn to the value of the listener + newConn.readHeaderTimeout = p.ReadHeaderTimeout + + return newConn, nil + } +} + +// Close closes the underlying listener. +func (p *Listener) Close() error { + return p.Listener.Close() +} + +// Addr returns the underlying listener's network address. +func (p *Listener) Addr() net.Addr { + return p.Listener.Addr() +} + +// NewConn is used to wrap a net.Conn that may be speaking +// the proxy protocol into a proxyproto.Conn +func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { + // For v1 the header length is at most 108 bytes. + // For v2 the header length is at most 52 bytes plus the length of the TLVs. + // We use 256 bytes to be safe. + const bufSize = 256 + br := bufio.NewReaderSize(conn, bufSize) + + pConn := &Conn{ + bufReader: br, + reader: io.MultiReader(br, conn), + conn: conn, + } + + for _, opt := range opts { + opt(pConn) + } + + return pConn +} + +// Read is check for the proxy protocol header when doing +// the initial scan. If there is an error parsing the header, +// it is returned and the socket is closed. +func (p *Conn) Read(b []byte) (int, error) { + p.once.Do(func() { + p.readErr = p.readHeader() + }) + if p.readErr != nil { + return 0, p.readErr + } + + return p.reader.Read(b) +} + +// Write wraps original conn.Write +func (p *Conn) Write(b []byte) (int, error) { + return p.conn.Write(b) +} + +// Close wraps original conn.Close +func (p *Conn) Close() error { + return p.conn.Close() +} + +// ProxyHeader returns the proxy protocol header, if any. If an error occurs +// while reading the proxy header, nil is returned. +func (p *Conn) ProxyHeader() *Header { + p.once.Do(func() { p.readErr = p.readHeader() }) + return p.header +} + +// LocalAddr returns the address of the server if the proxy +// protocol is being used, otherwise just returns the address of +// the socket server. In case an error happens on reading the +// proxy header the original LocalAddr is returned, not the one +// from the proxy header even if the proxy header itself is +// syntactically correct. +func (p *Conn) LocalAddr() net.Addr { + p.once.Do(func() { p.readErr = p.readHeader() }) + if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { + return p.conn.LocalAddr() + } + + return p.header.DestinationAddr +} + +// RemoteAddr returns the address of the client if the proxy +// protocol is being used, otherwise just returns the address of +// the socket peer. In case an error happens on reading the +// proxy header the original RemoteAddr is returned, not the one +// from the proxy header even if the proxy header itself is +// syntactically correct. +func (p *Conn) RemoteAddr() net.Addr { + p.once.Do(func() { p.readErr = p.readHeader() }) + if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { + return p.conn.RemoteAddr() + } + + return p.header.SourceAddr +} + +// Raw returns the underlying connection which can be casted to +// a concrete type, allowing access to specialized functions. +// +// Use this ONLY if you know exactly what you are doing. +func (p *Conn) Raw() net.Conn { + return p.conn +} + +// TCPConn returns the underlying TCP connection, +// allowing access to specialized functions. +// +// Use this ONLY if you know exactly what you are doing. +func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) { + conn, ok = p.conn.(*net.TCPConn) + return +} + +// UnixConn returns the underlying Unix socket connection, +// allowing access to specialized functions. +// +// Use this ONLY if you know exactly what you are doing. +func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) { + conn, ok = p.conn.(*net.UnixConn) + return +} + +// UDPConn returns the underlying UDP connection, +// allowing access to specialized functions. +// +// Use this ONLY if you know exactly what you are doing. +func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) { + conn, ok = p.conn.(*net.UDPConn) + return +} + +// SetDeadline wraps original conn.SetDeadline +func (p *Conn) SetDeadline(t time.Time) error { + p.readDeadline.Store(t) + return p.conn.SetDeadline(t) +} + +// SetReadDeadline wraps original conn.SetReadDeadline +func (p *Conn) SetReadDeadline(t time.Time) error { + // Set a local var that tells us the desired deadline. This is + // needed in order to reset the read deadline to the one that is + // desired by the user, rather than an empty deadline. + p.readDeadline.Store(t) + return p.conn.SetReadDeadline(t) +} + +// SetWriteDeadline wraps original conn.SetWriteDeadline +func (p *Conn) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} + +func (p *Conn) readHeader() error { + // If the connection's readHeaderTimeout is more than 0, + // push our deadline back to now plus the timeout. This should only + // run on the connection, as we don't want to override the previous + // read deadline the user may have used. + if p.readHeaderTimeout > 0 { + if err := p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout)); err != nil { + return err + } + } + + header, err := Read(p.bufReader) + + // If the connection's readHeaderTimeout is more than 0, undo the change to the + // deadline that we made above. Because we retain the readDeadline as part of our + // SetReadDeadline override, we know the user's desired deadline so we use that. + // Therefore, we check whether the error is a net.Timeout and if it is, we decide + // the proxy proto does not exist and set the error accordingly. + if p.readHeaderTimeout > 0 { + t := p.readDeadline.Load() + if t == nil { + t = time.Time{} + } + if err := p.conn.SetReadDeadline(t.(time.Time)); err != nil { + return err + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + err = ErrNoProxyProtocol + } + } + + // For the purpose of this wrapper shamefully stolen from armon/go-proxyproto + // let's act as if there was no error when PROXY protocol is not present. + if err == ErrNoProxyProtocol { + // but not if it is required that the connection has one + if p.ProxyHeaderPolicy == REQUIRE { + return err + } + + return nil + } + + // proxy protocol header was found + if err == nil && header != nil { + switch p.ProxyHeaderPolicy { + case REJECT: + // this connection is not allowed to send one + return ErrSuperfluousProxyHeader + case USE, REQUIRE: + if p.Validate != nil { + err = p.Validate(header) + if err != nil { + return err + } + } + + p.header = header + } + } + + return err +} + +// ReadFrom implements the io.ReaderFrom ReadFrom method +func (p *Conn) ReadFrom(r io.Reader) (int64, error) { + if rf, ok := p.conn.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + return io.Copy(p.conn, r) +} + +// WriteTo implements io.WriterTo +func (p *Conn) WriteTo(w io.Writer) (int64, error) { + p.once.Do(func() { p.readErr = p.readHeader() }) + if p.readErr != nil { + return 0, p.readErr + } + + b := make([]byte, p.bufReader.Buffered()) + if _, err := p.bufReader.Read(b); err != nil { + return 0, err // this should never as we read buffered data + } + + var n int64 + { + nn, err := w.Write(b) + n += int64(nn) + if err != nil { + return n, err + } + } + { + nn, err := io.Copy(w, p.conn) + n += nn + if err != nil { + return n, err + } + } + + return n, nil +} diff --git a/vendor/github.com/pires/go-proxyproto/tlv.go b/vendor/github.com/pires/go-proxyproto/tlv.go new file mode 100644 index 00000000..7cc2fb37 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/tlv.go @@ -0,0 +1,132 @@ +// Type-Length-Value splitting and parsing for proxy protocol V2 +// See spec https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt sections 2.2 to 2.7 and + +package proxyproto + +import ( + "encoding/binary" + "errors" + "fmt" + "math" +) + +const ( + // Section 2.2 + PP2_TYPE_ALPN PP2Type = 0x01 + PP2_TYPE_AUTHORITY PP2Type = 0x02 + PP2_TYPE_CRC32C PP2Type = 0x03 + PP2_TYPE_NOOP PP2Type = 0x04 + PP2_TYPE_UNIQUE_ID PP2Type = 0x05 + PP2_TYPE_SSL PP2Type = 0x20 + PP2_SUBTYPE_SSL_VERSION PP2Type = 0x21 + PP2_SUBTYPE_SSL_CN PP2Type = 0x22 + PP2_SUBTYPE_SSL_CIPHER PP2Type = 0x23 + PP2_SUBTYPE_SSL_SIG_ALG PP2Type = 0x24 + PP2_SUBTYPE_SSL_KEY_ALG PP2Type = 0x25 + PP2_TYPE_NETNS PP2Type = 0x30 + + // Section 2.2.7, reserved types + PP2_TYPE_MIN_CUSTOM PP2Type = 0xE0 + PP2_TYPE_MAX_CUSTOM PP2Type = 0xEF + PP2_TYPE_MIN_EXPERIMENT PP2Type = 0xF0 + PP2_TYPE_MAX_EXPERIMENT PP2Type = 0xF7 + PP2_TYPE_MIN_FUTURE PP2Type = 0xF8 + PP2_TYPE_MAX_FUTURE PP2Type = 0xFF +) + +var ( + ErrTruncatedTLV = errors.New("proxyproto: truncated TLV") + ErrMalformedTLV = errors.New("proxyproto: malformed TLV Value") + ErrIncompatibleTLV = errors.New("proxyproto: incompatible TLV type") +) + +// PP2Type is the proxy protocol v2 type +type PP2Type byte + +// TLV is a uninterpreted Type-Length-Value for V2 protocol, see section 2.2 +type TLV struct { + Type PP2Type + Value []byte +} + +// SplitTLVs splits the Type-Length-Value vector, returns the vector or an error. +func SplitTLVs(raw []byte) ([]TLV, error) { + var tlvs []TLV + for i := 0; i < len(raw); { + tlv := TLV{ + Type: PP2Type(raw[i]), + } + if len(raw)-i <= 2 { + return nil, ErrTruncatedTLV + } + tlvLen := int(binary.BigEndian.Uint16(raw[i+1 : i+3])) // Max length = 65K + i += 3 + if i+tlvLen > len(raw) { + return nil, ErrTruncatedTLV + } + // Ignore no-op padding + if tlv.Type != PP2_TYPE_NOOP { + tlv.Value = make([]byte, tlvLen) + copy(tlv.Value, raw[i:i+tlvLen]) + } + i += tlvLen + tlvs = append(tlvs, tlv) + } + return tlvs, nil +} + +// JoinTLVs joins multiple Type-Length-Value records. +func JoinTLVs(tlvs []TLV) ([]byte, error) { + var raw []byte + for _, tlv := range tlvs { + if len(tlv.Value) > math.MaxUint16 { + return nil, fmt.Errorf("proxyproto: cannot format TLV %v with length %d", tlv.Type, len(tlv.Value)) + } + var length [2]byte + binary.BigEndian.PutUint16(length[:], uint16(len(tlv.Value))) + raw = append(raw, byte(tlv.Type)) + raw = append(raw, length[:]...) + raw = append(raw, tlv.Value...) + } + return raw, nil +} + +// Registered is true if the type is registered in the spec, see section 2.2 +func (p PP2Type) Registered() bool { + switch p { + case PP2_TYPE_ALPN, + PP2_TYPE_AUTHORITY, + PP2_TYPE_CRC32C, + PP2_TYPE_NOOP, + PP2_TYPE_UNIQUE_ID, + PP2_TYPE_SSL, + PP2_SUBTYPE_SSL_VERSION, + PP2_SUBTYPE_SSL_CN, + PP2_SUBTYPE_SSL_CIPHER, + PP2_SUBTYPE_SSL_SIG_ALG, + PP2_SUBTYPE_SSL_KEY_ALG, + PP2_TYPE_NETNS: + return true + } + return false +} + +// App is true if the type is reserved for application specific data, see section 2.2.7 +func (p PP2Type) App() bool { + return p >= PP2_TYPE_MIN_CUSTOM && p <= PP2_TYPE_MAX_CUSTOM +} + +// Experiment is true if the type is reserved for temporary experimental use by application developers, see section 2.2.7 +func (p PP2Type) Experiment() bool { + return p >= PP2_TYPE_MIN_EXPERIMENT && p <= PP2_TYPE_MAX_EXPERIMENT +} + +// Future is true is the type is reserved for future use, see section 2.2.7 +func (p PP2Type) Future() bool { + return p >= PP2_TYPE_MIN_FUTURE +} + +// Spec is true if the type is covered by the spec, see section 2.2 and 2.2.7 +func (p PP2Type) Spec() bool { + return p.Registered() || p.App() || p.Experiment() || p.Future() +} diff --git a/vendor/github.com/pires/go-proxyproto/v1.go b/vendor/github.com/pires/go-proxyproto/v1.go new file mode 100644 index 00000000..0d34ba52 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/v1.go @@ -0,0 +1,243 @@ +package proxyproto + +import ( + "bufio" + "bytes" + "fmt" + "net" + "net/netip" + "strconv" + "strings" +) + +const ( + crlf = "\r\n" + separator = " " +) + +func initVersion1() *Header { + header := new(Header) + header.Version = 1 + // Command doesn't exist in v1 + header.Command = PROXY + return header +} + +func parseVersion1(reader *bufio.Reader) (*Header, error) { + //The header cannot be more than 107 bytes long. Per spec: + // + // (...) + // - worst case (optional fields set to 0xff) : + // "PROXY UNKNOWN ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n" + // => 5 + 1 + 7 + 1 + 39 + 1 + 39 + 1 + 5 + 1 + 5 + 2 = 107 chars + // + // So a 108-byte buffer is always enough to store all the line and a + // trailing zero for string processing. + // + // It must also be CRLF terminated, as above. The header does not otherwise + // contain a CR or LF byte. + // + // ISSUE #69 + // We can't use Peek here as it will block trying to fill the buffer, which + // will never happen if the header is TCP4 or TCP6 (max. 56 and 104 bytes + // respectively) and the server is expected to speak first. + // + // Similarly, we can't use ReadString or ReadBytes as these will keep reading + // until the delimiter is found; an abusive client could easily disrupt a + // server by sending a large amount of data that do not contain a LF byte. + // Another means of attack would be to start connections and simply not send + // data after the initial PROXY signature bytes, accumulating a large + // number of blocked goroutines on the server. ReadSlice will also block for + // a delimiter when the internal buffer does not fill up. + // + // A plain Read is also problematic since we risk reading past the end of the + // header without being able to easily put the excess bytes back into the reader's + // buffer (with the current implementation's design). + // + // So we use a ReadByte loop, which solves the overflow problem and avoids + // reading beyond the end of the header. However, we need one more trick to harden + // against partial header attacks (slow loris) - per spec: + // + // (..) The sender must always ensure that the header is sent at once, so that + // the transport layer maintains atomicity along the path to the receiver. The + // receiver may be tolerant to partial headers or may simply drop the connection + // when receiving a partial header. Recommendation is to be tolerant, but + // implementation constraints may not always easily permit this. + // + // We are subject to such implementation constraints. So we return an error if + // the header cannot be fully extracted with a single read of the underlying + // reader. + buf := make([]byte, 0, 107) + for { + b, err := reader.ReadByte() + if err != nil { + return nil, fmt.Errorf(ErrCantReadVersion1Header.Error()+": %v", err) + } + buf = append(buf, b) + if b == '\n' { + // End of header found + break + } + if len(buf) == 107 { + // No delimiter in first 107 bytes + return nil, ErrVersion1HeaderTooLong + } + if reader.Buffered() == 0 { + // Header was not buffered in a single read. Since we can't + // differentiate between genuine slow writers and DoS agents, + // we abort. On healthy networks, this should never happen. + return nil, ErrCantReadVersion1Header + } + } + + // Check for CR before LF. + if len(buf) < 2 || buf[len(buf)-2] != '\r' { + return nil, ErrLineMustEndWithCrlf + } + + // Check full signature. + tokens := strings.Split(string(buf[:len(buf)-2]), separator) + + // Expect at least 2 tokens: "PROXY" and the transport protocol. + if len(tokens) < 2 { + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // Read address family and protocol + var transportProtocol AddressFamilyAndProtocol + switch tokens[1] { + case "TCP4": + transportProtocol = TCPv4 + case "TCP6": + transportProtocol = TCPv6 + case "UNKNOWN": + transportProtocol = UNSPEC // doesn't exist in v1 but fits UNKNOWN + default: + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // Expect 6 tokens only when UNKNOWN is not present. + if transportProtocol != UNSPEC && len(tokens) < 6 { + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // When a signature is found, allocate a v1 header with Command set to PROXY. + // Command doesn't exist in v1 but set it for other parts of this library + // to rely on it for determining connection details. + header := initVersion1() + + // Transport protocol has been processed already. + header.TransportProtocol = transportProtocol + + // When UNKNOWN, set the command to LOCAL and return early + if header.TransportProtocol == UNSPEC { + header.Command = LOCAL + return header, nil + } + + // Otherwise, continue to read addresses and ports + sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2]) + if err != nil { + return nil, err + } + destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3]) + if err != nil { + return nil, err + } + sourcePort, err := parseV1PortNumber(tokens[4]) + if err != nil { + return nil, err + } + destPort, err := parseV1PortNumber(tokens[5]) + if err != nil { + return nil, err + } + header.SourceAddr = &net.TCPAddr{ + IP: sourceIP, + Port: sourcePort, + } + header.DestinationAddr = &net.TCPAddr{ + IP: destIP, + Port: destPort, + } + + return header, nil +} + +func (header *Header) formatVersion1() ([]byte, error) { + // As of version 1, only "TCP4" ( \x54 \x43 \x50 \x34 ) for TCP over IPv4, + // and "TCP6" ( \x54 \x43 \x50 \x36 ) for TCP over IPv6 are allowed. + var proto string + switch header.TransportProtocol { + case TCPv4: + proto = "TCP4" + case TCPv6: + proto = "TCP6" + default: + // Unknown connection (short form) + return []byte("PROXY UNKNOWN" + crlf), nil + } + + sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) + destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) + if !sourceOK || !destOK { + return nil, ErrInvalidAddress + } + + sourceIP, destIP := sourceAddr.IP, destAddr.IP + switch header.TransportProtocol { + case TCPv4: + sourceIP = sourceIP.To4() + destIP = destIP.To4() + case TCPv6: + sourceIP = sourceIP.To16() + destIP = destIP.To16() + } + if sourceIP == nil || destIP == nil { + return nil, ErrInvalidAddress + } + + buf := bytes.NewBuffer(make([]byte, 0, 108)) + buf.Write(SIGV1) + buf.WriteString(separator) + buf.WriteString(proto) + buf.WriteString(separator) + buf.WriteString(sourceIP.String()) + buf.WriteString(separator) + buf.WriteString(destIP.String()) + buf.WriteString(separator) + buf.WriteString(strconv.Itoa(sourceAddr.Port)) + buf.WriteString(separator) + buf.WriteString(strconv.Itoa(destAddr.Port)) + buf.WriteString(crlf) + + return buf.Bytes(), nil +} + +func parseV1PortNumber(portStr string) (int, error) { + port, err := strconv.Atoi(portStr) + if err != nil || port < 0 || port > 65535 { + return 0, ErrInvalidPortNumber + } + return port, nil +} + +func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (net.IP, error) { + addr, err := netip.ParseAddr(addrStr) + if err != nil { + return nil, ErrInvalidAddress + } + + switch protocol { + case TCPv4: + if addr.Is4() { + return net.IP(addr.AsSlice()), nil + } + case TCPv6: + if addr.Is6() || addr.Is4In6() { + return net.IP(addr.AsSlice()), nil + } + } + + return nil, ErrInvalidAddress +} diff --git a/vendor/github.com/pires/go-proxyproto/v2.go b/vendor/github.com/pires/go-proxyproto/v2.go new file mode 100644 index 00000000..74bf3f07 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/v2.go @@ -0,0 +1,285 @@ +package proxyproto + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "io" + "net" +) + +var ( + lengthUnspec = uint16(0) + lengthV4 = uint16(12) + lengthV6 = uint16(36) + lengthUnix = uint16(216) + lengthUnspecBytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthUnspec) + return a + }() + lengthV4Bytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthV4) + return a + }() + lengthV6Bytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthV6) + return a + }() + lengthUnixBytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthUnix) + return a + }() + errUint16Overflow = errors.New("proxyproto: uint16 overflow") +) + +type _ports struct { + SrcPort uint16 + DstPort uint16 +} + +type _addr4 struct { + Src [4]byte + Dst [4]byte + SrcPort uint16 + DstPort uint16 +} + +type _addr6 struct { + Src [16]byte + Dst [16]byte + _ports +} + +type _addrUnix struct { + Src [108]byte + Dst [108]byte +} + +func parseVersion2(reader *bufio.Reader) (header *Header, err error) { + // Skip first 12 bytes (signature) + for i := 0; i < 12; i++ { + if _, err = reader.ReadByte(); err != nil { + return nil, ErrCantReadProtocolVersionAndCommand + } + } + + header = new(Header) + header.Version = 2 + + // Read the 13th byte, protocol version and command + b13, err := reader.ReadByte() + if err != nil { + return nil, ErrCantReadProtocolVersionAndCommand + } + header.Command = ProtocolVersionAndCommand(b13) + if _, ok := supportedCommand[header.Command]; !ok { + return nil, ErrUnsupportedProtocolVersionAndCommand + } + + // Read the 14th byte, address family and protocol + b14, err := reader.ReadByte() + if err != nil { + return nil, ErrCantReadAddressFamilyAndProtocol + } + header.TransportProtocol = AddressFamilyAndProtocol(b14) + // UNSPEC is only supported when LOCAL is set. + if header.TransportProtocol == UNSPEC && header.Command != LOCAL { + return nil, ErrUnsupportedAddressFamilyAndProtocol + } + + // Make sure there are bytes available as specified in length + var length uint16 + if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil { + return nil, ErrCantReadLength + } + if !header.validateLength(length) { + return nil, ErrInvalidLength + } + + // Return early if the length is zero, which means that + // there's no address information and TLVs present for UNSPEC. + if length == 0 { + return header, nil + } + + if _, err := reader.Peek(int(length)); err != nil { + return nil, ErrInvalidLength + } + + // Length-limited reader for payload section + payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader) + + // Read addresses and ports for protocols other than UNSPEC. + // Ignore address information for UNSPEC, and skip straight to read TLVs, + // since the length is greater than zero. + if header.TransportProtocol != UNSPEC { + if header.TransportProtocol.IsIPv4() { + var addr _addr4 + if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { + return nil, ErrInvalidAddress + } + header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) + header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) + } else if header.TransportProtocol.IsIPv6() { + var addr _addr6 + if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { + return nil, ErrInvalidAddress + } + header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) + header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) + } else if header.TransportProtocol.IsUnix() { + var addr _addrUnix + if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { + return nil, ErrInvalidAddress + } + + network := "unix" + if header.TransportProtocol.IsDatagram() { + network = "unixgram" + } + + header.SourceAddr = &net.UnixAddr{ + Net: network, + Name: parseUnixName(addr.Src[:]), + } + header.DestinationAddr = &net.UnixAddr{ + Net: network, + Name: parseUnixName(addr.Dst[:]), + } + } + } + + // Copy bytes for optional Type-Length-Value vector + header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice + if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF { + return nil, err + } + + return header, nil +} + +func (header *Header) formatVersion2() ([]byte, error) { + var buf bytes.Buffer + buf.Write(SIGV2) + buf.WriteByte(header.Command.toByte()) + buf.WriteByte(header.TransportProtocol.toByte()) + if header.TransportProtocol.IsUnspec() { + // For UNSPEC, write no addresses and ports but only TLVs if they are present + hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs)) + if err != nil { + return nil, err + } + buf.Write(hdrLen) + } else { + var addrSrc, addrDst []byte + if header.TransportProtocol.IsIPv4() { + hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs)) + if err != nil { + return nil, err + } + buf.Write(hdrLen) + sourceIP, destIP, _ := header.IPs() + addrSrc = sourceIP.To4() + addrDst = destIP.To4() + } else if header.TransportProtocol.IsIPv6() { + hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs)) + if err != nil { + return nil, err + } + buf.Write(hdrLen) + sourceIP, destIP, _ := header.IPs() + addrSrc = sourceIP.To16() + addrDst = destIP.To16() + } else if header.TransportProtocol.IsUnix() { + buf.Write(lengthUnixBytes) + sourceAddr, destAddr, ok := header.UnixAddrs() + if !ok { + return nil, ErrInvalidAddress + } + addrSrc = formatUnixName(sourceAddr.Name) + addrDst = formatUnixName(destAddr.Name) + } + + if addrSrc == nil || addrDst == nil { + return nil, ErrInvalidAddress + } + buf.Write(addrSrc) + buf.Write(addrDst) + + if sourcePort, destPort, ok := header.Ports(); ok { + portBytes := make([]byte, 2) + + binary.BigEndian.PutUint16(portBytes, uint16(sourcePort)) + buf.Write(portBytes) + + binary.BigEndian.PutUint16(portBytes, uint16(destPort)) + buf.Write(portBytes) + } + } + + if len(header.rawTLVs) > 0 { + buf.Write(header.rawTLVs) + } + + return buf.Bytes(), nil +} + +func (header *Header) validateLength(length uint16) bool { + if header.TransportProtocol.IsIPv4() { + return length >= lengthV4 + } else if header.TransportProtocol.IsIPv6() { + return length >= lengthV6 + } else if header.TransportProtocol.IsUnix() { + return length >= lengthUnix + } else if header.TransportProtocol.IsUnspec() { + return length >= lengthUnspec + } + return false +} + +// addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow. +func addTLVLen(cur []byte, tlvLen int) ([]byte, error) { + if tlvLen == 0 { + return cur, nil + } + curLen := binary.BigEndian.Uint16(cur) + newLen := int(curLen) + tlvLen + if newLen >= 1<<16 { + return nil, errUint16Overflow + } + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, uint16(newLen)) + return a, nil +} + +func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr { + if transport.IsStream() { + return &net.TCPAddr{IP: ip, Port: int(port)} + } else if transport.IsDatagram() { + return &net.UDPAddr{IP: ip, Port: int(port)} + } else { + return nil + } +} + +func parseUnixName(b []byte) string { + i := bytes.IndexByte(b, 0) + if i < 0 { + return string(b) + } + return string(b[:i]) +} + +func formatUnixName(name string) []byte { + n := int(lengthUnix) / 2 + if len(name) >= n { + return []byte(name[:n]) + } + pad := make([]byte, n-len(name)) + return append([]byte(name), pad...) +} diff --git a/vendor/github.com/pires/go-proxyproto/version_cmd.go b/vendor/github.com/pires/go-proxyproto/version_cmd.go new file mode 100644 index 00000000..59f20420 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/version_cmd.go @@ -0,0 +1,47 @@ +package proxyproto + +// ProtocolVersionAndCommand represents the command in proxy protocol v2. +// Command doesn't exist in v1 but it should be set since other parts of +// this library may rely on it for determining connection details. +type ProtocolVersionAndCommand byte + +const ( + // LOCAL represents the LOCAL command in v2 or UNKNOWN transport in v1, + // in which case no address information is expected. + LOCAL ProtocolVersionAndCommand = '\x20' + // PROXY represents the PROXY command in v2 or transport is not UNKNOWN in v1, + // in which case valid local/remote address and port information is expected. + PROXY ProtocolVersionAndCommand = '\x21' +) + +var supportedCommand = map[ProtocolVersionAndCommand]bool{ + LOCAL: true, + PROXY: true, +} + +// IsLocal returns true if the command in v2 is LOCAL or the transport in v1 is UNKNOWN, +// i.e. when no address information is expected, false otherwise. +func (pvc ProtocolVersionAndCommand) IsLocal() bool { + return LOCAL == pvc +} + +// IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN, +// i.e. when valid local/remote address and port information is expected, false otherwise. +func (pvc ProtocolVersionAndCommand) IsProxy() bool { + return PROXY == pvc +} + +// IsUnspec returns true if the command is unspecified, false otherwise. +func (pvc ProtocolVersionAndCommand) IsUnspec() bool { + return !(pvc.IsLocal() || pvc.IsProxy()) +} + +func (pvc ProtocolVersionAndCommand) toByte() byte { + if pvc.IsLocal() { + return byte(LOCAL) + } else if pvc.IsProxy() { + return byte(PROXY) + } + + return byte(LOCAL) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 482c91e7..d84cc13e 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -296,6 +296,9 @@ github.com/oklog/run # github.com/pelletier/go-toml v1.2.0 ## explicit github.com/pelletier/go-toml +# github.com/pires/go-proxyproto v0.8.0 +## explicit; go 1.18 +github.com/pires/go-proxyproto # github.com/pkg/errors v0.9.1 ## explicit github.com/pkg/errors