Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/supabase/auth/internal/mailer/templatemailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/sbff"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/tokens"
"github.com/supabase/auth/internal/utilities"
Expand Down Expand Up @@ -152,8 +153,17 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r := newRouter()
r.UseBypass(observability.AddRequestID(globalConfig))
r.UseBypass(logger)
r.UseBypass(xffmw.Handler)
r.UseBypass(recoverer)
r.UseBypass(
sbff.Middleware(
&globalConfig.Security,
func(r *http.Request, err error) {
log := observability.GetLogEntry(r).Entry
log.WithField("error", err.Error()).Warn("error processing Sb-Forwarded-For")
},
),
)
r.UseBypass(xffmw.Handler)

if globalConfig.API.MaxRequestDuration > 0 {
r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration))
Expand Down
15 changes: 14 additions & 1 deletion internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/sbff"
"github.com/supabase/auth/internal/security"
"github.com/supabase/auth/internal/utilities"

Expand Down Expand Up @@ -61,7 +62,7 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error {

var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered")

func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
func (a *API) performRateLimitingWithHeader(lmt *limiter.Limiter, req *http.Request) error {
limitHeader := a.config.RateLimitHeader

// If no rate limit header was set, ignore rate limiting
Expand Down Expand Up @@ -112,6 +113,18 @@ func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error
return nil
}

func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
if sbffAddr, ok := sbff.GetIPAddress(req); ok {
if err := tollbooth.LimitByKeys(lmt, []string{sbffAddr}); err != nil {
return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
}

return nil
}

return a.performRateLimitingWithHeader(lmt, req)
}

func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
return req.Context(), a.performRateLimiting(lmt, req)
Expand Down
1 change: 1 addition & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ type SecurityConfiguration struct {
RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"`
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"`
SbForwardedForEnabled bool `json:"forwarded_ip_header_enabled" split_words:"true" default:"false"`

DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"`
}
Expand Down
98 changes: 98 additions & 0 deletions internal/sbff/sbff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package sbff

import (
"context"
"errors"
"net"
"net/http"
"strings"

"github.com/supabase/auth/internal/conf"
)

var (
ctxKeySBFF = &struct{}{}
headerName = "sb-forwarded-for"

ErrHeaderNotFound = errors.New("Sb-Forwarded-For header not found")
ErrHeaderInvalid = errors.New("invalid Sb-Forwarded-For header value")
)

func parseSBFFHeader(headerVal string) (string, error) {
values := strings.SplitN(headerVal, ",", 2)

key := strings.TrimSpace(values[0])

if ipAddr := net.ParseIP(key); ipAddr != nil {
return ipAddr.String(), nil
}

return "", ErrHeaderInvalid
}

// GetIPAddress returns the value of the IP address in Sb-Forwarded-For as defined by
// SBForwardedForMiddleware. If no value is present in the request context, this function will
// return ("", false).
func GetIPAddress(r *http.Request) (addr string, found bool) {
value := r.Context().Value(ctxKeySBFF)

if value == nil {
return "", false
}

ipAddr, ok := value.(string)

return ipAddr, ok
}

// WithIPAddress parses the Sb-Forwarded-For header and adds the leftmost value to the
// request context if it is a valid IP address, then returns a new request with modified context.
// If the leftmost value is not a valid IP address or the header is not set, this function returns
// an error.
func WithIPAddress(r *http.Request) (*http.Request, error) {
ctx := r.Context()
headerVal := r.Header.Get(headerName)
if headerVal == "" {
return nil, ErrHeaderNotFound
}

parsedIPAddr, err := parseSBFFHeader(headerVal)
if err != nil {
return nil, err
}

newCtx := context.WithValue(ctx, ctxKeySBFF, parsedIPAddr)
out := r.WithContext(newCtx)

return out, nil
}

// Middleware returns a middleware function that parses the Sb-Forwarded-For header
// and adds the leftmost header value to the request context if GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED
// is true and the value is a valid IP address.
func Middleware(cfg *conf.SecurityConfiguration, errCallback func(*http.Request, error)) func(http.Handler) http.Handler {
out := func(next http.Handler) http.Handler {
handlerFunc := func(rw http.ResponseWriter, r *http.Request) {
if !cfg.SbForwardedForEnabled {
next.ServeHTTP(rw, r)
return
}

reqWithSBFF, err := WithIPAddress(r)

switch {
case err == nil:
next.ServeHTTP(rw, reqWithSBFF)
case errors.Is(err, ErrHeaderNotFound):
next.ServeHTTP(rw, r)
default:
errCallback(r, err)
next.ServeHTTP(rw, r)
}
}

return http.HandlerFunc(handlerFunc)
}

return out
}
14 changes: 11 additions & 3 deletions internal/utilities/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ import (
"strings"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/sbff"
)

// GetIPAddress returns the real IP address of the HTTP request. It parses the
// X-Forwarded-For header.
func GetIPAddress(r *http.Request) string {
func getIPAddressWithXFF(r *http.Request) string {
if r.Header != nil {
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
Expand Down Expand Up @@ -45,6 +44,15 @@ func GetIPAddress(r *http.Request) string {
return ip
}

// GetIPAddress returns the real IP address of the HTTP request.
func GetIPAddress(r *http.Request) string {
if sbffAddr, ok := sbff.GetIPAddress(r); ok {
return sbffAddr
}

return getIPAddressWithXFF(r)
}

// GetBodyBytes reads the whole request body properly into a byte array.
func GetBodyBytes(req *http.Request) ([]byte, error) {
if req.Body == nil || req.Body == http.NoBody {
Expand Down