diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5ca0eef417..a54d9bfeee 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -49,10 +49,13 @@ jobs: - name: Run static check run: | set -x + go install honnef.co/go/tools/cmd/staticcheck@latest + go install github.com/nishanths/exhaustive/cmd/exhaustive@latest make static - name: Check gosec run: | set -x + go install github.com/securego/gosec/v2/cmd/gosec@latest make sec - name: Init Database run: psql -f hack/init_postgres.sql postgresql://postgres:root@localhost:5432/postgres diff --git a/Dockerfile b/Dockerfile index 3900a025fa..6fd9fc3006 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.25.5-alpine3.23 as build +FROM golang:1.23.7-alpine3.20 as build ENV GO111MODULE=on ENV CGO_ENABLED=0 ENV GOOS=linux diff --git a/Dockerfile.dev b/Dockerfile.dev index 99a8c0d5cb..9f0a304212 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -1,4 +1,4 @@ -FROM golang:1.25.5-alpine3.23 +FROM golang:1.23.7-alpine3.20 ENV GO111MODULE=on ENV CGO_ENABLED=0 ENV GOOS=linux diff --git a/Makefile b/Makefile index 14d41aa1c5..cf768b39b6 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ .PHONY: all build deps image migrate test vet sec format unused -.PHONY: check-gosec check-oapi-codegen check-staticcheck +.PHONY: check-exhaustive check-gosec check-oapi-codegen check-staticcheck CHECK_FILES?=./... ifdef RELEASE_VERSION @@ -66,13 +66,18 @@ unused: | check-staticcheck # Look for unused code @echo "Code used only in _test.go (do move it in those files):" staticcheck -checks U1000 -tests=false $(CHECK_FILES) -static: | check-staticcheck +static: | check-staticcheck check-exhaustive staticcheck ./... + exhaustive ./... check-staticcheck: @command -v staticcheck >/dev/null 2>&1 \ || go install honnef.co/go/tools/cmd/staticcheck@latest +check-exhaustive: + @command -v exhaustive >/dev/null 2>&1 \ + || go install github.com/nishanths/exhaustive/cmd/exhaustive@latest + generate: | check-oapi-codegen go generate ./... diff --git a/README.md b/README.md index 9b0405aee8..77417b3627 100644 --- a/README.md +++ b/README.md @@ -888,12 +888,6 @@ Enforce reauthentication on password update. Use this to enable/disable anonymous sign-ins. -### IP address forwarding - -`GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED` - `bool` - -Enable IP address forwarding using the `Sb-Forwarded-For` HTTP request header. When enabled, Auth will parse the first value of this header as an IP address and use it for IP address tracking and rate limiting. Make sure this header is fully trusted before enabling this feature by only passing it from trustworthy clients or proxies. - ## Endpoints Auth exposes the following endpoints: diff --git a/cmd/migrate_cmd.go b/cmd/migrate_cmd.go index 511f53b27d..c9b80dfb87 100644 --- a/cmd/migrate_cmd.go +++ b/cmd/migrate_cmd.go @@ -2,6 +2,7 @@ package cmd import ( "embed" + "fmt" "net/url" "os" @@ -22,12 +23,12 @@ var migrateCmd = cobra.Command{ func migrate(cmd *cobra.Command, args []string) { globalConfig := loadGlobalConfig(cmd.Context()) - u, err := url.Parse(globalConfig.DB.URL) - if err != nil { - logrus.Fatalf("%+v", errors.Wrap(err, "parsing db connection url")) - } if globalConfig.DB.Driver == "" && globalConfig.DB.URL != "" { + u, err := url.Parse(globalConfig.DB.URL) + if err != nil { + logrus.Fatalf("%+v", errors.Wrap(err, "parsing db connection url")) + } globalConfig.DB.Driver = u.Scheme } @@ -52,12 +53,16 @@ func migrate(cmd *cobra.Command, args []string) { } } - q := u.Query() - q.Add("application_name", "auth_migrations") - u.RawQuery = q.Encode() + u, _ := url.Parse(globalConfig.DB.URL) + processedUrl := globalConfig.DB.URL + if len(u.Query()) != 0 { + processedUrl = fmt.Sprintf("%s&application_name=gotrue_migrations", processedUrl) + } else { + processedUrl = fmt.Sprintf("%s?application_name=gotrue_migrations", processedUrl) + } deets := &pop.ConnectionDetails{ Dialect: globalConfig.DB.Driver, - URL: u.String(), + URL: processedUrl, } deets.Options = map[string]string{ "migration_table_name": "schema_migrations", diff --git a/go.mod b/go.mod index 966c5074d0..88ed5cf408 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.10.0 - golang.org/x/crypto v0.40.0 + golang.org/x/crypto v0.36.0 golang.org/x/oauth2 v0.27.0 gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df ) @@ -71,8 +71,8 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect - golang.org/x/mod v0.26.0 // indirect - golang.org/x/tools v0.35.0 // indirect + golang.org/x/mod v0.22.0 // indirect + golang.org/x/tools v0.29.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect ) @@ -169,10 +169,10 @@ require ( github.com/stretchr/objx v0.5.2 // indirect go.opentelemetry.io/proto/otlp v1.2.0 // indirect golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb - golang.org/x/net v0.42.0 // indirect - golang.org/x/sync v0.16.0 - golang.org/x/sys v0.34.0 - golang.org/x/text v0.27.0 + golang.org/x/net v0.38.0 // indirect + golang.org/x/sync v0.12.0 + golang.org/x/sys v0.31.0 + golang.org/x/text v0.23.0 golang.org/x/time v0.9.0 google.golang.org/grpc v1.63.2 // indirect google.golang.org/protobuf v1.34.2 // indirect @@ -181,4 +181,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.25.5 +go 1.23.7 diff --git a/go.sum b/go.sum index 0e7a44a10e..8f420c6729 100644 --- a/go.sum +++ b/go.sum @@ -558,8 +558,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb h1:mIKbk8weKhSeLH2GmUTrvx8CjkyJmnU1wFmg59CUjFA= golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -568,8 +568,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20161007143504-f4b625ec9b21/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -587,8 +587,8 @@ golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -597,8 +597,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -632,8 +632,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -650,8 +650,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/time v0.0.0-20160926182426-711ca1cb8763/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= @@ -668,8 +668,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= +golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/api/admin.go b/internal/api/admin.go index ad50b81e53..0d53406def 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -390,7 +390,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if err != nil { if errors.Is(err, bcrypt.ErrPasswordTooLong) { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "%s", err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) } return apierrors.NewInternalServerError("Error creating user").WithInternalError(err) } diff --git a/internal/api/anonymous_test.go b/internal/api/anonymous_test.go index 628206954e..81d900de85 100644 --- a/internal/api/anonymous_test.go +++ b/internal/api/anonymous_test.go @@ -16,7 +16,6 @@ import ( "github.com/supabase/auth/internal/conf" mail "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/models" - "github.com/supabase/auth/internal/storage" ) type AnonymousTestSuite struct { @@ -26,14 +25,9 @@ type AnonymousTestSuite struct { } func TestAnonymous(t *testing.T) { - cb := func(cfg *conf.GlobalConfiguration, _ *storage.Connection) { - if cfg != nil { - cfg.RateLimitAnonymousUsers = 5 - } - } - - api, config, err := setupAPIForTestWithCallback(cb) + api, config, err := setupAPIForTest() require.NoError(t, err) + ts := &AnonymousTestSuite{ API: api, Config: config, diff --git a/internal/api/api.go b/internal/api/api.go index e656def35f..a728251c30 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -19,7 +19,6 @@ import ( "github.com/supabase/auth/internal/mailer/templatemailer" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" - "github.com/supabase/auth/internal/sbff" "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/tokens" "github.com/supabase/auth/internal/utilities" @@ -153,17 +152,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r := newRouter() r.UseBypass(observability.AddRequestID(globalConfig)) r.UseBypass(logger) - r.UseBypass(recoverer) - r.UseBypass( - sbff.Middleware( - &globalConfig.Security, - func(r *http.Request, err error) { - log := observability.GetLogEntry(r).Entry - log.WithField("error", err.Error()).Warn("error processing Sb-Forwarded-For") - }, - ), - ) r.UseBypass(xffmw.Handler) + r.UseBypass(recoverer) if globalConfig.API.MaxRequestDuration > 0 { r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration)) diff --git a/internal/api/apierrors/apierrors_test.go b/internal/api/apierrors/apierrors_test.go index aa47066d3a..515a79a9ef 100644 --- a/internal/api/apierrors/apierrors_test.go +++ b/internal/api/apierrors/apierrors_test.go @@ -144,7 +144,7 @@ func TestHTTPErrors(t *testing.T) { ErrorCodeBadJSON, "Unable to parse JSON: %v", errors.New("bad syntax"), - ).WithInternalError(sentinel).WithInternalMessage("%s", sentinel.Error()) + ).WithInternalError(sentinel).WithInternalMessage(sentinel.Error()) require.Equal(t, err.Error(), sentinel.Error()) require.Equal(t, err.Cause(), sentinel) @@ -171,7 +171,7 @@ func TestOAuthErrors(t *testing.T) { err := NewOAuthError( "oauth error", "oauth desc", - ).WithInternalError(sentinel).WithInternalMessage("%s", sentinel.Error()) + ).WithInternalError(sentinel).WithInternalMessage(sentinel.Error()) require.Error(t, err) require.Equal(t, err.Error(), sentinel.Error()) diff --git a/internal/api/auth.go b/internal/api/auth.go index abcb529a21..448212beb6 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -58,10 +58,7 @@ func (a *API) requireAdmin(ctx context.Context) (context.Context, error) { return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil } - return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeNotAdmin, "User not allowed"). - WithInternalMessage( - "this token needs to have one of the following roles: %v", - strings.Join(adminRoles, ", ")) + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", "))) } func (a *API) extractBearerToken(r *http.Request) (string, error) { @@ -146,7 +143,7 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro session, err = models.FindSessionByID(db, sessionId, false) if err != nil { if models.IsNotFoundError(err) { - return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage("session id (%s) doesn't exist", sessionId) + return ctx, apierrors.NewForbiddenError(apierrors.ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId)) } return ctx, err } diff --git a/internal/api/e2e_test.go b/internal/api/e2e_test.go index 81c2cc7ec8..2ac3221118 100644 --- a/internal/api/e2e_test.go +++ b/internal/api/e2e_test.go @@ -124,31 +124,23 @@ func runVerifyAfterUserCreatedHook( return latest } -func getEmailAccessToken( +func getAccessToken( ctx context.Context, t *testing.T, inst *e2ehooks.Instance, email, pass string, ) *api.AccessTokenResponse { - return getAccessToken(ctx, t, inst, &api.PasswordGrantParams{ + req := &api.PasswordGrantParams{ Email: email, Password: pass, - }) -} + } -func getAccessToken( - ctx context.Context, - t *testing.T, - inst *e2ehooks.Instance, - req *api.PasswordGrantParams, -) *api.AccessTokenResponse { res := new(api.AccessTokenResponse) err := e2eapi.Do(ctx, http.MethodPost, inst.APIServer.URL+"/token?grant_type=password", req, res) require.NoError(t, err) return res } - -func signupAndConfirmEmail( +func signupAndConfirm( ctx context.Context, t *testing.T, inst *e2ehooks.Instance, @@ -257,180 +249,21 @@ func TestE2EHooks(t *testing.T) { }) t.Run("SignupPhone", func(t *testing.T) { - defer inst.HookRecorder.AfterUserCreated.ClearCalls() defer inst.HookRecorder.BeforeUserCreated.ClearCalls() - defer inst.HookRecorder.SendSMS.ClearCalls() - - var currentUser *models.User - { - phone := genPhone() - req := &api.SignupParams{ - Phone: phone, - Password: defaultPassword, - } - signupUser := new(models.User) - { - err := e2eapi.Do( - ctx, http.MethodPost, inst.APIServer.URL+"/signup", req, signupUser) - require.NoError(t, err) - require.Equal(t, phone, signupUser.Phone.String()) - } - - // load the hook call - calls := inst.HookRecorder.SendSMS.GetCalls() - require.Equal(t, 1, len(calls)) - call := calls[0] - - hookReq := &v0hooks.SendSMSInput{} - err = call.Unmarshal(hookReq) - require.NoError(t, err) - latestUser, err := models.FindUserByID(inst.Conn, signupUser.ID) - require.NoError(t, err) - require.NotNil(t, latestUser) - - otp := hookReq.SMS.OTP - otpHash := crypto.GenerateTokenHash( - signupUser.GetPhone(), hookReq.SMS.OTP) - - ott, err := models.FindOneTimeToken( - inst.Conn, - otpHash, - models.ConfirmationToken) - require.NoError(t, err) - require.Equal(t, signupUser.ID.String(), ott.UserID.String()) - require.Equal(t, signupUser.Phone.String(), ott.RelatesTo) - - { - req := &api.VerifyParams{ - Type: "sms", - Token: otp, - Phone: phone, - } - res := new(models.User) - - body := new(bytes.Buffer) - err = json.NewEncoder(body).Encode(req) - require.NoError(t, err) - - httpReq, err := http.NewRequestWithContext( - ctx, "POST", "/verify", body) - require.NoError(t, err) - - httpRes, err := inst.Do(httpReq) - require.NoError(t, err) - require.Equal(t, 200, httpRes.StatusCode) - - err = json.NewDecoder(httpRes.Body).Decode(res) - require.NoError(t, err) - } - - { - // setup phone change - latestUser, err = models.FindUserByID(inst.Conn, signupUser.ID) - require.NoError(t, err) - require.NotNil(t, latestUser) - currentUser = latestUser - } + phone := genPhone() + req := &api.SignupParams{ + Phone: phone, + Password: defaultPassword, } + res := new(models.User) + err := e2eapi.Do(ctx, http.MethodPost, inst.APIServer.URL+"/signup", req, res) + require.NoError(t, err) + require.Equal(t, phone, res.Phone.String()) - t.Run("PhoneChange", func(t *testing.T) { - inst.HookRecorder.BeforeUserCreated.ClearCalls() - inst.HookRecorder.SendSMS.ClearCalls() - - currentAccessToken := getAccessToken(ctx, t, inst, - &api.PasswordGrantParams{ - Phone: string(currentUser.Phone), - Password: defaultPassword, - }) - - curPhone := currentUser.Phone.String() - newPhone := genPhone() - { - req := &api.UserUpdateParams{ - Phone: newPhone, - } - res := new(models.User) - - body := new(bytes.Buffer) - err = json.NewEncoder(body).Encode(req) - require.NoError(t, err) - - httpReq, err := http.NewRequestWithContext( - ctx, "PUT", "/user", body) - require.NoError(t, err) - - httpRes, err := inst.DoAuth(httpReq, currentAccessToken.Token) - require.NoError(t, err) - require.Equal(t, 200, httpRes.StatusCode) - - err = json.NewDecoder(httpRes.Body).Decode(res) - require.NoError(t, err) - - currentUser = res - } - - var otp string - { - require.Equal(t, curPhone, currentUser.Phone.String()) - require.Equal(t, newPhone, currentUser.PhoneChange) - - calls := inst.HookRecorder.SendSMS.GetCalls() - require.Equal(t, 1, len(calls)) - call := calls[0] - - hookReq := &v0hooks.SendSMSInput{} - err = call.Unmarshal(hookReq) - require.NoError(t, err) - - require.Equal(t, currentUser.ID, hookReq.User.ID) - require.Equal(t, currentUser.Aud, hookReq.User.Aud) - require.Equal(t, currentUser.Phone, hookReq.User.Phone) - require.Equal(t, currentUser.AppMetaData, hookReq.User.AppMetaData) - - otp = hookReq.SMS.OTP - otpHash := crypto.GenerateTokenHash( - currentUser.PhoneChange, hookReq.SMS.OTP) - - ott, err := models.FindOneTimeToken( - inst.Conn, - otpHash, - models.PhoneChangeToken) - require.NoError(t, err) - require.Equal(t, currentUser.ID.String(), ott.UserID.String()) - require.Equal(t, currentUser.PhoneChange, ott.RelatesTo) - - latestUser, err := models.FindUserByID(inst.Conn, currentUser.ID) - require.NoError(t, err) - require.NotNil(t, latestUser) - - currentUser = latestUser - } - - { - req := &api.VerifyParams{ - Type: "phone_change", - Token: otp, - Phone: currentUser.PhoneChange, - } - res := new(models.User) - - body := new(bytes.Buffer) - err = json.NewEncoder(body).Encode(req) - require.NoError(t, err) - - httpReq, err := http.NewRequestWithContext( - ctx, "POST", "/verify", body) - require.NoError(t, err) - - httpRes, err := inst.Do(httpReq) - require.NoError(t, err) - require.Equal(t, 200, httpRes.StatusCode) + runVerifyBeforeUserCreatedHook(t, inst, res) + runVerifyAfterUserCreatedHook(t, inst, res) - err = json.NewDecoder(httpRes.Body).Decode(res) - require.NoError(t, err) - } - }) }) t.Run("SignupAnonymously", func(t *testing.T) { @@ -589,7 +422,7 @@ func TestE2EHooks(t *testing.T) { mfaUser = runVerifyBeforeUserCreatedHook(t, inst, mfaUser) runVerifyAfterUserCreatedHook(t, inst, mfaUser) require.NotNil(t, mfaUser) - mfaUserAccessToken = getEmailAccessToken( + mfaUserAccessToken = getAccessToken( ctx, t, inst, string(mfaUser.Email), defaultPassword) phone := genPhone() @@ -971,90 +804,6 @@ func TestE2EHooks(t *testing.T) { } }) } - - t.Run("AMRStringArrayUnmarshalling", func(t *testing.T) { - defer inst.HookRecorder.CustomizeAccessToken.ClearCalls() - - // Setup hook that returns amr as array of strings - var claimsIn M - hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("content-type", "application/json") - w.WriteHeader(http.StatusOK) - - err := json.NewDecoder(r.Body).Decode(&claimsIn) - require.NoError(t, err) - - // Modify amr to be array of strings instead of objects - claimsOut := copyMap(t, claimsIn) - claimsOut["claims"].(M)["amr"] = []string{"password", "totp"} - - err = json.NewEncoder(w).Encode(claimsOut) - require.NoError(t, err) - }) - - inst.HookRecorder.CustomizeAccessToken.ClearCalls() - inst.HookRecorder.CustomizeAccessToken.SetHandler(hr) - - // Get token with modified amr - req := &api.PasswordGrantParams{ - Email: string(currentUser.Email), - Password: defaultPassword, - } - - res := new(api.AccessTokenResponse) - err := e2eapi.Do(ctx, http.MethodPost, inst.APIServer.URL+"/token?grant_type=password", req, res) - require.NoError(t, err) - require.True(t, len(res.Token) > 0) - - // Verify hook was called - { - calls := inst.HookRecorder.CustomizeAccessToken.GetCalls() - require.Equal(t, 1, len(calls)) - } - - // Parse token to verify it can be unmarshalled - p := jwt.NewParser(jwt.WithValidMethods(globalCfg.JWT.ValidMethods)) - token, err := p.ParseWithClaims( - res.Token, - &api.AccessTokenClaims{}, - func(token *jwt.Token) (any, error) { - if kid, ok := token.Header["kid"]; ok { - if kidStr, ok := kid.(string); ok { - return conf.FindPublicKeyByKid(kidStr, &globalCfg.JWT) - } - } - if alg, ok := token.Header["alg"]; ok { - if alg == jwt.SigningMethodHS256.Name { - return []byte(globalCfg.JWT.Secret), nil - } - } - return nil, fmt.Errorf("missing kid") - }) - require.NoError(t, err, "Token should parse successfully even with string array amr") - - fmt.Println("token hereee", res.Token) - // Verify claims were unmarshalled correctly - claims, ok := token.Claims.(*api.AccessTokenClaims) - require.True(t, ok, "Claims should be AccessTokenClaims type") - require.NotNil(t, claims.AuthenticationMethodReference, "AMR should not be nil") - require.Len(t, claims.AuthenticationMethodReference, 2, "AMR should have 2 entries") - require.Equal(t, "password", claims.AuthenticationMethodReference[0].Method) - require.Equal(t, "totp", claims.AuthenticationMethodReference[1].Method) - - // Call /user endpoint with the token to verify it works end-to-end - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "/user", nil) - require.NoError(t, err) - - httpRes, err := inst.DoAuth(httpReq, res.Token) - require.NoError(t, err, "Should be able to call /user endpoint with token containing string array amr") - require.Equal(t, http.StatusOK, httpRes.StatusCode, "/user endpoint should return 200 OK") - - // Verify we got user data back - var userData models.User - err = json.NewDecoder(httpRes.Body).Decode(&userData) - require.NoError(t, err, "Should be able to decode user response") - require.Equal(t, currentUser.ID, userData.ID, "Should get the correct user") - }) }) t.Run("SendEmail", func(t *testing.T) { @@ -1064,7 +813,7 @@ func TestE2EHooks(t *testing.T) { require.NoError(t, err) defer inst.Close() - signupAndConfirmEmail(ctx, t, inst) + signupAndConfirm(ctx, t, inst) }) t.Run("SecureEmailChange=Enabled", func(t *testing.T) { @@ -1080,11 +829,11 @@ func TestE2EHooks(t *testing.T) { // test requires this flag require.True(t, inst.Config.Mailer.SecureEmailChangeEnabled) - signupUser := signupAndConfirmEmail(ctx, t, inst) + signupUser := signupAndConfirm(ctx, t, inst) currentUser := signupUser // get access token - currentAccessToken := getEmailAccessToken( + currentAccessToken := getAccessToken( ctx, t, inst, string(currentUser.Email), defaultPassword) // update email @@ -1296,11 +1045,11 @@ func TestE2EHooks(t *testing.T) { // test requires this flag require.False(t, inst.Config.Mailer.SecureEmailChangeEnabled) - signupUser := signupAndConfirmEmail(ctx, t, inst) + signupUser := signupAndConfirm(ctx, t, inst) currentUser := signupUser // get access token - currentAccessToken := getEmailAccessToken( + currentAccessToken := getAccessToken( ctx, t, inst, string(currentUser.Email), defaultPassword) // update email diff --git a/internal/api/errors.go b/internal/api/errors.go index a9b467f36a..7479f9f032 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -15,13 +15,10 @@ import ( ) // Common error messages during signup flow -const ( - DuplicateEmailMsg = "A user with this email address has already been registered" - DuplicatePhoneMsg = "A user with this phone number has already been registered" -) - var ( - UserExistsError error = errors.New("user already exists") + DuplicateEmailMsg = "A user with this email address has already been registered" + DuplicatePhoneMsg = "A user with this phone number has already been registered" + UserExistsError error = errors.New("user already exists") ) const InvalidChannelError = "Invalid channel, supported values are 'sms' or 'whatsapp'. 'whatsapp' is only supported if Twilio or Twilio Verify is used as the provider." diff --git a/internal/api/external.go b/internal/api/external.go index 8392797d59..dc2fd6e008 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -441,19 +441,10 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. if !config.Mailer.AllowUnverifiedEmailSignIns { if emailConfirmationSent { - err := apierrors.NewUnprocessableEntityError( - apierrors.ErrorCodeProviderEmailNeedsVerification, - "Unverified email with %v. A confirmation email has been sent to your %v email", - providerType, providerType, - ) - return 0, nil, storage.NewCommitWithError(err) + return 0, nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) } - err := apierrors.NewUnprocessableEntityError( - apierrors.ErrorCodeProviderEmailNeedsVerification, - "Unverified email with %v. Verify the email with %v in order to sign in", - providerType, providerType) - return 0, nil, storage.NewCommitWithError(err) + return 0, nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) } } } else { diff --git a/internal/api/logout.go b/internal/api/logout.go index 310b7c0d60..fa3742ccc3 100644 --- a/internal/api/logout.go +++ b/internal/api/logout.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "net/http" "github.com/sirupsen/logrus" @@ -36,7 +37,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { scope = LogoutOthers default: - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unsupported logout scope %q", r.URL.Query().Get("scope")) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) } } @@ -51,6 +52,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { if s == nil { logrus.Infof("user has an empty session_id claim: %s", u.ID) } else { + //exhaustive:ignore Default case is handled below. switch scope { case LogoutLocal: return models.LogoutSession(tx, s.ID) diff --git a/internal/api/mail.go b/internal/api/mail.go index 7106a5a3a2..d25462a519 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -335,7 +335,7 @@ func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *model }); err != nil { u.ConfirmationToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -370,7 +370,7 @@ func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User if err != nil { u.ConfirmationToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -413,7 +413,7 @@ func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *m if err != nil { u.RecoveryToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -455,7 +455,7 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u if err != nil { u.ReauthenticationToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -498,7 +498,7 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U }); err != nil { u.RecoveryToken = oldToken if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -550,7 +550,7 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models }) if err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -590,7 +590,7 @@ func (a *API) sendPasswordChangedNotification(r *http.Request, tx *storage.Conne }) if err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -607,7 +607,7 @@ func (a *API) sendEmailChangedNotification(r *http.Request, tx *storage.Connecti }) if err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -624,7 +624,7 @@ func (a *API) sendPhoneChangedNotification(r *http.Request, tx *storage.Connecti }) if err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -641,7 +641,7 @@ func (a *API) sendIdentityLinkedNotification(r *http.Request, tx *storage.Connec }) if err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -658,7 +658,7 @@ func (a *API) sendIdentityUnlinkedNotification(r *http.Request, tx *storage.Conn }) if err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -675,7 +675,7 @@ func (a *API) sendMFAFactorEnrolledNotification(r *http.Request, tx *storage.Con }) if err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -692,7 +692,7 @@ func (a *API) sendMFAFactorUnenrolledNotification(r *http.Request, tx *storage.C }) if err != nil { if errors.Is(err, EmailRateLimitExceeded) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", EmailRateLimitExceeded.Error()) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) } else if herr, ok := err.(*HTTPError); ok { return herr } @@ -710,7 +710,7 @@ func (a *API) validateEmail(email string) (string, error) { return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "An email address is too long") } if err := checkmail.ValidateFormat(email); err != nil { - return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unable to validate email address: %s", err.Error()) + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) } return strings.ToLower(email), nil @@ -718,7 +718,7 @@ func (a *API) validateEmail(email string) (string, error) { func validateSentWithinFrequencyLimit(sentAt *time.Time, frequency time.Duration) error { if sentAt != nil && sentAt.Add(frequency).After(time.Now()) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", generateFrequencyLimitErrorMessage(sentAt, frequency)) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, frequency)) } return nil } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 26654c0160..81523363f4 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -140,7 +140,7 @@ func validateFactors(db *storage.Connection, user *models.User, newFactorName st if factor.FriendlyName == newFactorName { return apierrors.NewUnprocessableEntityError( apierrors.ErrorCodeMFAFactorNameConflict, - "A factor with the friendly name %q for this user already exists", newFactorName, + fmt.Sprintf("A factor with the friendly name %q for this user already exists", newFactorName), ) } if factor.IsVerified() { @@ -389,7 +389,7 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error if factor.IsPhoneFactor() && factor.LastChallengedAt != nil { if !factor.LastChallengedAt.Add(config.MFA.Phone.MaxFrequency).Before(time.Now()) { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, "%s", generateFrequencyLimitErrorMessage(factor.LastChallengedAt, config.MFA.Phone.MaxFrequency)) + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(factor.LastChallengedAt, config.MFA.Phone.MaxFrequency)) } } @@ -670,7 +670,7 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V output.Message = v0hooks.DefaultMFAHookRejectionMessage } - return apierrors.NewForbiddenError(apierrors.ErrorCodeMFAVerificationRejected, "%s", output.Message) + return apierrors.NewForbiddenError(apierrors.ErrorCodeMFAVerificationRejected, output.Message) } } if !valid { @@ -821,7 +821,7 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * output.Message = v0hooks.DefaultMFAHookRejectionMessage } - return apierrors.NewForbiddenError(apierrors.ErrorCodeMFAVerificationRejected, "%s", output.Message) + return apierrors.NewForbiddenError(apierrors.ErrorCodeMFAVerificationRejected, output.Message) } } if !valid { diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 92003e6d7b..e41ae80c30 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -20,7 +20,6 @@ import ( "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" - "github.com/supabase/auth/internal/sbff" "github.com/supabase/auth/internal/security" "github.com/supabase/auth/internal/utilities" @@ -62,67 +61,22 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error { var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered") -func (a *API) performRateLimitingWithHeader(lmt *limiter.Limiter, req *http.Request) error { - limitHeader := a.config.RateLimitHeader - - // If no rate limit header was set, ignore rate limiting - if limitHeader == "" { - return nil - } - - valuesStr := req.Header.Get(limitHeader) - - // If a rate limit header was set, but has no value, ignore rate limiting but warn with an error - if valuesStr == "" { - log := observability.GetLogEntry(req).Entry - log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") - - return nil - } - - // According to RFC 7230 section 3.2.2, multiple headers with the same name are equivalent - // to a single header with that name where each value is separated by a comma and whitespace. - // - // Note that there is some ambiguity in RFC 7230 where section 3.2.4 states that - // header field values (which can contain commas) are processed independently of the header - // field name, and thus it is not always clear if a comma is a list delimiter or simply par - // of a single value. - // - // Given that this function is primarily for use with headers like X-Forwarded-For which - // vendors generally combine into comma-separated lists, we opt for the simpler approach - // here and split the header value by commas before taking the first value. - values := strings.SplitN(valuesStr, ",", 2) - - // We will always get at least one value back, so this operation is safe - key := strings.TrimSpace(values[0]) - - // If the rate limit header has at least one value, but the first value is all whitespace, return a warning. - // This will happen if the header is something like "X-Foo-Bar: ,baz". - if key == "" { - log := observability.GetLogEntry(req).Entry - log.WithField("header", limitHeader).Warn("first rate limit header value is empty, rate limiting is not applied") - - return nil - } - - // Otherwise, apply rate limiting based on the first rate limit header value - if err := tollbooth.LimitByKeys(lmt, []string{key}); err != nil { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") - } - - return nil -} - func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error { - if sbffAddr, ok := sbff.GetIPAddress(req); ok { - if err := tollbooth.LimitByKeys(lmt, []string{sbffAddr}); err != nil { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") + if limitHeader := a.config.RateLimitHeader; limitHeader != "" { + key := req.Header.Get(limitHeader) + + if key == "" { + log := observability.GetLogEntry(req).Entry + log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") + } else { + err := tollbooth.LimitByKeys(lmt, []string{key}) + if err != nil { + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") + } } - - return nil } - return a.performRateLimitingWithHeader(lmt, req) + return nil } func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { @@ -138,7 +92,7 @@ func (a *API) requireOAuthClientAuth(w http.ResponseWriter, r *http.Request) (co clientID, clientSecret, err := oauthserver.ExtractClientCredentials(r) if err != nil { - return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials: %s", err.Error()) + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials: "+err.Error()) } // If no client credentials provided, continue without client authentication @@ -164,7 +118,7 @@ func (a *API) requireOAuthClientAuth(w http.ResponseWriter, r *http.Request) (co // Validate authentication using centralized logic if err := oauthserver.ValidateClientAuthentication(client, clientSecret); err != nil { - return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "%s", err.Error()) + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, err.Error()) } // Add authenticated client to context diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index dd0f4da374..68dbabb7cb 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -19,7 +19,6 @@ import ( "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/sbff" "github.com/supabase/auth/internal/storage" ) @@ -416,267 +415,6 @@ func TestTimeoutResponseWriter(t *testing.T) { require.Equal(t, w1.Result(), w2.Result()) } -func (ts *MiddlewareTestSuite) TestPerformRateLimitingWithSBFF() { - origRateLimitHeader := ts.Config.RateLimitHeader - origSBFFEnabled := ts.Config.Security.SbForwardedForEnabled - - defer func() { - ts.Config.RateLimitHeader = origRateLimitHeader - ts.Config.Security.SbForwardedForEnabled = origSBFFEnabled - }() - - ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting" - ts.Config.Security.SbForwardedForEnabled = true - - type headerSet struct { - rateLimiting string - sbForwardedFor string - } - - testCases := []struct { - name string - headerValues []headerSet - expErr error - }{ - { - name: "multiple SBFF values, single rate limiting value", - headerValues: []headerSet{ - { - sbForwardedFor: "192.168.1.100", - rateLimiting: "60.60.60.60", - }, - { - sbForwardedFor: "192.168.1.200", - rateLimiting: "60.60.60.60", - }, - }, - expErr: nil, - }, - { - name: "single SBFF value, multiple rate limiting values", - headerValues: []headerSet{ - { - sbForwardedFor: "192.168.1.100", - rateLimiting: "60.60.60.60", - }, - { - sbForwardedFor: "192.168.1.100", - rateLimiting: "70.70.70.70", - }, - }, - expErr: apierrors.NewTooManyRequestsError( - apierrors.ErrorCodeOverRequestRateLimit, - "Request rate limit reached", - ), - }, - { - name: "no SBFF value, multiple rate limiting values", - headerValues: []headerSet{ - { - sbForwardedFor: "", - rateLimiting: "60.60.60.60", - }, - { - sbForwardedFor: "", - rateLimiting: "70.70.70.70", - }, - }, - expErr: nil, - }, - { - name: "no SBFF value, single rate limiting value", - headerValues: []headerSet{ - { - sbForwardedFor: "", - rateLimiting: "60.60.60.60", - }, - { - sbForwardedFor: "", - rateLimiting: "60.60.60.60", - }, - }, - expErr: apierrors.NewTooManyRequestsError( - apierrors.ErrorCodeOverRequestRateLimit, - "Request rate limit reached", - ), - }, - { - name: "invalid SBFF value, multiple rate limiting values", - headerValues: []headerSet{ - { - sbForwardedFor: "invalid", - rateLimiting: "60.60.60.60", - }, - { - sbForwardedFor: "invalid", - rateLimiting: "70.70.70.70", - }, - }, - expErr: nil, - }, - { - name: "invalid SBFF value, single rate limiting value", - headerValues: []headerSet{ - { - sbForwardedFor: "invalid", - rateLimiting: "60.60.60.60", - }, - { - sbForwardedFor: "invalid", - rateLimiting: "60.60.60.60", - }, - }, - expErr: apierrors.NewTooManyRequestsError( - apierrors.ErrorCodeOverRequestRateLimit, - "Request rate limit reached", - ), - }, - } - - // This test uses the SBFF middleware to inject the Sb-Forwarded-For IP address value, then - // wraps a handler that calls performRateLimiting and stores the error value. - for _, tc := range testCases { - lmt := tollbooth.NewLimiter( - 1, - &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }, - ) - - var obsErr error - - var handler http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) { - obsErr = ts.API.performRateLimiting(lmt, r) - } - - errCallback := func(r *http.Request, err error) { - } - - middleware := sbff.Middleware(&ts.Config.Security, errCallback) - - wrappedHandler := middleware(handler) - - for _, h := range tc.headerValues { - r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) - - if h.rateLimiting != "" { - r.Header.Set(ts.Config.RateLimitHeader, h.rateLimiting) - } - - if h.sbForwardedFor != "" { - r.Header.Set(sbff.HeaderName, h.sbForwardedFor) - } - - wrappedHandler.ServeHTTP(nil, r) - } - - require.ErrorIs(ts.T(), obsErr, tc.expErr) - } - -} - -func (ts *MiddlewareTestSuite) TestPerformRateLimitingWithHeader() { - ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting" - - tests := []struct { - name string - headerValues []string - expError error - }{ - { - name: "no value", - headerValues: []string{ - "", - "", - }, - expError: nil, - }, - { - name: "single end user value", - headerValues: []string{ - "192.168.1.100", - "192.168.1.100", - }, - expError: apierrors.NewTooManyRequestsError( - apierrors.ErrorCodeOverRequestRateLimit, - "Request rate limit reached", - ), - }, - { - name: "same end user value, multiple proxies", - headerValues: []string{ - "2600:cafe:beef::1,192.168.1.100", - "2600:cafe:beef::1,192.168.1.200", - }, - expError: apierrors.NewTooManyRequestsError( - apierrors.ErrorCodeOverRequestRateLimit, - "Request rate limit reached", - ), - }, - { - name: "multiple end user values, single proxy", - headerValues: []string{ - "2600:cafe:beef::1,192.168.1.100", - "3700:dead:abcd::2,192.168.1.100", - }, - expError: nil, - }, - { - name: "same end user value, multiple proxies, with whitespace", - headerValues: []string{ - "2600:cafe:beef::1 ,192.168.1.100", - "2600:cafe:beef::1 , 192.168.1.200", - }, - expError: apierrors.NewTooManyRequestsError( - apierrors.ErrorCodeOverRequestRateLimit, - "Request rate limit reached", - ), - }, - { - name: "empty header, all whitespace", - headerValues: []string{ - " ", - }, - expError: nil, - }, - { - name: "empty first key, no whitespace", - headerValues: []string{ - ",192.168.1.100", - }, - expError: nil, - }, - { - name: "empty first key, with whitespace", - headerValues: []string{ - " ,192.168.1.100", - }, - expError: nil, - }, - } - - for _, tt := range tests { - // Trigger a rate limiting error if we see the same end-user key twice in the same - // test case - lmt := tollbooth.NewLimiter( - 1, - &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }, - ) - - var obsError error - - for _, h := range tt.headerValues { - req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - req.Header.Add(ts.Config.RateLimitHeader, h) - obsError = ts.API.performRateLimiting(lmt, req) - } - - require.ErrorIs(ts.T(), obsError, tt.expError, "error for test '%s'", tt.name) - } -} - func (ts *MiddlewareTestSuite) TestLimitHandler() { ts.Config.RateLimitHeader = "X-Rate-Limit" lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{ diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index c61de40852..770f3f50c3 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -127,7 +127,7 @@ func (s *Server) AdminOAuthServerClientRegister(w http.ResponseWriter, r *http.R client, plaintextSecret, err := s.registerOAuthServerClient(ctx, ¶ms) if err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "%s", err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) } response := oauthServerClientToResponse(client) @@ -156,7 +156,7 @@ func (s *Server) OAuthServerClientDynamicRegister(w http.ResponseWriter, r *http client, plaintextSecret, err := s.registerOAuthServerClient(ctx, ¶ms) if err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "%s", err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) } response := oauthServerClientToResponse(client) diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index dff27c7741..3af95d0644 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -9,7 +9,6 @@ import ( "fmt" "net/url" "slices" - "strings" "time" "github.com/gofrs/uuid" @@ -165,20 +164,13 @@ func (p *OAuthServerClientRegisterParams) validate() error { // Validate consistency between client_type and token_endpoint_auth_method if err := ValidateClientTypeConsistency(p.ClientType, p.TokenEndpointAuthMethod); err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "%s", err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) } return nil } -// validateRedirectURI validates OAuth 2.1 redirect URIs as specific in -// -// * https://tools.ietf.org/html/rfc6749#section-3.1.2 -// - The redirection endpoint URI MUST be an absolute URI as defined by [RFC3986] Section 4.3. -// - The endpoint URI MUST NOT include a fragment component. -// - https://tools.ietf.org/html/rfc3986#section-4.3 -// absolute-URI = scheme ":" hier-part [ "?" query ] -// - https://tools.ietf.org/html/rfc6819#section-5.1.1 +// validateRedirectURI validates OAuth 2.1 redirect URIs func validateRedirectURI(uri string) error { if uri == "" { return fmt.Errorf("redirect URI cannot be empty") @@ -194,23 +186,16 @@ func validateRedirectURI(uri string) error { return fmt.Errorf("must have scheme and host") } - // Block dangerous URI schemes that can lead to XSS or token leakage - dangerousSchemes := []string{"javascript", "data", "file", "vbscript", "about", "blob"} - for _, dangerous := range dangerousSchemes { - if strings.EqualFold(parsedURL.Scheme, dangerous) { - return fmt.Errorf("scheme '%s' is not allowed for security reasons", parsedURL.Scheme) - } - } - - // Only restrict HTTP (not HTTPS or custom schemes) - // HTTP is only allowed for localhost/loopback addresses + // Check scheme requirements if parsedURL.Scheme == "http" { + // HTTP only allowed for localhost host := parsedURL.Hostname() - if host != "localhost" && host != "127.0.0.1" && host != "::1" { + if host != "localhost" && host != "127.0.0.1" { return fmt.Errorf("HTTP scheme only allowed for localhost") } + } else if parsedURL.Scheme != "https" { + return fmt.Errorf("scheme must be HTTPS or HTTP (localhost only)") } - // All other schemes (https, custom schemes like myapp://* etc.) are allowed // Must not have fragment if parsedURL.Fragment != "" { diff --git a/internal/api/oauthserver/service_test.go b/internal/api/oauthserver/service_test.go index 96d23ef3cb..d409128c5e 100644 --- a/internal/api/oauthserver/service_test.go +++ b/internal/api/oauthserver/service_test.go @@ -254,60 +254,21 @@ func (ts *OAuthServiceTestSuite) TestRedirectURIValidation() { shouldError bool errorMsg string }{ - // Valid HTTPS URIs { name: "Valid HTTPS URI", uri: "https://example.com/callback", shouldError: false, }, - { - name: "Valid HTTPS URI with port", - uri: "https://example.com:8443/callback", - shouldError: false, - }, - { - name: "Valid HTTPS URI with query params", - uri: "https://example.com/callback?foo=bar", - shouldError: false, - }, - // Valid HTTP localhost URIs { name: "Valid localhost HTTP URI", uri: "http://localhost:3000/callback", shouldError: false, }, - { - name: "Valid localhost HTTP URI without port", - uri: "http://localhost/callback", - shouldError: false, - }, { name: "Valid 127.0.0.1 HTTP URI", uri: "http://127.0.0.1:8080/callback", shouldError: false, }, - { - name: "Valid IPv6 localhost HTTP URI", - uri: "http://[::1]:8080/callback", - shouldError: false, - }, - // Valid custom URI schemes (native apps) - { - name: "Valid custom scheme - myapp", - uri: "myapp://callback", - shouldError: false, - }, - { - name: "Valid custom scheme - com.example.app", - uri: "com.example.app://oauth/callback", - shouldError: false, - }, - { - name: "Valid custom scheme with port and path", - uri: "myapp://localhost:8080/callback", - shouldError: false, - }, - // Invalid cases { name: "Invalid empty URI", uri: "", @@ -315,14 +276,14 @@ func (ts *OAuthServiceTestSuite) TestRedirectURIValidation() { errorMsg: "redirect URI cannot be empty", }, { - name: "Invalid HTTP non-localhost", - uri: "http://example.com/callback", + name: "Invalid scheme", + uri: "ftp://example.com/callback", shouldError: true, - errorMsg: "HTTP scheme only allowed for localhost", + errorMsg: "scheme must be HTTPS or HTTP (localhost only)", }, { - name: "Invalid HTTP with IP address (not loopback)", - uri: "http://192.168.1.1/callback", + name: "Invalid HTTP non-localhost", + uri: "http://example.com/callback", shouldError: true, errorMsg: "HTTP scheme only allowed for localhost", }, @@ -333,72 +294,11 @@ func (ts *OAuthServiceTestSuite) TestRedirectURIValidation() { errorMsg: "fragment not allowed in redirect URI", }, { - name: "Invalid custom scheme with fragment", - uri: "myapp://callback#fragment", - shouldError: true, - errorMsg: "fragment not allowed in redirect URI", - }, - { - name: "Invalid URI format - no scheme", - uri: "example.com/callback", - shouldError: true, - errorMsg: "must have scheme and host", - }, - { - name: "Invalid URI format - no host", - uri: "https:///callback", - shouldError: true, - errorMsg: "must have scheme and host", - }, - { - name: "Invalid URI format - completely invalid", + name: "Invalid URI format", uri: "not-a-uri", shouldError: true, errorMsg: "must have scheme and host", }, - // Dangerous URI schemes - { - name: "Invalid dangerous scheme - javascript", - uri: "javascript://example.com/alert(1)", - shouldError: true, - errorMsg: "scheme 'javascript' is not allowed for security reasons", - }, - { - name: "Invalid dangerous scheme - data", - uri: "data://text/html,", - shouldError: true, - errorMsg: "scheme 'data' is not allowed for security reasons", - }, - { - name: "Invalid dangerous scheme - file", - uri: "file://localhost/etc/passwd", - shouldError: true, - errorMsg: "scheme 'file' is not allowed for security reasons", - }, - { - name: "Invalid dangerous scheme - vbscript", - uri: "vbscript://example.com/malicious", - shouldError: true, - errorMsg: "scheme 'vbscript' is not allowed for security reasons", - }, - { - name: "Invalid dangerous scheme - about", - uri: "about://blank", - shouldError: true, - errorMsg: "scheme 'about' is not allowed for security reasons", - }, - { - name: "Invalid dangerous scheme - blob", - uri: "blob://example.com/something", - shouldError: true, - errorMsg: "scheme 'blob' is not allowed for security reasons", - }, - { - name: "Invalid dangerous scheme - case insensitive JAVASCRIPT", - uri: "JAVASCRIPT://example.com/alert(1)", - shouldError: true, - errorMsg: "is not allowed for security reasons", - }, } for _, tc := range testCases { diff --git a/internal/api/password.go b/internal/api/password.go index 5ce20149b2..47cc6755df 100644 --- a/internal/api/password.go +++ b/internal/api/password.go @@ -29,11 +29,7 @@ func (a *API) checkPasswordStrength(ctx context.Context, password string) error config := a.config if len(password) > MaxPasswordLength { - return apierrors.NewBadRequestError( - apierrors.ErrorCodeValidationFailed, - "Password cannot be longer than %v characters", - MaxPasswordLength, - ) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Password cannot be longer than %v characters", MaxPasswordLength)) } var messages, reasons []string diff --git a/internal/api/phone.go b/internal/api/phone.go index fbd940bcbf..ab21bf6cd7 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -71,7 +71,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use // intentionally keeping this before the test OTP, so that the behavior // of regular and test OTPs is similar if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { - return "", apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, "%s", generateFrequencyLimitErrorMessage(sentAt, config.Sms.MaxFrequency)) + return "", apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, config.Sms.MaxFrequency)) } now := time.Now() diff --git a/internal/api/provider/provider.go b/internal/api/provider/provider.go index f7acb91e6c..9ced63937b 100644 --- a/internal/api/provider/provider.go +++ b/internal/api/provider/provider.go @@ -134,7 +134,7 @@ func makeRequest(ctx context.Context, tok *oauth2.Token, g *oauth2.Config, url s res.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices { - return httpError(res.StatusCode, "%s", string(bodyBytes)) + return httpError(res.StatusCode, string(bodyBytes)) } if err := json.NewDecoder(res.Body).Decode(dst); err != nil { diff --git a/internal/api/signup.go b/internal/api/signup.go index 3a9773ce0a..89c79f889a 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -168,7 +168,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { msg = "Sign up with this provider not possible" } - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "%s", msg) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, msg) } if err != nil && !models.IsNotFoundError(err) { diff --git a/internal/api/token.go b/internal/api/token.go index 4345d0def7..43863784cf 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -170,7 +170,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri return err } } - return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "%s", output.Message) + return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, output.Message) } } if !isValidPassword { @@ -246,7 +246,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) return err } if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeBadCodeVerifier, "%s", err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadCodeVerifier, err.Error()) } var token *AccessTokenResponse diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index c6e37ff78f..e62940abc9 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -122,7 +122,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !allowed { - return nil, false, "", nil, false, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Custom OIDC provider %q not allowed", p.Provider) + return nil, false, "", nil, false, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) } cfg = &conf.OAuthProviderConfiguration{ @@ -132,7 +132,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !cfg.Enabled { - return nil, false, "", nil, false, apierrors.NewBadRequestError(apierrors.ErrorCodeProviderDisabled, "Provider (issuer %q) is not enabled", issuer) + return nil, false, "", nil, false, apierrors.NewBadRequestError(apierrors.ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) } oidcCtx := ctx diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 356edc0da3..e3390c94e2 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -741,43 +741,6 @@ end; $$ language plpgsql;`, "user_metadata": nil, }, shouldError: false, - }, { - desc: "Modify amr to be array of strings", - uri: "pg-functions://postgres/auth/custom_access_token_amr_strings", - hookFunctionSQL: ` -create or replace function custom_access_token_amr_strings(input jsonb) -returns jsonb as $$ -declare - result jsonb; -begin - input := jsonb_set(input, '{claims,amr}', '["password", "mfa"]'::jsonb); - result := jsonb_build_object('claims', input->'claims'); - return result; -end; $$ language plpgsql;`, - expectedClaims: map[string]interface{}{ - "amr": []interface{}{"password", "mfa"}, - }, - shouldError: false, - }, { - desc: "Modify amr to be array of objects", - uri: "pg-functions://postgres/auth/custom_access_token_amr_objects", - hookFunctionSQL: ` -create or replace function custom_access_token_amr_objects(input jsonb) -returns jsonb as $$ -declare - result jsonb; -begin - input := jsonb_set(input, '{claims,amr}', '[{"method": "password"}, {"method": "mfa"}]'::jsonb); - result := jsonb_build_object('claims', input->'claims'); - return result; -end; $$ language plpgsql;`, - expectedClaims: map[string]interface{}{ - "amr": []interface{}{ - map[string]interface{}{"method": "password"}, - map[string]interface{}{"method": "mfa"}, - }, - }, - shouldError: false, }, } for _, c := range cases { diff --git a/internal/api/web3.go b/internal/api/web3.go index 8971846e80..3928b1a25f 100644 --- a/internal/api/web3.go +++ b/internal/api/web3.go @@ -73,7 +73,7 @@ func (a *API) web3GrantSolana(ctx context.Context, w http.ResponseWriter, r *htt parsedMessage, err := siws.ParseMessage(params.Message) if err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "%s", err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) } if !parsedMessage.VerifySignature(signatureBytes) { @@ -219,7 +219,7 @@ func (a *API) web3GrantEthereum(ctx context.Context, w http.ResponseWriter, r *h parsedMessage, err := siwe.ParseMessage(params.Message) if err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "%s", err.Error()) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) } if !parsedMessage.VerifySignature(params.Signature) { diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 3e397be69c..b49e54f0c4 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -731,7 +731,6 @@ type SecurityConfiguration struct { RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"` UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"` ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"` - SbForwardedForEnabled bool `json:"sb_forwarded_for_enabled" split_words:"true" default:"false"` DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"` } diff --git a/internal/conf/saml.go b/internal/conf/saml.go index 9f0ec1ce15..a38929ebe5 100644 --- a/internal/conf/saml.go +++ b/internal/conf/saml.go @@ -75,26 +75,16 @@ func (c *SAMLConfiguration) Validate() error { // PopulateFields fills the configuration details based off the provided // parameters. func (c *SAMLConfiguration) PopulateFields(externalURL string) error { - certTemplate, err := c.populateFields(externalURL) - if err != nil { - return err - } - return c.createCertificate(certTemplate) -} - -// PopulateFields fills the configuration details based off the provided -// parameters. -func (c *SAMLConfiguration) populateFields(externalURL string) (*x509.Certificate, error) { // errors are intentionally ignored since they should have been handled // within #Validate() bytes, err := base64.StdEncoding.DecodeString(c.PrivateKey) if err != nil { - return nil, fmt.Errorf("saml: PopulateFields: invalid base64: %w", err) + return fmt.Errorf("saml: PopulateFields: invalid base64: %w", err) } privateKey, err := x509.ParsePKCS1PrivateKey(bytes) if err != nil { - return nil, fmt.Errorf("saml: PopulateFields: invalid private key: %w", err) + return fmt.Errorf("saml: PopulateFields: invalid private key: %w", err) } c.RSAPrivateKey = privateKey @@ -102,7 +92,7 @@ func (c *SAMLConfiguration) populateFields(externalURL string) (*x509.Certificat parsedURL, err := url.ParseRequestURI(externalURL) if err != nil { - return nil, fmt.Errorf("saml: unable to parse external URL for SAML, check API_EXTERNAL_URL: %w", err) + return fmt.Errorf("saml: unable to parse external URL for SAML, check API_EXTERNAL_URL: %w", err) } host := "" @@ -135,7 +125,7 @@ func (c *SAMLConfiguration) populateFields(externalURL string) (*x509.Certificat if c.AllowEncryptedAssertions { certTemplate.KeyUsage = certTemplate.KeyUsage | x509.KeyUsageDataEncipherment } - return certTemplate, nil + return c.createCertificate(certTemplate) } func (c *SAMLConfiguration) createCertificate(certTemplate *x509.Certificate) error { diff --git a/internal/conf/saml_test.go b/internal/conf/saml_test.go index 24657a9345..aa9c1262c1 100644 --- a/internal/conf/saml_test.go +++ b/internal/conf/saml_test.go @@ -4,7 +4,6 @@ import ( "crypto/x509" "encoding/base64" "fmt" - "math/big" "testing" "github.com/stretchr/testify/require" @@ -98,13 +97,11 @@ func TestSAMLConfiguration(t *testing.T) { t.Run("PopulateFieldInvalidCreateCertificate", func(t *testing.T) { c := &SAMLConfiguration{ Enabled: true, - PrivateKey: validPrivateKey, + PrivateKey: base64.StdEncoding.EncodeToString([]byte("INVALID")), } - certTemplate, err := c.populateFields("https://projectref.supabase.co") - require.NoError(t, err) - certTemplate.SerialNumber = big.NewInt(-1) - err = c.createCertificate(certTemplate) + tmpl := &x509.Certificate{} + err := c.createCertificate(tmpl) require.Error(t, err) }) diff --git a/internal/e2e/e2ehooks/e2ehooks.go b/internal/e2e/e2ehooks/e2ehooks.go index 2b56ca4f70..637f4f090d 100644 --- a/internal/e2e/e2ehooks/e2ehooks.go +++ b/internal/e2e/e2ehooks/e2ehooks.go @@ -76,6 +76,7 @@ func NewHook(name v0hooks.Name) *Hook { name: name, } + //exhaustive:ignore switch name { case v0hooks.CustomizeAccessToken: // This hooks returns the exact claims given. @@ -187,6 +188,7 @@ func NewHookRecorder() *HookRecorder { } o.mux.HandleFunc("POST /hooks/{hook}", func(w http.ResponseWriter, r *http.Request) { + //exhaustive:ignore switch v0hooks.Name(r.PathValue("hook")) { case v0hooks.BeforeUserCreated: o.BeforeUserCreated.ServeHTTP(w, r) diff --git a/internal/hooks/hookshttp/hookshttp.go b/internal/hooks/hookshttp/hookshttp.go index 1f7d8016b5..36df0977fe 100644 --- a/internal/hooks/hookshttp/hookshttp.go +++ b/internal/hooks/hookshttp/hookshttp.go @@ -157,10 +157,11 @@ func (o *Dispatcher) runHTTPHook( rsp, err := client.Do(req) if err != nil && errors.Is(err, context.DeadlineExceeded) { - return nil, apierrors.NewUnprocessableEntityError( - apierrors.ErrorCodeHookTimeout, + msg := fmt.Sprintf( "Failed to reach hook within maximum time of %f seconds", o.hookTimeout.Seconds()) + return nil, apierrors.NewUnprocessableEntityError( + apierrors.ErrorCodeHookTimeout, msg) } else if err != nil { if terr, ok := err.(net.Error); ok && terr.Timeout() || i < o.hookRetries-1 { @@ -168,10 +169,11 @@ func (o *Dispatcher) runHTTPHook( "Request timed out for attempt %d with err %s", i, err) select { case <-ctx.Done(): - return nil, apierrors.NewUnprocessableEntityError( - apierrors.ErrorCodeHookTimeout, + msg := fmt.Sprintf( "Failed to reach hook within maximum time of %f seconds", o.hookTimeout.Seconds()) + return nil, apierrors.NewUnprocessableEntityError( + apierrors.ErrorCodeHookTimeout, msg) case <-time.After(o.hookBackoff): } continue @@ -196,18 +198,14 @@ func (o *Dispatcher) runHTTPHook( mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { + msg := fmt.Sprintf("Invalid Content-Type header: %s", err.Error()) return nil, apierrors.NewBadRequestError( - apierrors.ErrorCodeHookPayloadInvalidContentType, - "Invalid Content-Type header: %s", - err.Error(), - ) + apierrors.ErrorCodeHookPayloadInvalidContentType, msg) } if mediaType != "application/json" { return nil, apierrors.NewBadRequestError( apierrors.ErrorCodeHookPayloadInvalidContentType, - "Invalid JSON response. Received content-type: %s", - contentType, - ) + "Invalid JSON response. Received content-type: "+contentType) } limitedReader := io.LimitedReader{R: rsp.Body, N: o.limitResponse} @@ -218,11 +216,11 @@ func (o *Dispatcher) runHTTPHook( if limitedReader.N <= 0 { // check if the response body still has excess bytes to be read if n, _ := rsp.Body.Read(make([]byte, 1)); n > 0 { - return nil, apierrors.NewUnprocessableEntityError( - apierrors.ErrorCodeHookPayloadOverSizeLimit, + msg := fmt.Sprintf( "Payload size exceeded size limit of %d bytes", - o.limitResponse, - ) + o.limitResponse) + return nil, apierrors.NewUnprocessableEntityError( + apierrors.ErrorCodeHookPayloadOverSizeLimit, msg) } } return body, nil diff --git a/internal/indexworker/indexworker.go b/internal/indexworker/indexworker.go index 0011da7249..bed42f47fd 100644 --- a/internal/indexworker/indexworker.go +++ b/internal/indexworker/indexworker.go @@ -2,6 +2,7 @@ package indexworker import ( "context" + "database/sql" "errors" "fmt" "log" @@ -10,6 +11,7 @@ import ( "time" "github.com/gobuffalo/pop/v6" + pkgerrors "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/conf" ) @@ -33,21 +35,24 @@ const ( // Returns an error either from index creation failure (partial or complete) or if the advisory lock // could not be acquired. func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *logrus.Entry) error { - u, err := url.Parse(config.DB.URL) - if err != nil { - le.WithError(err).Fatal("Error parsing db connection url") - } - if config.DB.Driver == "" && config.DB.URL != "" { + u, err := url.Parse(config.DB.URL) + if err != nil { + le.Fatalf("Error parsing db connection url: %+v", err) + } config.DB.Driver = u.Scheme } - q := u.Query() - q.Add("application_name", "auth_indexworker") - u.RawQuery = q.Encode() + u, _ := url.Parse(config.DB.URL) + processedUrl := config.DB.URL + if len(u.Query()) != 0 { + processedUrl = fmt.Sprintf("%s&application_name=auth_index_worker", processedUrl) + } else { + processedUrl = fmt.Sprintf("%s?application_name=auth_index_worker", processedUrl) + } deets := &pop.ConnectionDetails{ Dialect: config.DB.Driver, - URL: u.String(), + URL: processedUrl, } deets.Options = map[string]string{ "Namespace": config.DB.Namespace, @@ -104,7 +109,28 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo } }() - indexes := getUsersIndexes(config.DB.Namespace) + // Ensure either auth_trgm or pg_trgm extension is installed + extName, err := ensureTrgmExtension(db, config.DB.Namespace, le) + if err != nil { + le.WithFields(logrus.Fields{ + "outcome": OutcomeFailure, + "code": "trgm_extension_unavailable", + }).WithError(err).Error("Failed to ensure trgm extension is available") + return err + } + + // Look up which schema the trgm extension is installed in + trgmSchema, err := getTrgmExtensionSchema(db, extName) + if err != nil { + le.WithFields(logrus.Fields{ + "outcome": OutcomeFailure, + "code": "extension_schema_not_found", + "extension": extName, + }).WithError(err).Error("Failed to find extension schema") + return ErrExtensionNotFound + } + + indexes := getUsersIndexes(config.DB.Namespace, trgmSchema) indexNames := make([]string, len(indexes)) for i, idx := range indexes { indexNames[i] = idx.name @@ -214,8 +240,135 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo return nil } +// getTrgmExtensionSchema looks up which schema the specified trgm extension is installed in +func getTrgmExtensionSchema(db *pop.Connection, extName string) (string, error) { + var schema string + query := ` + SELECT extnamespace::regnamespace::text AS schema_name + FROM pg_extension + WHERE extname = $1 + LIMIT 1 + ` + + if err := db.RawQuery(query, extName).First(&schema); err != nil { + return "", fmt.Errorf("failed to find %s extension schema: %w", extName, err) + } + + return schema, nil +} + +// extensionStatus represents the status of an extension from pg_available_extensions +type extensionStatus struct { + Available bool + Installed bool +} + +// getExtensionStatus checks if an extension is available and/or installed +func getExtensionStatus(db *pop.Connection, extName string) (extensionStatus, error) { + var result struct { + Name *string `db:"name"` + InstalledVersion *string `db:"installed_version"` + } + + query := ` + SELECT name, installed_version + FROM pg_available_extensions + WHERE name = $1 + ` + + if err := db.RawQuery(query, extName).First(&result); err != nil { + // If no rows returned, extension is not available + if pkgerrors.Cause(err) == sql.ErrNoRows { + return extensionStatus{Available: false, Installed: false}, nil + } + return extensionStatus{}, fmt.Errorf("failed to check extension status for %s: %w", extName, err) + } + + return extensionStatus{ + Available: result.Name != nil, + Installed: result.InstalledVersion != nil, + }, nil +} + +// installExtension installs the specified extension in the provided schema +func installExtension(db *pop.Connection, extName string, schema string) error { + query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s SCHEMA %s", extName, schema) + if err := db.RawQuery(query).Exec(); err != nil { + return fmt.Errorf("failed to install extension %s in schema %s: %w", extName, schema, err) + } + return nil +} + +// ensureTrgmExtension ensures that either auth_trgm or pg_trgm extension is installed +// It prefers auth_trgm if available, otherwise falls back to pg_trgm +// Returns the name of the installed extension +func ensureTrgmExtension(db *pop.Connection, authSchema string, le *logrus.Entry) (string, error) { + authTrgmStatus, err := getExtensionStatus(db, "auth_trgm") + if err != nil { + return "", fmt.Errorf("failed to check auth_trgm extension status: %w", err) + } + + if authTrgmStatus.Available { + if !authTrgmStatus.Installed { + le.Debug("auth_trgm extension is available but not installed, installing") + + if err := installExtension(db, "auth_trgm", authSchema); err != nil { + le.WithFields(logrus.Fields{ + "outcome": OutcomeFailure, + "code": "extension_install_failed", + "extension": "auth_trgm", + }).WithError(err).Error("Failed to install auth_trgm extension") + + return "", fmt.Errorf("auth_trgm extension is available but failed to install: %w", err) + } + + le.WithFields(logrus.Fields{ + "code": "extension_installed", + "extension": "auth_trgm", + }).Info("Successfully installed auth_trgm extension") + } else { + le.Debug("auth_trgm extension is already installed") + } + + return "auth_trgm", nil + } + + le.Debug("auth_trgm extension is not available, checking pg_trgm") + + pgTrgmStatus, err := getExtensionStatus(db, "pg_trgm") + if err != nil { + return "", fmt.Errorf("failed to check pg_trgm extension status: %w", err) + } + + if !pgTrgmStatus.Available { + return "", fmt.Errorf("neither auth_trgm nor pg_trgm extensions are available") + } + + if !pgTrgmStatus.Installed { + le.Debug("pg_trgm extension is available but not installed, installing") + + if err := installExtension(db, "pg_trgm", "pg_catalog"); err != nil { + le.WithFields(logrus.Fields{ + "code": "extension_install_failed", + "extension": "pg_trgm", + }).WithError(err).Error("Failed to install pg_trgm extension") + + return "", fmt.Errorf("pg_trgm extension is available but failed to install: %w", err) + } + + le.WithFields(logrus.Fields{ + "code": "extension_installed", + "extension": "pg_trgm", + }).Info("Successfully installed pg_trgm extension") + } else { + le.Debug("pg_trgm extension is already installed") + } + + return "pg_trgm", nil +} + // getUsersIndexes returns the list of indexes to create on the users table -func getUsersIndexes(namespace string) []struct { +func getUsersIndexes(namespace, trgmSchema string) []struct { name string query string } { @@ -225,12 +378,17 @@ func getUsersIndexes(namespace string) []struct { name string query string }{ - // for exact-match queries, sorting, and prefix searches on email (e.g., email LIKE 'term%') + // for exact-match queries, sorting, and LIKE '%term%' (trigram) searches on email { name: "idx_users_email", query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_email ON %q.users USING btree (email);`, namespace), }, + { + name: "idx_users_email_trgm", + query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_email_trgm + ON %q.users USING gin (email %s.gin_trgm_ops);`, namespace, trgmSchema), + }, // for range queries and sorting on created_at and last_sign_in_at { name: "idx_users_created_at_desc", @@ -242,12 +400,12 @@ func getUsersIndexes(namespace string) []struct { query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_last_sign_in_at_desc ON %q.users (last_sign_in_at DESC);`, namespace), }, - // for exact-match, sorting, and prefix searches on raw_user_meta_data->>'name' + // trigram indexes on name field in raw_user_meta_data JSONB - enables fast LIKE '%term%' searches { - name: "idx_users_name", - query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_name - ON %q.users USING btree ((raw_user_meta_data->>'name')) - WHERE (raw_user_meta_data->>'name') IS NOT NULL;`, namespace), + name: "idx_users_name_trgm", + query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_name_trgm + ON %q.users USING gin ((raw_user_meta_data->>'name') %s.gin_trgm_ops) + WHERE raw_user_meta_data->>'name' IS NOT NULL;`, namespace, trgmSchema), }, } } diff --git a/internal/indexworker/indexworker_test.go b/internal/indexworker/indexworker_test.go index 2431d47baf..5bd4779deb 100644 --- a/internal/indexworker/indexworker_test.go +++ b/internal/indexworker/indexworker_test.go @@ -55,6 +55,10 @@ func (ts *IndexWorkerTestSuite) SetupSuite() { // Ensure we have a clean state for testing ts.cleanupIndexes() + + // Ensure trigram extension is available + err = ts.db.RawQuery("CREATE EXTENSION IF NOT EXISTS pg_trgm").Exec() + require.NoError(ts.T(), err) } func (ts *IndexWorkerTestSuite) TearDownSuite() { @@ -73,7 +77,7 @@ func (ts *IndexWorkerTestSuite) SetupTest() { } func (ts *IndexWorkerTestSuite) cleanupIndexes() { - indexes := getUsersIndexes(ts.namespace) + indexes := getUsersIndexes(ts.namespace, ts.namespace) for _, idx := range indexes { // Drop any existing indexes (valid or invalid) dropQuery := fmt.Sprintf("DROP INDEX IF EXISTS %q.%s", ts.namespace, idx.name) @@ -87,7 +91,7 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesHappyPath() { err := CreateIndexes(ctx, ts.config, ts.logger) require.NoError(ts.T(), err) - indexes := getUsersIndexes(ts.namespace) + indexes := getUsersIndexes(ts.namespace, ts.namespace) existingIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes)) require.NoError(ts.T(), err) @@ -131,7 +135,7 @@ func (ts *IndexWorkerTestSuite) TestIdempotency() { require.NoError(ts.T(), err) // Get the state after first run - indexes := getUsersIndexes(ts.namespace) + indexes := getUsersIndexes(ts.namespace, ts.namespace) firstRunIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes)) require.NoError(ts.T(), err) require.Equal(ts.T(), len(indexes), len(firstRunIndexes)) @@ -187,7 +191,7 @@ func (ts *IndexWorkerTestSuite) TestOutOfBandIndexRemoval() { require.NoError(ts.T(), err) // Verify all indexes exist - indexes := getUsersIndexes(ts.namespace) + indexes := getUsersIndexes(ts.namespace, ts.namespace) existingIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes)) require.NoError(ts.T(), err) assert.Equal(ts.T(), len(indexes), len(existingIndexes)) @@ -273,7 +277,7 @@ func (ts *IndexWorkerTestSuite) TestConcurrentWorkers() { assert.Equal(ts.T(), numWorkers-1, lockSkipCount, "Other workers should skip due to lock") // Verify all indexes were created successfully - indexes := getUsersIndexes(ts.namespace) + indexes := getUsersIndexes(ts.namespace, ts.namespace) existingIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes)) require.NoError(ts.T(), err) assert.Equal(ts.T(), len(indexes), len(existingIndexes), "All indexes should be created") @@ -302,7 +306,7 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() { require.NoError(ts.T(), err, "Initial CreateIndexes should succeed") // Verify all indexes were created and are valid - indexes := getUsersIndexes(ts.namespace) + indexes := getUsersIndexes(ts.namespace, ts.namespace) initialIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes)) require.NoError(ts.T(), err) assert.Equal(ts.T(), len(indexes), len(initialIndexes), "All indexes should be created initially") @@ -333,7 +337,7 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() { defer manipulatorDB.Close() // Select the first 2 indexes to mark as invalid - allIndexes := getUsersIndexes(ts.namespace) + allIndexes := getUsersIndexes(ts.namespace, ts.namespace) indexesToInvalidate := []string{allIndexes[0].name, allIndexes[1].name} for _, indexName := range indexesToInvalidate { @@ -389,6 +393,54 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() { ts.logger.Infof("Successfully recovered from %d invalid indexes", len(indexesToInvalidate)) } +// TestCreateIndexesWithoutTrgmExtension tests that CreateIndexes installs pg_trgm extension +// when it's available but not installed, and then successfully creates indexes. +func (ts *IndexWorkerTestSuite) TestCreateIndexesWithoutTrgmExtension() { + ctx := context.Background() + + // Drop the pg_trgm extension to simulate it not being installed + dropExtQuery := "DROP EXTENSION IF EXISTS pg_trgm CASCADE" + err := ts.db.RawQuery(dropExtQuery).Exec() + require.NoError(ts.T(), err, "Should be able to drop pg_trgm extension") + + // Verify the extension is dropped + var extensionExists bool + checkExtQuery := "SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm')" + err = ts.db.RawQuery(checkExtQuery).First(&extensionExists) + require.NoError(ts.T(), err) + assert.False(ts.T(), extensionExists, "pg_trgm extension should not exist") + + // Verify no indexes exist initially + indexes := getUsersIndexes(ts.namespace, ts.namespace) + existingIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes)) + require.NoError(ts.T(), err) + assert.Empty(ts.T(), existingIndexes, "No indexes should exist initially") + + // Run CreateIndexes - it should install the pg_trgm extension and create indexes + err = CreateIndexes(ctx, ts.config, ts.logger) + require.NoError(ts.T(), err, "CreateIndexes should succeed by installing the pg_trgm extension") + + // Verify that pg_trgm is now installed + err = ts.db.RawQuery(checkExtQuery).First(&extensionExists) + require.NoError(ts.T(), err) + assert.True(ts.T(), extensionExists, "pg_trgm extension should have been installed") + + // Verify all indexes were created successfully + existingIndexes, err = getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes)) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), len(indexes), len(existingIndexes), "All indexes should have been created") + + for _, idx := range existingIndexes { + assert.True(ts.T(), idx.IsValid, "Index %s should be valid", idx.IndexName) + assert.True(ts.T(), idx.IsReady, "Index %s should be ready", idx.IndexName) + } + + // Restore pg_trgm extension for other tests + createExtQuery := "CREATE EXTENSION IF NOT EXISTS pg_trgm" + err = ts.db.RawQuery(createExtQuery).Exec() + require.NoError(ts.T(), err, "Should be able to restore pg_trgm extension") +} + // Run the test suite func TestIndexWorker(t *testing.T) { suite.Run(t, new(IndexWorkerTestSuite)) diff --git a/internal/mailer/validateclient/validateclient.go b/internal/mailer/validateclient/validateclient.go index 1f5bfd61a7..b194bf1114 100644 --- a/internal/mailer/validateclient/validateclient.go +++ b/internal/mailer/validateclient/validateclient.go @@ -56,51 +56,12 @@ var invalidHostMap = map[string]bool{ // Hundreds of typos per day for this. "gamil.com": true, - "gamai.com": true, // These are not email providers, but people often use them. "anonymous.com": true, "email.com": true, } -// We skip checking hosts for some of the biggest well known public email -// providers, generated via: -// -// test -f "gmass.html" \ -// || wget -O "gmass.html" https://www.gmass.co/domains -// cat "gmass.html" \ -// | pup 'div#status-details a json{}' \ -// | jq -r 'map([(.text | split(" "))[1], .children[0].text]) -// | map("`" + .[0] + ".` : true, // " + .[1]) | join("\n")' \ -// | sed 's| emails sent||g' \ -// | head -20 -// -// Note: -// This only affects the validateHost code, if we have an exact match we don't -// bother to make a dns request. -var hostAllowList = map[string]bool{ - `gmail.com.`: true, // 563,185,814 - `yahoo.com.`: true, // 107,413,999 - `hotmail.com.`: true, // 98,895,904 - `aol.com.`: true, // 31,839,178 - `outlook.com.`: true, // 11,826,511 - `comcast.net.`: true, // 9,663,112 - `icloud.com.`: true, // 9,274,437 - `msn.com.`: true, // 7,101,124 - `hotmail.co.uk.`: true, // 5,456,609 - `sbcglobal.net.`: true, // 5,167,305 - `live.com.`: true, // 5,140,589 - `yahoo.co.in.`: true, // 4,091,798 - `me.com.`: true, // 3,920,969 - `att.net.`: true, // 3,688,388 - `mail.ru.`: true, // 3,583,276 - `bellsouth.net.`: true, // 3,455,683 - `rediffmail.com`: true, // 3,400,300 - `cox.net.`: true, // 3,254,227 - `yahoo.co.uk.`: true, // 3,218,049 - `verizon.net.`: true, // 3,046,288 -} - const ( validateEmailTimeout = 3 * time.Second ) @@ -261,12 +222,6 @@ func (ev *emailValidator) validateStatic(email string) (string, error) { return "", ErrInvalidEmailFormat } - // The mail package supports RFC 5322 addresses which are not valid for - // signup users (e.g. Chris Stockton ). - if ea.Address != email { - return "", ErrInvalidEmailFormat - } - i := strings.LastIndex(ea.Address, "@") if i == -1 { return "", ErrInvalidEmailFormat @@ -336,16 +291,13 @@ func (ev *emailValidator) validateService(ctx context.Context, email string) err return nil } - // 32 bytes is plenty for the payload: {"valid": true|false} dec := json.NewDecoder(io.LimitReader(res.Body, 1<<5)) if err := dec.Decode(&resObject); err != nil { return nil } - // If the resObject contained no "valid" key we ignore the service and - // return a nil error. If the Valid key is present AND set to true we - // will return a nil error, otherwise the valid key was present & false - // so we fall through to ErrInvalidEmailAddress. + // If the object did not contain a valid key we consider the check as + // failed. We _must_ get a valid JSON response with a "valid" field. if resObject.Valid == nil || *resObject.Valid { return nil } @@ -368,29 +320,7 @@ func (ev *emailValidator) validateProviders(name, host string) error { return nil } -// NOTE(cstockton): We could consider using[1] in the future for an additional -// data point. -// -// [1] https://pkg.go.dev/golang.org/x/net/publicsuffix func (ev *emailValidator) validateHost(ctx context.Context, host string) error { - - // As far as I know there is no such thing as valid single label hosts for - // email. This will block anything like: email@a, email@mycompanygltd and - // so on. - if !strings.Contains(host, ".") { - return ErrInvalidEmailDNS - } - - // Require a FQDN to remove possible implict search behavior. - if !strings.HasSuffix(host, ".") { - host = host + "." - } - - // If the host is in the allow list skip mx check all together. - if hostAllowList[host] { - return nil - } - mxs, err := validateEmailResolver.LookupMX(ctx, host) if !isHostNotFound(err) { return ev.validateMXRecords(mxs, nil) diff --git a/internal/mailer/validateclient/validateclient_test.go b/internal/mailer/validateclient/validateclient_test.go index 281826bb80..43e2fbc112 100644 --- a/internal/mailer/validateclient/validateclient_test.go +++ b/internal/mailer/validateclient/validateclient_test.go @@ -3,7 +3,6 @@ package validateclient import ( "context" "fmt" - "net" "net/http" "net/http/httptest" "sync/atomic" @@ -202,10 +201,7 @@ func TestValidateEmailExtended(t *testing.T) { // valid (has mx record) {email: "a@supabase.io"}, {email: "support@supabase.io"}, - {email: "abc@supabase.io"}, - - // valid (RFC 5321 fallback, supabase.co has no mx, but valid A) - {email: "invalid@supabase.co"}, + {email: "chris.stockton@supabase.io"}, // bad format {email: "", err: "invalid_email_format"}, @@ -214,25 +210,6 @@ func TestValidateEmailExtended(t *testing.T) { {email: "@supabase.io", err: "invalid_email_format"}, {email: "test@.supabase.io", err: "invalid_email_format"}, - // invalid providers check doesn't allow short gmails - {email: "short@gmail.com", err: "invalid_email_address"}, - {email: "short@hotmail.com"}, // allow other providers - - // ensure the mail parser does not result in mutated addr - {email: "Chris Stockton ", - err: "invalid_email_format"}, - - // Check dot suffixes are invalid (mutations) - {email: "a@example.org.", err: "invalid_email_format"}, - {email: "a@", err: "invalid_email_format"}, - {email: "a@.", err: "invalid_email_format"}, - {email: "a@a.", err: "invalid_email_format"}, - {email: "a@gmail.com.", err: "invalid_email_format"}, - {email: "a@gmail.com..", err: "invalid_email_format"}, - {email: "aaaaaaaa@.abc", err: "invalid_email_format"}, - {email: "aaaaaaaa@.abc.", err: "invalid_email_format"}, - {email: "aaaaaaaa@.abc.abc", err: "invalid_email_format"}, - // invalid: valid mx records, but invalid and often typed // (invalidEmailMap) {email: "test@email.com", err: "invalid_email_address"}, @@ -258,23 +235,23 @@ func TestValidateEmailExtended(t *testing.T) { // valid but not actually valid and typed a lot {email: "a@invalid", err: "invalid_email_dns"}, {email: "a@a.invalid", err: "invalid_email_dns"}, + {email: "test@invalid", err: "invalid_email_dns"}, // various invalid emails {email: "test@test.localhost", err: "invalid_email_dns"}, {email: "test@invalid.example.com", err: "invalid_email_dns"}, {email: "test@no.such.email.host.supabase.io", err: "invalid_email_dns"}, + + // test blocked mx records + {email: "test@hotmail.com", err: "invalid_email_mx"}, + // this low timeout should simulate a dns timeout, which should // not be treated as an invalid email. {email: "validemail@probablyaaaaaaaanotarealdomain.com", timeout: time.Millisecond}, // likewise for a valid email - {email: "timeout@supabase.io", timeout: time.Millisecond}, - - // invalid dns - {email: "a@a", err: "invalid_email_dns"}, - {email: "a@a.a", err: "invalid_email_dns"}, - {email: "a@abcd", err: "invalid_email_dns"}, + {email: "support@supabase.io", timeout: time.Millisecond}, } cfg := conf.MailerConfiguration{ @@ -291,7 +268,6 @@ func TestValidateEmailExtended(t *testing.T) { ev := newEmailValidator(cfg) - seen := make(map[string]bool) for idx, tc := range cases { func(timeout time.Duration) { if timeout == 0 { @@ -301,11 +277,6 @@ func TestValidateEmailExtended(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - if seen[tc.email] { - t.Fatalf("duplicate email %v in test cases", tc.email) - } - seen[tc.email] = true - now := time.Now() err := ev.Validate(ctx, tc.email) dur := time.Since(now) @@ -323,24 +294,4 @@ func TestValidateEmailExtended(t *testing.T) { }(tc.timeout) } - - // test blocked mx list - { - for _, v := range []string{ - "hotmail-com.olc.protection.outlook.com", - "hotmail-com.olc.protection.outlook.com.", - } { - { - err := ev.validateMXRecords([]*net.MX{{Host: v}}, nil) - require.Error(t, err) - require.Contains(t, err.Error(), ErrInvalidEmailMX.Error()) - } - - { - err := ev.validateMXRecords(nil, []string{v}) - require.Error(t, err) - require.Contains(t, err.Error(), ErrInvalidEmailMX.Error()) - } - } - } } diff --git a/internal/sbff/sbff.go b/internal/sbff/sbff.go deleted file mode 100644 index 33a5126643..0000000000 --- a/internal/sbff/sbff.go +++ /dev/null @@ -1,94 +0,0 @@ -package sbff - -import ( - "context" - "errors" - "net" - "net/http" - "strings" - - "github.com/supabase/auth/internal/conf" -) - -// HeaderName is the Sb-Forwarded-For header name. It is all lowercase here as HTTP header names -// are not case-sensitive. -const HeaderName = "sb-forwarded-for" - -var ( - ctxKeySBFF = &struct{}{} - - ErrHeaderNotFound = errors.New("Sb-Forwarded-For header not found") - ErrHeaderInvalid = errors.New("invalid Sb-Forwarded-For header value") -) - -func parseSBFFHeader(headerVal string) (string, error) { - values := strings.SplitN(headerVal, ",", 2) - key := strings.TrimSpace(values[0]) - if ipAddr := net.ParseIP(key); ipAddr != nil { - return ipAddr.String(), nil - } - - return "", ErrHeaderInvalid -} - -// GetIPAddress returns the value of the IP address in Sb-Forwarded-For as defined by -// SBForwardedForMiddleware. If no value is present in the request context, this function will -// return ("", false). -func GetIPAddress(r *http.Request) (addr string, found bool) { - if ipAddr, ok := r.Context().Value(ctxKeySBFF).(string); ok && ipAddr != "" { - return ipAddr, true - } - - return "", false -} - -// withIPAddress parses the Sb-Forwarded-For header and adds the leftmost value to the -// request context if it is a valid IP address, then returns a new request with modified context. -// If the leftmost value is not a valid IP address or the header is not set, this function returns -// an error. -func withIPAddress(r *http.Request) (*http.Request, error) { - headerVal := r.Header.Get(HeaderName) - if headerVal == "" { - return nil, ErrHeaderNotFound - } - - parsedIPAddr, err := parseSBFFHeader(headerVal) - if err != nil { - return nil, err - } - - ctx := r.Context() - newCtx := context.WithValue(ctx, ctxKeySBFF, parsedIPAddr) - out := r.WithContext(newCtx) - - return out, nil -} - -// Middleware returns a middleware function that parses the Sb-Forwarded-For header -// and adds the leftmost header value to the request context if GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED -// is true and the value is a valid IP address. -func Middleware(cfg *conf.SecurityConfiguration, errCallback func(*http.Request, error)) func(http.Handler) http.Handler { - out := func(next http.Handler) http.Handler { - handlerFunc := func(rw http.ResponseWriter, r *http.Request) { - if !cfg.SbForwardedForEnabled { - next.ServeHTTP(rw, r) - return - } - - reqWithSBFF, err := withIPAddress(r) - switch { - case err == nil: - next.ServeHTTP(rw, reqWithSBFF) - case errors.Is(err, ErrHeaderNotFound): - next.ServeHTTP(rw, r) - default: - errCallback(r, err) - next.ServeHTTP(rw, r) - } - } - - return http.HandlerFunc(handlerFunc) - } - - return out -} diff --git a/internal/sbff/sbff_test.go b/internal/sbff/sbff_test.go deleted file mode 100644 index 6f38bd96f2..0000000000 --- a/internal/sbff/sbff_test.go +++ /dev/null @@ -1,254 +0,0 @@ -package sbff - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" - "github.com/supabase/auth/internal/conf" -) - -func TestParseHeader(t *testing.T) { - testCases := []struct { - name string - headerVal string - expAddr string - expErr error - }{ - { - name: "SingleAddressIPv4", - headerVal: "192.168.1.100", - expAddr: "192.168.1.100", - expErr: nil, - }, - - { - name: "SingleAddressIPv6", - headerVal: "2600:1000:cafe:bead::1", - expAddr: "2600:1000:cafe:bead::1", - expErr: nil, - }, - { - name: "MultipleAddressIPv4", - headerVal: "192.168.1.100,60.60.60.60", - expAddr: "192.168.1.100", - expErr: nil, - }, - { - name: "MultipleAddressIPv4WithWhitespace", - headerVal: "192.168.1.100 ,60.60.60.60", - expAddr: "192.168.1.100", - expErr: nil, - }, - { - name: "HeaderInvalid", - headerVal: "invalid, 60.60.60.60", - expAddr: "", - expErr: ErrHeaderInvalid, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - obsAddr, obsErr := parseSBFFHeader(tc.headerVal) - require.Equal(t, tc.expAddr, obsAddr) - require.ErrorIs(t, obsErr, tc.expErr) - }) - } -} - -func TestWithIPAddress(t *testing.T) { - testCases := []struct { - name string - headerVal string - expAddr string - expErr error - }{ - { - name: "WithHeader", - headerVal: "2600:cafe:bead::1", - expAddr: "2600:cafe:bead::1", - expErr: nil, - }, - { - name: "HeaderNotFound", - headerVal: "", - expAddr: "", - expErr: ErrHeaderNotFound, - }, - { - name: "HeaderInvalid", - headerVal: "invalid", - expAddr: "", - expErr: ErrHeaderInvalid, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) - - if tc.headerVal != "" { - r.Header.Set(HeaderName, tc.headerVal) - } - - obsReq, obsErr := withIPAddress(r) - - if tc.expErr == nil { - require.NotNil(t, obsReq) - - obsAddr, ok := GetIPAddress(obsReq) - require.Equal(t, tc.expAddr, obsAddr) - require.Equal(t, true, ok) - } - - require.ErrorIs(t, obsErr, tc.expErr) - }) - } -} - -func TestGetIPAddress(t *testing.T) { - testCases := []struct { - name string - // ctxVal is any here because context.WithValue accepts any - ctxVal any - expAddr string - expFound bool - }{ - { - name: "WithAddress", - ctxVal: "2600:cafe:bead::1", - expAddr: "2600:cafe:bead::1", - expFound: true, - }, - { - name: "EmptyContext", - ctxVal: nil, - expAddr: "", - expFound: false, - }, - { - name: "NonStringValue", - ctxVal: 1, - expAddr: "", - expFound: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - originalReq := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) - - var ctx context.Context - - if tc.ctxVal == nil { - ctx = originalReq.Context() - } else { - ctx = context.WithValue(originalReq.Context(), ctxKeySBFF, tc.ctxVal) - } - - r := originalReq.WithContext(ctx) - - obsAddr, obsFound := GetIPAddress(r) - - require.Equal(t, tc.expAddr, obsAddr) - require.Equal(t, tc.expFound, obsFound) - }) - } -} - -func TestMiddleware(t *testing.T) { - testCases := []struct { - name string - sbffEnabled bool - headerVal string - expAddr string - expFound bool - expErr error - }{ - { - name: "FlagDisabledHeaderEmpty", - sbffEnabled: false, - headerVal: "", - expAddr: "", - expFound: false, - expErr: nil, - }, - { - name: "FlagDisabledHeaderValid", - sbffEnabled: false, - headerVal: "192.168.1.100", - expAddr: "", - expFound: false, - expErr: nil, - }, - { - name: "FlagDisabledHeaderInvalid", - sbffEnabled: false, - headerVal: "invalid", - expAddr: "", - expFound: false, - expErr: nil, - }, - { - name: "FlagEnabledHeaderEmpty", - sbffEnabled: true, - headerVal: "", - expAddr: "", - expFound: false, - expErr: nil, - }, - { - name: "FlagEnabledHeaderValid", - sbffEnabled: true, - headerVal: "192.168.1.100", - expAddr: "192.168.1.100", - expFound: true, - expErr: nil, - }, - { - name: "FlagEnabledHeaderInvalid", - sbffEnabled: true, - headerVal: "invalid", - expAddr: "", - expFound: false, - expErr: ErrHeaderInvalid, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) - - if tc.headerVal != "" { - r.Header.Set(HeaderName, tc.headerVal) - } - - var cfg conf.SecurityConfiguration - - var handler http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) { - obsAddr, obsFound := GetIPAddress(r) - require.Equal(t, tc.expAddr, obsAddr) - require.Equal(t, tc.expFound, obsFound) - } - - errCallback := func(r *http.Request, err error) { - if tc.expErr == nil { - t.Fatal("error callback called when expected error is nil") - } - - require.ErrorIs(t, err, tc.expErr) - } - - cfg.SbForwardedForEnabled = tc.sbffEnabled - - middlewareFn := Middleware(&cfg, errCallback) - - wrappedHandler := middlewareFn(handler) - - wrappedHandler.ServeHTTP(nil, r) - }) - } -} diff --git a/internal/tokens/service.go b/internal/tokens/service.go index 3e3d7cd08d..a7552865d1 100644 --- a/internal/tokens/service.go +++ b/internal/tokens/service.go @@ -2,7 +2,6 @@ package tokens import ( "context" - "encoding/json" "fmt" mathRand "math/rand" "net/http" @@ -27,47 +26,6 @@ import ( const retryLoopDuration = 5.0 -// AMRClaim supports unmarshalling AMR as either strings or AMREntry objects. -type AMRClaim []models.AMREntry - -// UnmarshalJSON accepts either an array of strings or AMREntry objects. -func (a *AMRClaim) UnmarshalJSON(data []byte) error { - // Handle null explicitly - null cannot be unmarshaled into a slice - if len(data) > 0 { - trimmed := strings.TrimSpace(string(data)) - if trimmed == "null" { - *a = AMRClaim{} - return nil - } - } - - var rawItems []json.RawMessage - if err := json.Unmarshal(data, &rawItems); err != nil { - return err - } - - entries := make([]models.AMREntry, 0, len(rawItems)) - for _, item := range rawItems { - var method string - if err := json.Unmarshal(item, &method); err == nil { - entries = append(entries, models.AMREntry{ - Method: method, - Timestamp: time.Now().Unix(), - }) - continue - } - - var entry models.AMREntry - if err := json.Unmarshal(item, &entry); err != nil { - return err - } - entries = append(entries, entry) - } - - *a = entries - return nil -} - // AccessTokenClaims is a struct thats used for JWT claims type AccessTokenClaims struct { jwt.RegisteredClaims @@ -77,7 +35,7 @@ type AccessTokenClaims struct { UserMetaData map[string]interface{} `json:"user_metadata"` Role string `json:"role"` AuthenticatorAssuranceLevel string `json:"aal,omitempty"` - AuthenticationMethodReference AMRClaim `json:"amr,omitempty"` + AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"` SessionId string `json:"session_id,omitempty"` IsAnonymous bool `json:"is_anonymous"` ClientID string `json:"client_id,omitempty"` @@ -206,7 +164,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, if models.IsNotFoundError(err) { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenNotFound, "Invalid Refresh Token: Refresh Token Not Found") } - return nil, apierrors.NewInternalServerError("%s", err.Error()) + return nil, apierrors.NewInternalServerError(err.Error()) } responseHeaders.Set("sb-auth-user-id", user.ID.String()) @@ -283,7 +241,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, retry = true return terr } - return apierrors.NewInternalServerError("%s", terr.Error()) + return apierrors.NewInternalServerError(terr.Error()) } // Validate OAuth client consistency between session and current request @@ -317,7 +275,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, retry = true return terr } else if terr != nil { - return apierrors.NewInternalServerError("%s", terr.Error()) + return apierrors.NewInternalServerError(terr.Error()) } sessionTag := session.DetermineTag(config.Sessions.Tags) @@ -367,7 +325,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, if token.Revoked { activeRefreshToken, terr := session.FindCurrentlyActiveRefreshToken(tx) if terr != nil && !models.IsNotFoundError(terr) { - return apierrors.NewInternalServerError("%s", terr.Error()) + return apierrors.NewInternalServerError(terr.Error()) } if activeRefreshToken != nil && activeRefreshToken.Parent.String() == token.Token { @@ -391,7 +349,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, if config.Security.RefreshTokenRotationEnabled { // Revoke all tokens in token family if err := models.RevokeTokenFamily(tx, token); err != nil { - return apierrors.NewInternalServerError("%s", err.Error()) + return apierrors.NewInternalServerError(err.Error()) } } @@ -993,10 +951,7 @@ const MinimumViableTokenSchema = `{ "amr": { "type": "array", "items": { - "anyOf": [ - {"type": "string"}, - {"type": "object"} - ] + "type": "object" } }, "session_id": { diff --git a/internal/tokens/service_test.go b/internal/tokens/service_test.go index 8c409934b8..2948b9cd4f 100644 --- a/internal/tokens/service_test.go +++ b/internal/tokens/service_test.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "encoding/base64" - "encoding/json" "net/http" "strconv" "strings" @@ -1023,69 +1022,3 @@ func (ts *IDTokenTestSuite) TestIDTokenWithMultipleScopes() { phoneNumber, hasPhone := claims["phone_number"] require.False(ts.T(), hasPhone || (phoneNumber != nil && phoneNumber != ""), "phone_number claim should not be present without phone scope") } - -func TestAMRClaimUnmarshal(t *testing.T) { - t.Run("mixed string and object formats", func(t *testing.T) { - var claim AMRClaim - before := time.Now().Unix() - - err := json.Unmarshal([]byte(`["password", {"method":"totp","timestamp":123,"provider":"webauthn"}]`), &claim) - require.NoError(t, err) - require.Len(t, claim, 2) - - require.Equal(t, "password", claim[0].Method) - require.GreaterOrEqual(t, claim[0].Timestamp, before) - require.LessOrEqual(t, claim[0].Timestamp, time.Now().Unix()) - require.Empty(t, claim[0].Provider, "string format should not have provider") - - require.Equal(t, "totp", claim[1].Method) - require.Equal(t, int64(123), claim[1].Timestamp) - require.Equal(t, "webauthn", claim[1].Provider, "provider should be preserved from object format") - }) - - t.Run("object with provider", func(t *testing.T) { - var claim AMRClaim - err := json.Unmarshal([]byte(`[{"method":"sso","timestamp":456,"provider":"saml"}]`), &claim) - require.NoError(t, err) - require.Len(t, claim, 1) - require.Equal(t, "sso", claim[0].Method) - require.Equal(t, int64(456), claim[0].Timestamp) - require.Equal(t, "saml", claim[0].Provider, "provider should be preserved") - }) - - t.Run("object without provider", func(t *testing.T) { - var claim AMRClaim - err := json.Unmarshal([]byte(`[{"method":"password","timestamp":789}]`), &claim) - require.NoError(t, err) - require.Len(t, claim, 1) - require.Equal(t, "password", claim[0].Method) - require.Equal(t, int64(789), claim[0].Timestamp) - require.Empty(t, claim[0].Provider, "provider should be empty when not provided") - }) - - t.Run("all strings", func(t *testing.T) { - var claim AMRClaim - before := time.Now().Unix() - err := json.Unmarshal([]byte(`["password", "totp"]`), &claim) - require.NoError(t, err) - require.Len(t, claim, 2) - require.Equal(t, "password", claim[0].Method) - require.Equal(t, "totp", claim[1].Method) - require.GreaterOrEqual(t, claim[0].Timestamp, before) - require.Empty(t, claim[0].Provider) - require.Empty(t, claim[1].Provider) - }) - - t.Run("all objects", func(t *testing.T) { - var claim AMRClaim - err := json.Unmarshal([]byte(`[{"method":"password","timestamp":100},{"method":"totp","timestamp":200,"provider":"webauthn"}]`), &claim) - require.NoError(t, err) - require.Len(t, claim, 2) - require.Equal(t, "password", claim[0].Method) - require.Equal(t, int64(100), claim[0].Timestamp) - require.Empty(t, claim[0].Provider) - require.Equal(t, "totp", claim[1].Method) - require.Equal(t, int64(200), claim[1].Timestamp) - require.Equal(t, "webauthn", claim[1].Provider, "provider should be preserved") - }) -} diff --git a/internal/utilities/request.go b/internal/utilities/request.go index bd38c73819..fcfac8287d 100644 --- a/internal/utilities/request.go +++ b/internal/utilities/request.go @@ -10,10 +10,11 @@ import ( "strings" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/sbff" ) -func getIPAddressWithXFF(r *http.Request) string { +// GetIPAddress returns the real IP address of the HTTP request. It parses the +// X-Forwarded-For header. +func GetIPAddress(r *http.Request) string { if r.Header != nil { xForwardedFor := r.Header.Get("X-Forwarded-For") if xForwardedFor != "" { @@ -44,15 +45,6 @@ func getIPAddressWithXFF(r *http.Request) string { return ip } -// GetIPAddress returns the real IP address of the HTTP request. -func GetIPAddress(r *http.Request) string { - if sbffAddr, ok := sbff.GetIPAddress(r); ok { - return sbffAddr - } - - return getIPAddressWithXFF(r) -} - // GetBodyBytes reads the whole request body properly into a byte array. func GetBodyBytes(req *http.Request) ([]byte, error) { if req.Body == nil || req.Body == http.NoBody { diff --git a/internal/utilities/request_test.go b/internal/utilities/request_test.go index 91ae97fac6..c1d1a66219 100644 --- a/internal/utilities/request_test.go +++ b/internal/utilities/request_test.go @@ -7,69 +7,9 @@ import ( "github.com/stretchr/testify/require" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/sbff" ) -func TestGetIPAddressWithSBFF(t *tst.T) { - testCases := []struct { - name string - remoteAddr string - headerVal string - expAddr string - }{ - { - name: "ValidSBFF", - remoteAddr: "60.60.60.60", - headerVal: "192.168.1.100", - expAddr: "192.168.1.100", - }, - { - name: "MissingSBFF", - remoteAddr: "60.60.60.60", - headerVal: "", - expAddr: "60.60.60.60", - }, - { - name: "InvalidSBFF", - remoteAddr: "60.60.60.60", - headerVal: "invalid", - expAddr: "60.60.60.60", - }, - } - - config := conf.SecurityConfiguration{ - SbForwardedForEnabled: true, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *tst.T) { - var handler http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) { - obsAddr := GetIPAddress(r) - require.Equal(t, tc.expAddr, obsAddr) - } - - errCallback := func(r *http.Request, err error) { - } - - middleware := sbff.Middleware(&config, errCallback) - - wrappedHandler := middleware(handler) - - r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) - - r.RemoteAddr = tc.remoteAddr - - if tc.headerVal != "" { - r.Header.Set(sbff.HeaderName, tc.headerVal) - } - - wrappedHandler.ServeHTTP(nil, r) - }) - - } -} - -func TestGetIPAddressWithXFF(t *tst.T) { +func TestGetIPAddress(t *tst.T) { examples := []func(r *http.Request) string{ func(r *http.Request) string { r.Header = nil diff --git a/migrations/20221003041349_add_mfa_schema.up.sql b/migrations/20221003041349_add_mfa_schema.up.sql index 26c4e4b69b..a44654aed3 100644 --- a/migrations/20221003041349_add_mfa_schema.up.sql +++ b/migrations/20221003041349_add_mfa_schema.up.sql @@ -1,18 +1,8 @@ -- see: https://stackoverflow.com/questions/7624919/check-if-a-user-defined-type-already-exists-in-postgresql/48382296#48382296 do $$ begin - create type {{ index .Options "Namespace" }}.factor_type as enum('totp', 'webauthn'); -exception - when duplicate_object then null; -end $$; - -do $$ begin - create type {{ index .Options "Namespace" }}.factor_status as enum('unverified', 'verified'); -exception - when duplicate_object then null; -end $$; - -do $$ begin - create type {{ index .Options "Namespace" }}.aal_level as enum('aal1', 'aal2', 'aal3'); + create type factor_type as enum('totp', 'webauthn'); + create type factor_status as enum('unverified', 'verified'); + create type aal_level as enum('aal1', 'aal2', 'aal3'); exception when duplicate_object then null; end $$; @@ -22,8 +12,8 @@ create table if not exists {{ index .Options "Namespace" }}.mfa_factors( id uuid not null, user_id uuid not null, friendly_name text null, - factor_type {{ index .Options "Namespace" }}.factor_type not null, - status {{ index .Options "Namespace" }}.factor_status not null, + factor_type factor_type not null, + status factor_status not null, created_at timestamptz not null, updated_at timestamptz not null, secret text null, diff --git a/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql b/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql index 426a42f591..cc8a2096d9 100644 --- a/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql +++ b/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql @@ -1,3 +1,3 @@ -- add factor_id to sessions alter table {{ index .Options "Namespace" }}.sessions add column if not exists factor_id uuid null; - alter table {{ index .Options "Namespace" }}.sessions add column if not exists aal {{ index .Options "Namespace" }}.aal_level null; + alter table {{ index .Options "Namespace" }}.sessions add column if not exists aal aal_level null; diff --git a/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql b/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql index 8975f1eb97..19ec79e9e3 100644 --- a/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql +++ b/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql @@ -9,5 +9,5 @@ update {{ index .Options "Namespace" }}.identities created_at = '2022-11-25' and updated_at = '2022-11-25' and provider = 'email' and - id::text = user_id::text; + id = user_id::text; end $$; diff --git a/migrations/20230322519590_add_flow_state_table.up.sql b/migrations/20230322519590_add_flow_state_table.up.sql index c54455f903..a8842e5b0d 100644 --- a/migrations/20230322519590_add_flow_state_table.up.sql +++ b/migrations/20230322519590_add_flow_state_table.up.sql @@ -1,6 +1,6 @@ -- see: https://stackoverflow.com/questions/7624919/check-if-a-user-defined-type-already-exists-in-postgresql/48382296#48382296 do $$ begin - create type {{ index .Options "Namespace" }}.code_challenge_method as enum('s256', 'plain'); + create type code_challenge_method as enum('s256', 'plain'); exception when duplicate_object then null; end $$; @@ -8,7 +8,7 @@ create table if not exists {{ index .Options "Namespace" }}.flow_state( id uuid primary key, user_id uuid null, auth_code text not null, - code_challenge_method {{ index .Options "Namespace" }}.code_challenge_method not null, + code_challenge_method code_challenge_method not null, code_challenge text not null, provider_type text not null, provider_access_token text null, diff --git a/migrations/20240427152123_add_one_time_tokens_table.up.sql b/migrations/20240427152123_add_one_time_tokens_table.up.sql index 38100c08cf..be7312656f 100644 --- a/migrations/20240427152123_add_one_time_tokens_table.up.sql +++ b/migrations/20240427152123_add_one_time_tokens_table.up.sql @@ -1,5 +1,5 @@ do $$ begin - create type {{ index .Options "Namespace" }}.one_time_token_type as enum ( + create type one_time_token_type as enum ( 'confirmation_token', 'reauthentication_token', 'recovery_token', @@ -16,7 +16,7 @@ do $$ begin create table if not exists {{ index .Options "Namespace" }}.one_time_tokens ( id uuid primary key, user_id uuid not null references {{ index .Options "Namespace" }}.users on delete cascade, - token_type {{ index .Options "Namespace" }}.one_time_token_type not null, + token_type one_time_token_type not null, token_hash text not null, relates_to text not null, created_at timestamp without time zone not null default now(),