Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ type ProxyConfig struct {
HTTP HTTPConfig `json:"http" yaml:"http"`

TLS TLSConfig `json:"tls" yaml:"tls"`

ClientTLS ClientTLSConfig `json:"client_tls" yaml:"client_tls"`
}

func (c *ProxyConfig) Validate() error {
Expand Down Expand Up @@ -274,6 +276,8 @@ Timeout when forwarding incoming requests to the upstream.`,
c.Auth.RegisterFlags(fs, "proxy")

c.TLS.RegisterFlags(fs, "proxy")

c.ClientTLS.RegisterFlags(fs, "proxy.tls")
}

type UpstreamConfig struct {
Expand Down
16 changes: 16 additions & 0 deletions server/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ proxy:
cert: /piko/cert.pem
key: /piko/key.pem

client_tls:
cert: /piko/cert2.pem
key: /piko/key2.pem
cas: /piko/ca.pem

upstream:
bind_addr: 10.15.104.25:8001
advertise_addr: 1.2.3.4:8001
Expand Down Expand Up @@ -171,6 +176,11 @@ grace_period: 2m
Cert: "/piko/cert.pem",
Key: "/piko/key.pem",
},
ClientTLS: ClientTLSConfig{
Cert: "/piko/cert2.pem",
Key: "/piko/key2.pem",
CAs: "/piko/ca.pem",
},
},
Upstream: UpstreamConfig{
BindAddr: "10.15.104.25:8001",
Expand Down Expand Up @@ -274,6 +284,8 @@ func TestConfig_LoadFlags(t *testing.T) {
"--proxy.auth.issuer", "my-issuer",
"--proxy.tls.cert", "/piko/cert.pem",
"--proxy.tls.key", "/piko/key.pem",
"--proxy.tls.client.cas", "/piko/ca.pem",
"--proxy.tls.client.skip-verify",
"--upstream.bind-addr", "10.15.104.25:8001",
"--upstream.advertise-addr", "1.2.3.4:8001",
"--upstream.rebalance.threshold", "0.2",
Expand Down Expand Up @@ -349,6 +361,10 @@ func TestConfig_LoadFlags(t *testing.T) {
Cert: "/piko/cert.pem",
Key: "/piko/key.pem",
},
ClientTLS: ClientTLSConfig{
CAs: "/piko/ca.pem",
SkipVerify: true,
},
},
Upstream: UpstreamConfig{
BindAddr: "10.15.104.25:8001",
Expand Down
82 changes: 82 additions & 0 deletions server/config/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,85 @@ func (c *TLSConfig) Load() (*tls.Config, error) {
func (c *TLSConfig) enabled() bool {
return c.Cert != "" || c.Key != ""
}

type ClientTLSConfig struct {
Cert string `json:"cert" yaml:"cert"`
Key string `json:"key" yaml:"key"`
CAs string `json:"cas" yaml:"cas"`
SkipVerify bool `json:"skip_verify" yaml:"skip_verify"`
}

func (c *ClientTLSConfig) Load() (*tls.Config, error) {
tlsConfig := &tls.Config{
InsecureSkipVerify: c.SkipVerify,
}

if c.CAs != "" {
caCertPool, err := x509.SystemCertPool()
if err != nil {
caCertPool = x509.NewCertPool()
}

caCert, err := os.ReadFile(c.CAs)
if err != nil {
return nil, fmt.Errorf("open cas: %s: %w", c.CAs, err)
}

ok := caCertPool.AppendCertsFromPEM(caCert)
if !ok {
return nil, fmt.Errorf("parse cas: %s: %w", c.CAs, err)
}

tlsConfig.RootCAs = caCertPool
}

if c.Cert == "" && c.Key == "" {

cert, err := tls.LoadX509KeyPair(c.Cert, c.Key)
if err != nil {
return nil, fmt.Errorf("load key pair: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}

return tlsConfig, nil
}

func (c *ClientTLSConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) {
prefix += ".client."

fs.StringVar(
&c.Cert,
prefix+"cert",
c.Cert,
`
Path to the PEM encoded certificate file.

Used for communication between Piko servers if mTLS is expected`,
)
fs.StringVar(
&c.Key,
prefix+"key",
c.Key,
`
Path to the PEM encoded key file.`,
)
fs.StringVar(
&c.CAs,
prefix+"cas",
c.CAs,
`
A path to a certificate PEM file containing certificiate authorities to
verify the server certificates.

Required when the server is using non-public certificates and not skip-verify.`,
)

fs.BoolVar(
&c.SkipVerify,
prefix+"skip-verify",
c.SkipVerify,
`
Skip certificate verification between Piko servers.`,
)
}
20 changes: 15 additions & 5 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
Expand Down Expand Up @@ -118,14 +119,27 @@ func NewServer(conf *config.Config, logger log.Logger) (*Server, error) {

// Cluster.

proxyTLSConfig, err := conf.Proxy.TLS.Load()
if err != nil {
return nil, fmt.Errorf("proxy tls: %w", err)
}

s.clusterState = cluster.NewState(&cluster.Node{
ID: conf.Cluster.NodeID,
ProxyAddr: conf.Proxy.AdvertiseAddr,
AdminAddr: conf.Admin.AdvertiseAddr,
}, logger)
s.clusterState.Metrics().Register(registry)

upstreams := upstream.NewLoadBalancedManager(s.clusterState)
var clientTLSConfig *tls.Config
if proxyTLSConfig != nil {
clientTLSConfig, err = conf.Proxy.ClientTLS.Load()
if err != nil {
return nil, fmt.Errorf("proxy client tls: %w", err)
}
}

upstreams := upstream.NewLoadBalancedManager(s.clusterState, clientTLSConfig)
upstreams.Metrics().Register(registry)

// Proxy server.
Expand All @@ -140,10 +154,6 @@ func NewServer(conf *config.Config, logger log.Logger) (*Server, error) {
auth.NewJWTVerifier(verifierConf), nil,
)
}
proxyTLSConfig, err := conf.Proxy.TLS.Load()
if err != nil {
return nil, fmt.Errorf("proxy tls: %w", err)
}
s.proxyServer = proxy.NewServer(
upstreams,
conf.Proxy,
Expand Down
12 changes: 8 additions & 4 deletions server/upstream/manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package upstream

import (
"crypto/tls"
"sync"

"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -83,12 +84,15 @@ type LoadBalancedManager struct {
cluster *cluster.State

metrics *Metrics

clientTLSConfig *tls.Config
}

func NewLoadBalancedManager(cluster *cluster.State) *LoadBalancedManager {
func NewLoadBalancedManager(cluster *cluster.State, proxyClientTLSConfig *tls.Config) *LoadBalancedManager {
return &LoadBalancedManager{
localUpstreams: make(map[string]*loadBalancer),
cluster: cluster,
localUpstreams: make(map[string]*loadBalancer),
cluster: cluster,
clientTLSConfig: proxyClientTLSConfig,
usage: &Usage{
Requests: atomic.NewUint64(0),
Upstreams: atomic.NewUint64(0),
Expand Down Expand Up @@ -118,7 +122,7 @@ func (m *LoadBalancedManager) Select(endpointID string, allowRemote bool) (Upstr
"node_id": node.ID,
}).Inc()
m.usage.Requests.Inc()
return NewNodeUpstream(endpointID, node), true
return NewNodeUpstream(endpointID, node, m.clientTLSConfig), true
}

func (m *LoadBalancedManager) AddConn(u Upstream) {
Expand Down
9 changes: 8 additions & 1 deletion server/upstream/upstream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package upstream

import (
"crypto/tls"
"net"

"github.com/andydunstall/yamux"
Expand Down Expand Up @@ -50,12 +51,14 @@ func (u *ConnUpstream) Forward() bool {
type NodeUpstream struct {
endpointID string
node *cluster.Node
tlsConfig *tls.Config
}

func NewNodeUpstream(endpointID string, node *cluster.Node) *NodeUpstream {
func NewNodeUpstream(endpointID string, node *cluster.Node, tlsConfig *tls.Config) *NodeUpstream {
return &NodeUpstream{
endpointID: endpointID,
node: node,
tlsConfig: tlsConfig,
}
}

Expand All @@ -64,6 +67,10 @@ func (u *NodeUpstream) EndpointID() string {
}

func (u *NodeUpstream) Dial() (net.Conn, error) {
if u.tlsConfig != nil {
return tls.Dial("tcp", u.node.ProxyAddr, u.tlsConfig)
}

return net.Dial("tcp", u.node.ProxyAddr)
}

Expand Down
Loading