diff --git a/README.md b/README.md index 77417b362..9b0405aee 100644 --- a/README.md +++ b/README.md @@ -888,6 +888,12 @@ Enforce reauthentication on password update. Use this to enable/disable anonymous sign-ins. +### IP address forwarding + +`GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED` - `bool` + +Enable IP address forwarding using the `Sb-Forwarded-For` HTTP request header. When enabled, Auth will parse the first value of this header as an IP address and use it for IP address tracking and rate limiting. Make sure this header is fully trusted before enabling this feature by only passing it from trustworthy clients or proxies. + ## Endpoints Auth exposes the following endpoints: diff --git a/internal/api/api.go b/internal/api/api.go index a728251c3..e656def35 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -19,6 +19,7 @@ import ( "github.com/supabase/auth/internal/mailer/templatemailer" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/sbff" "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/tokens" "github.com/supabase/auth/internal/utilities" @@ -152,8 +153,17 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r := newRouter() r.UseBypass(observability.AddRequestID(globalConfig)) r.UseBypass(logger) - r.UseBypass(xffmw.Handler) r.UseBypass(recoverer) + r.UseBypass( + sbff.Middleware( + &globalConfig.Security, + func(r *http.Request, err error) { + log := observability.GetLogEntry(r).Entry + log.WithField("error", err.Error()).Warn("error processing Sb-Forwarded-For") + }, + ), + ) + r.UseBypass(xffmw.Handler) if globalConfig.API.MaxRequestDuration > 0 { r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration)) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index ab28a8c58..96b323c53 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -20,6 +20,7 @@ import ( "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/sbff" "github.com/supabase/auth/internal/security" "github.com/supabase/auth/internal/utilities" @@ -61,7 +62,7 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error { var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered") -func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error { +func (a *API) performRateLimitingWithHeader(lmt *limiter.Limiter, req *http.Request) error { limitHeader := a.config.RateLimitHeader // If no rate limit header was set, ignore rate limiting @@ -112,6 +113,18 @@ func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error return nil } +func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error { + if sbffAddr, ok := sbff.GetIPAddress(req); ok { + if err := tollbooth.LimitByKeys(lmt, []string{sbffAddr}); err != nil { + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") + } + + return nil + } + + return a.performRateLimitingWithHeader(lmt, req) +} + func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { return req.Context(), a.performRateLimiting(lmt, req) diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 34db4c0f0..dd0f4da37 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/sbff" "github.com/supabase/auth/internal/storage" ) @@ -415,7 +416,166 @@ func TestTimeoutResponseWriter(t *testing.T) { require.Equal(t, w1.Result(), w2.Result()) } -func (ts *MiddlewareTestSuite) TestPerformRateLimiting() { +func (ts *MiddlewareTestSuite) TestPerformRateLimitingWithSBFF() { + origRateLimitHeader := ts.Config.RateLimitHeader + origSBFFEnabled := ts.Config.Security.SbForwardedForEnabled + + defer func() { + ts.Config.RateLimitHeader = origRateLimitHeader + ts.Config.Security.SbForwardedForEnabled = origSBFFEnabled + }() + + ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting" + ts.Config.Security.SbForwardedForEnabled = true + + type headerSet struct { + rateLimiting string + sbForwardedFor string + } + + testCases := []struct { + name string + headerValues []headerSet + expErr error + }{ + { + name: "multiple SBFF values, single rate limiting value", + headerValues: []headerSet{ + { + sbForwardedFor: "192.168.1.100", + rateLimiting: "60.60.60.60", + }, + { + sbForwardedFor: "192.168.1.200", + rateLimiting: "60.60.60.60", + }, + }, + expErr: nil, + }, + { + name: "single SBFF value, multiple rate limiting values", + headerValues: []headerSet{ + { + sbForwardedFor: "192.168.1.100", + rateLimiting: "60.60.60.60", + }, + { + sbForwardedFor: "192.168.1.100", + rateLimiting: "70.70.70.70", + }, + }, + expErr: apierrors.NewTooManyRequestsError( + apierrors.ErrorCodeOverRequestRateLimit, + "Request rate limit reached", + ), + }, + { + name: "no SBFF value, multiple rate limiting values", + headerValues: []headerSet{ + { + sbForwardedFor: "", + rateLimiting: "60.60.60.60", + }, + { + sbForwardedFor: "", + rateLimiting: "70.70.70.70", + }, + }, + expErr: nil, + }, + { + name: "no SBFF value, single rate limiting value", + headerValues: []headerSet{ + { + sbForwardedFor: "", + rateLimiting: "60.60.60.60", + }, + { + sbForwardedFor: "", + rateLimiting: "60.60.60.60", + }, + }, + expErr: apierrors.NewTooManyRequestsError( + apierrors.ErrorCodeOverRequestRateLimit, + "Request rate limit reached", + ), + }, + { + name: "invalid SBFF value, multiple rate limiting values", + headerValues: []headerSet{ + { + sbForwardedFor: "invalid", + rateLimiting: "60.60.60.60", + }, + { + sbForwardedFor: "invalid", + rateLimiting: "70.70.70.70", + }, + }, + expErr: nil, + }, + { + name: "invalid SBFF value, single rate limiting value", + headerValues: []headerSet{ + { + sbForwardedFor: "invalid", + rateLimiting: "60.60.60.60", + }, + { + sbForwardedFor: "invalid", + rateLimiting: "60.60.60.60", + }, + }, + expErr: apierrors.NewTooManyRequestsError( + apierrors.ErrorCodeOverRequestRateLimit, + "Request rate limit reached", + ), + }, + } + + // This test uses the SBFF middleware to inject the Sb-Forwarded-For IP address value, then + // wraps a handler that calls performRateLimiting and stores the error value. + for _, tc := range testCases { + lmt := tollbooth.NewLimiter( + 1, + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }, + ) + + var obsErr error + + var handler http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) { + obsErr = ts.API.performRateLimiting(lmt, r) + } + + errCallback := func(r *http.Request, err error) { + } + + middleware := sbff.Middleware(&ts.Config.Security, errCallback) + + wrappedHandler := middleware(handler) + + for _, h := range tc.headerValues { + r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + + if h.rateLimiting != "" { + r.Header.Set(ts.Config.RateLimitHeader, h.rateLimiting) + } + + if h.sbForwardedFor != "" { + r.Header.Set(sbff.HeaderName, h.sbForwardedFor) + } + + wrappedHandler.ServeHTTP(nil, r) + } + + require.ErrorIs(ts.T(), obsErr, tc.expErr) + } + +} + +func (ts *MiddlewareTestSuite) TestPerformRateLimitingWithHeader() { ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting" tests := []struct { diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index b49e54f0c..3e397be69 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -731,6 +731,7 @@ type SecurityConfiguration struct { RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"` UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"` ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"` + SbForwardedForEnabled bool `json:"sb_forwarded_for_enabled" split_words:"true" default:"false"` DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"` } diff --git a/internal/sbff/sbff.go b/internal/sbff/sbff.go new file mode 100644 index 000000000..33a512664 --- /dev/null +++ b/internal/sbff/sbff.go @@ -0,0 +1,94 @@ +package sbff + +import ( + "context" + "errors" + "net" + "net/http" + "strings" + + "github.com/supabase/auth/internal/conf" +) + +// HeaderName is the Sb-Forwarded-For header name. It is all lowercase here as HTTP header names +// are not case-sensitive. +const HeaderName = "sb-forwarded-for" + +var ( + ctxKeySBFF = &struct{}{} + + ErrHeaderNotFound = errors.New("Sb-Forwarded-For header not found") + ErrHeaderInvalid = errors.New("invalid Sb-Forwarded-For header value") +) + +func parseSBFFHeader(headerVal string) (string, error) { + values := strings.SplitN(headerVal, ",", 2) + key := strings.TrimSpace(values[0]) + if ipAddr := net.ParseIP(key); ipAddr != nil { + return ipAddr.String(), nil + } + + return "", ErrHeaderInvalid +} + +// GetIPAddress returns the value of the IP address in Sb-Forwarded-For as defined by +// SBForwardedForMiddleware. If no value is present in the request context, this function will +// return ("", false). +func GetIPAddress(r *http.Request) (addr string, found bool) { + if ipAddr, ok := r.Context().Value(ctxKeySBFF).(string); ok && ipAddr != "" { + return ipAddr, true + } + + return "", false +} + +// withIPAddress parses the Sb-Forwarded-For header and adds the leftmost value to the +// request context if it is a valid IP address, then returns a new request with modified context. +// If the leftmost value is not a valid IP address or the header is not set, this function returns +// an error. +func withIPAddress(r *http.Request) (*http.Request, error) { + headerVal := r.Header.Get(HeaderName) + if headerVal == "" { + return nil, ErrHeaderNotFound + } + + parsedIPAddr, err := parseSBFFHeader(headerVal) + if err != nil { + return nil, err + } + + ctx := r.Context() + newCtx := context.WithValue(ctx, ctxKeySBFF, parsedIPAddr) + out := r.WithContext(newCtx) + + return out, nil +} + +// Middleware returns a middleware function that parses the Sb-Forwarded-For header +// and adds the leftmost header value to the request context if GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED +// is true and the value is a valid IP address. +func Middleware(cfg *conf.SecurityConfiguration, errCallback func(*http.Request, error)) func(http.Handler) http.Handler { + out := func(next http.Handler) http.Handler { + handlerFunc := func(rw http.ResponseWriter, r *http.Request) { + if !cfg.SbForwardedForEnabled { + next.ServeHTTP(rw, r) + return + } + + reqWithSBFF, err := withIPAddress(r) + switch { + case err == nil: + next.ServeHTTP(rw, reqWithSBFF) + case errors.Is(err, ErrHeaderNotFound): + next.ServeHTTP(rw, r) + default: + errCallback(r, err) + next.ServeHTTP(rw, r) + } + } + + return http.HandlerFunc(handlerFunc) + } + + return out +} diff --git a/internal/sbff/sbff_test.go b/internal/sbff/sbff_test.go new file mode 100644 index 000000000..6f38bd96f --- /dev/null +++ b/internal/sbff/sbff_test.go @@ -0,0 +1,254 @@ +package sbff + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestParseHeader(t *testing.T) { + testCases := []struct { + name string + headerVal string + expAddr string + expErr error + }{ + { + name: "SingleAddressIPv4", + headerVal: "192.168.1.100", + expAddr: "192.168.1.100", + expErr: nil, + }, + + { + name: "SingleAddressIPv6", + headerVal: "2600:1000:cafe:bead::1", + expAddr: "2600:1000:cafe:bead::1", + expErr: nil, + }, + { + name: "MultipleAddressIPv4", + headerVal: "192.168.1.100,60.60.60.60", + expAddr: "192.168.1.100", + expErr: nil, + }, + { + name: "MultipleAddressIPv4WithWhitespace", + headerVal: "192.168.1.100 ,60.60.60.60", + expAddr: "192.168.1.100", + expErr: nil, + }, + { + name: "HeaderInvalid", + headerVal: "invalid, 60.60.60.60", + expAddr: "", + expErr: ErrHeaderInvalid, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + obsAddr, obsErr := parseSBFFHeader(tc.headerVal) + require.Equal(t, tc.expAddr, obsAddr) + require.ErrorIs(t, obsErr, tc.expErr) + }) + } +} + +func TestWithIPAddress(t *testing.T) { + testCases := []struct { + name string + headerVal string + expAddr string + expErr error + }{ + { + name: "WithHeader", + headerVal: "2600:cafe:bead::1", + expAddr: "2600:cafe:bead::1", + expErr: nil, + }, + { + name: "HeaderNotFound", + headerVal: "", + expAddr: "", + expErr: ErrHeaderNotFound, + }, + { + name: "HeaderInvalid", + headerVal: "invalid", + expAddr: "", + expErr: ErrHeaderInvalid, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + + if tc.headerVal != "" { + r.Header.Set(HeaderName, tc.headerVal) + } + + obsReq, obsErr := withIPAddress(r) + + if tc.expErr == nil { + require.NotNil(t, obsReq) + + obsAddr, ok := GetIPAddress(obsReq) + require.Equal(t, tc.expAddr, obsAddr) + require.Equal(t, true, ok) + } + + require.ErrorIs(t, obsErr, tc.expErr) + }) + } +} + +func TestGetIPAddress(t *testing.T) { + testCases := []struct { + name string + // ctxVal is any here because context.WithValue accepts any + ctxVal any + expAddr string + expFound bool + }{ + { + name: "WithAddress", + ctxVal: "2600:cafe:bead::1", + expAddr: "2600:cafe:bead::1", + expFound: true, + }, + { + name: "EmptyContext", + ctxVal: nil, + expAddr: "", + expFound: false, + }, + { + name: "NonStringValue", + ctxVal: 1, + expAddr: "", + expFound: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + originalReq := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + + var ctx context.Context + + if tc.ctxVal == nil { + ctx = originalReq.Context() + } else { + ctx = context.WithValue(originalReq.Context(), ctxKeySBFF, tc.ctxVal) + } + + r := originalReq.WithContext(ctx) + + obsAddr, obsFound := GetIPAddress(r) + + require.Equal(t, tc.expAddr, obsAddr) + require.Equal(t, tc.expFound, obsFound) + }) + } +} + +func TestMiddleware(t *testing.T) { + testCases := []struct { + name string + sbffEnabled bool + headerVal string + expAddr string + expFound bool + expErr error + }{ + { + name: "FlagDisabledHeaderEmpty", + sbffEnabled: false, + headerVal: "", + expAddr: "", + expFound: false, + expErr: nil, + }, + { + name: "FlagDisabledHeaderValid", + sbffEnabled: false, + headerVal: "192.168.1.100", + expAddr: "", + expFound: false, + expErr: nil, + }, + { + name: "FlagDisabledHeaderInvalid", + sbffEnabled: false, + headerVal: "invalid", + expAddr: "", + expFound: false, + expErr: nil, + }, + { + name: "FlagEnabledHeaderEmpty", + sbffEnabled: true, + headerVal: "", + expAddr: "", + expFound: false, + expErr: nil, + }, + { + name: "FlagEnabledHeaderValid", + sbffEnabled: true, + headerVal: "192.168.1.100", + expAddr: "192.168.1.100", + expFound: true, + expErr: nil, + }, + { + name: "FlagEnabledHeaderInvalid", + sbffEnabled: true, + headerVal: "invalid", + expAddr: "", + expFound: false, + expErr: ErrHeaderInvalid, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + + if tc.headerVal != "" { + r.Header.Set(HeaderName, tc.headerVal) + } + + var cfg conf.SecurityConfiguration + + var handler http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) { + obsAddr, obsFound := GetIPAddress(r) + require.Equal(t, tc.expAddr, obsAddr) + require.Equal(t, tc.expFound, obsFound) + } + + errCallback := func(r *http.Request, err error) { + if tc.expErr == nil { + t.Fatal("error callback called when expected error is nil") + } + + require.ErrorIs(t, err, tc.expErr) + } + + cfg.SbForwardedForEnabled = tc.sbffEnabled + + middlewareFn := Middleware(&cfg, errCallback) + + wrappedHandler := middlewareFn(handler) + + wrappedHandler.ServeHTTP(nil, r) + }) + } +} diff --git a/internal/utilities/request.go b/internal/utilities/request.go index fcfac8287..bd38c7381 100644 --- a/internal/utilities/request.go +++ b/internal/utilities/request.go @@ -10,11 +10,10 @@ import ( "strings" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/sbff" ) -// GetIPAddress returns the real IP address of the HTTP request. It parses the -// X-Forwarded-For header. -func GetIPAddress(r *http.Request) string { +func getIPAddressWithXFF(r *http.Request) string { if r.Header != nil { xForwardedFor := r.Header.Get("X-Forwarded-For") if xForwardedFor != "" { @@ -45,6 +44,15 @@ func GetIPAddress(r *http.Request) string { return ip } +// GetIPAddress returns the real IP address of the HTTP request. +func GetIPAddress(r *http.Request) string { + if sbffAddr, ok := sbff.GetIPAddress(r); ok { + return sbffAddr + } + + return getIPAddressWithXFF(r) +} + // GetBodyBytes reads the whole request body properly into a byte array. func GetBodyBytes(req *http.Request) ([]byte, error) { if req.Body == nil || req.Body == http.NoBody { diff --git a/internal/utilities/request_test.go b/internal/utilities/request_test.go index c1d1a6621..91ae97fac 100644 --- a/internal/utilities/request_test.go +++ b/internal/utilities/request_test.go @@ -7,9 +7,69 @@ import ( "github.com/stretchr/testify/require" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/sbff" ) -func TestGetIPAddress(t *tst.T) { +func TestGetIPAddressWithSBFF(t *tst.T) { + testCases := []struct { + name string + remoteAddr string + headerVal string + expAddr string + }{ + { + name: "ValidSBFF", + remoteAddr: "60.60.60.60", + headerVal: "192.168.1.100", + expAddr: "192.168.1.100", + }, + { + name: "MissingSBFF", + remoteAddr: "60.60.60.60", + headerVal: "", + expAddr: "60.60.60.60", + }, + { + name: "InvalidSBFF", + remoteAddr: "60.60.60.60", + headerVal: "invalid", + expAddr: "60.60.60.60", + }, + } + + config := conf.SecurityConfiguration{ + SbForwardedForEnabled: true, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *tst.T) { + var handler http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) { + obsAddr := GetIPAddress(r) + require.Equal(t, tc.expAddr, obsAddr) + } + + errCallback := func(r *http.Request, err error) { + } + + middleware := sbff.Middleware(&config, errCallback) + + wrappedHandler := middleware(handler) + + r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + + r.RemoteAddr = tc.remoteAddr + + if tc.headerVal != "" { + r.Header.Set(sbff.HeaderName, tc.headerVal) + } + + wrappedHandler.ServeHTTP(nil, r) + }) + + } +} + +func TestGetIPAddressWithXFF(t *tst.T) { examples := []func(r *http.Request) string{ func(r *http.Request) string { r.Header = nil