Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
Expand All @@ -18,8 +19,10 @@
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/oauthserver"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/sbff"
"github.com/supabase/auth/internal/security"
"github.com/supabase/auth/internal/utilities"

Expand Down Expand Up @@ -61,7 +64,7 @@

var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered")

func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
func (a *API) performRateLimitingWithHeader(lmt *limiter.Limiter, req *http.Request) error {
limitHeader := a.config.RateLimitHeader

// If no rate limit header was set, ignore rate limiting
Expand Down Expand Up @@ -112,6 +115,18 @@
return nil
}

func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
if sbffAddr, ok := sbff.GetSBForwardedForAddress(req); ok {
if err := tollbooth.LimitByKeys(lmt, []string{sbffAddr}); err != nil {
return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
}

return nil
}

return a.performRateLimitingWithHeader(lmt, req)
}

func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
return req.Context(), a.performRateLimiting(lmt, req)
Expand Down Expand Up @@ -490,3 +505,34 @@
})
}
}

// SBForwardedForMiddleware returns a middleware function that parses the Sb-Forwarded-For header
// and adds the leftmost header value to the request context if GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED
// is true and the value is a valid IP address.
func sbForwardedForMiddleware(cfg *conf.SecurityConfiguration) func(http.Handler) http.Handler {

Check failure on line 512 in internal/api/middleware.go

View workflow job for this annotation

GitHub Actions / test

func sbForwardedForMiddleware is unused (U1000)
out := func(next http.Handler) http.Handler {
handlerFunc := func(rw http.ResponseWriter, r *http.Request) {
if !cfg.SbForwardedForEnabled {
next.ServeHTTP(rw, r)
return
}

reqWithSBFF, err := sbff.ParseSBForwardedForAddress(r)

switch {
case err == nil:
next.ServeHTTP(rw, reqWithSBFF)
case errors.Is(err, sbff.ErrHeaderNotFound):
next.ServeHTTP(rw, r)
default:
log := observability.GetLogEntry(r).Entry
log.WithField("header", sbff.HeaderNameSBFF).WithField("error", err.Error()).Warn("error processing Sb-Forwarded-For")
next.ServeHTTP(rw, r)
}
}

return http.HandlerFunc(handlerFunc)
}

return out
}
2 changes: 2 additions & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,8 @@ type SecurityConfiguration struct {
RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"`
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"`
SbForwardedForEnabled bool `json:"forwarded_ip_header_enabled" split_words:"true" default:"false"`
ForwardedIPHeader string `json:"forwarded_ip_header" split_words:"true" default:"false"`

DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"`
}
Expand Down
68 changes: 68 additions & 0 deletions internal/sbff/sbff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package sbff

import (
"context"
"errors"
"net"
"net/http"
"strings"
)

var (
ctxKeySBFF = &struct{}{}
HeaderNameSBFF = "sb-forwarded-for"

ErrHeaderNotFound = errors.New("Sb-Forwarded-For header not found")
ErrHeaderInvalid = errors.New("invalid Sb-Forwarded-For header value")
)

func parseSBFFHeader(headerVal string) (string, error) {
values := strings.SplitN(headerVal, ",", 2)

key := strings.TrimSpace(values[0])

if ipAddr := net.ParseIP(key); ipAddr != nil {
return ipAddr.String(), nil
}

return "", ErrHeaderInvalid
}

// GetSBForwardedForAddress returns the value of the IP address in Sb-Forwarded-For as defined by
// SBForwardedForMiddleware. If no value is present in the request context, this function will
// return ("", false).
func GetSBForwardedForAddress(r *http.Request) (addr string, found bool) {
value := r.Context().Value(ctxKeySBFF)

if value == nil {
return "", false
}

ipAddr, ok := value.(string)

return ipAddr, ok
}

// SetSBForwardedForAddress parses the Sb-Forwarded-For header and adds the leftmost value to the
// request context if it is a valid IP address, then returns a new request with modified context.
// If the leftmost value is not a valid IP address or the header is not set, this function returns
// an error.
func ParseSBForwardedForAddress(r *http.Request) (*http.Request, error) {
ctx := r.Context()
headerVal := r.Header.Get(HeaderNameSBFF)

if headerVal == "" {
return nil, ErrHeaderNotFound
}

parsedIPAddr, err := parseSBFFHeader(headerVal)

if err != nil {
return nil, err
}

newCtx := context.WithValue(ctx, ctxKeySBFF, parsedIPAddr)
out := r.WithContext(newCtx)

return out, nil
}
14 changes: 11 additions & 3 deletions internal/utilities/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ import (
"strings"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/sbff"
)

// GetIPAddress returns the real IP address of the HTTP request. It parses the
// X-Forwarded-For header.
func GetIPAddress(r *http.Request) string {
func getIPAddressWithXFF(r *http.Request) string {
if r.Header != nil {
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
Expand Down Expand Up @@ -45,6 +44,15 @@ func GetIPAddress(r *http.Request) string {
return ip
}

// GetIPAddress returns the real IP address of the HTTP request.
func GetIPAddress(r *http.Request) string {
if sbffAddr, ok := sbff.GetSBForwardedForAddress(r); ok {
return sbffAddr
}

return getIPAddressWithXFF(r)
}

// GetBodyBytes reads the whole request body properly into a byte array.
func GetBodyBytes(req *http.Request) ([]byte, error) {
if req.Body == nil || req.Body == http.NoBody {
Expand Down
Loading