Skip to content

Commit e1d2ea3

Browse files
committed
fix: PR changes
1 parent af3e211 commit e1d2ea3

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/*
2+
* Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved.
3+
*
4+
* This software is licensed under the Apache License, Version 2.0 (the
5+
* "License") as published by the Apache Software Foundation.
6+
*
7+
* You may not use this file except in compliance with the License. You may
8+
* obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations
14+
* under the License.
15+
*/
16+
17+
package session
18+
19+
import (
20+
"io"
21+
"net/http"
22+
"net/http/httptest"
23+
"testing"
24+
25+
"github.com/supertokens/supertokens-golang/recipe/session/claims"
26+
sessionErrors "github.com/supertokens/supertokens-golang/recipe/session/errors"
27+
"github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
28+
"github.com/supertokens/supertokens-golang/supertokens"
29+
"github.com/supertokens/supertokens-golang/test/unittesting"
30+
31+
"github.com/stretchr/testify/assert"
32+
)
33+
34+
func TestSessionErrorHandlerOverides(t *testing.T) {
35+
BeforeEach()
36+
37+
customAntiCsrfVal := "VIA_TOKEN"
38+
configValue := supertokens.TypeInput{
39+
Supertokens: &supertokens.ConnectionInfo{
40+
ConnectionURI: "http://localhost:8080",
41+
},
42+
AppInfo: supertokens.AppInfo{
43+
AppName: "SuperTokens",
44+
WebsiteDomain: "supertokens.io",
45+
APIDomain: "api.supertokens.io",
46+
},
47+
RecipeList: []supertokens.Recipe{
48+
Init(&sessmodels.TypeInput{
49+
AntiCsrf: &customAntiCsrfVal,
50+
ErrorHandlers: &sessmodels.ErrorHandlers{
51+
OnUnauthorised: func(message string, req *http.Request, res http.ResponseWriter) error {
52+
res.WriteHeader(401)
53+
res.Write([]byte("unauthorised from errorHandler"))
54+
return nil
55+
},
56+
OnTokenTheftDetected: func(sessionHandle, userID string, req *http.Request, res http.ResponseWriter) error {
57+
res.WriteHeader(403)
58+
res.Write([]byte("token theft detected from errorHandler"))
59+
return nil
60+
},
61+
OnTryRefreshToken: func(message string, req *http.Request, res http.ResponseWriter) error {
62+
res.WriteHeader(401)
63+
res.Write([]byte("try refresh token from errorHandler"))
64+
return nil
65+
},
66+
OnInvalidClaim: func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error {
67+
res.WriteHeader(403)
68+
res.Write([]byte("invalid claim from errorHandler"))
69+
return nil
70+
},
71+
OnClearDuplicateSessionCookies: func(message string, req *http.Request, res http.ResponseWriter) error {
72+
res.WriteHeader(200)
73+
res.Write([]byte("clear duplicate session cookies from errorHandler"))
74+
return nil
75+
},
76+
},
77+
GetTokenTransferMethod: func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) sessmodels.TokenTransferMethod {
78+
return sessmodels.CookieTransferMethod
79+
},
80+
}),
81+
},
82+
}
83+
84+
unittesting.StartUpST("localhost", "8080")
85+
defer AfterEach()
86+
err := supertokens.Init(configValue)
87+
if err != nil {
88+
t.Error(err.Error())
89+
}
90+
91+
mux := http.NewServeMux()
92+
93+
mux.HandleFunc("/test/unauthorized", func(rw http.ResponseWriter, r *http.Request) {
94+
supertokens.ErrorHandler(sessionErrors.UnauthorizedError{}, r, rw)
95+
})
96+
97+
mux.HandleFunc("/test/try-refresh", func(rw http.ResponseWriter, r *http.Request) {
98+
supertokens.ErrorHandler(sessionErrors.TryRefreshTokenError{}, r, rw)
99+
})
100+
101+
mux.HandleFunc("/test/token-theft", func(rw http.ResponseWriter, r *http.Request) {
102+
supertokens.ErrorHandler(sessionErrors.TokenTheftDetectedError{}, r, rw)
103+
})
104+
105+
mux.HandleFunc("/test/claim-validation", func(rw http.ResponseWriter, r *http.Request) {
106+
supertokens.ErrorHandler(sessionErrors.InvalidClaimError{}, r, rw)
107+
})
108+
109+
mux.HandleFunc("/test/clear-duplicate-session", func(rw http.ResponseWriter, r *http.Request) {
110+
supertokens.ErrorHandler(sessionErrors.ClearDuplicateSessionCookiesError{}, r, rw)
111+
})
112+
113+
testServer := httptest.NewServer(supertokens.Middleware(mux))
114+
defer func() {
115+
testServer.Close()
116+
}()
117+
118+
t.Run("should override session errorHandlers", func(t *testing.T) {
119+
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/test/unauthorized", nil)
120+
assert.NoError(t, err)
121+
122+
res, err := http.DefaultClient.Do(req)
123+
assert.NoError(t, err)
124+
assert.Equal(t, 401, res.StatusCode)
125+
126+
content, err := io.ReadAll(res.Body)
127+
assert.NoError(t, err)
128+
assert.Equal(t, `{"message":"unauthorised from errorHandler"}`, string(content))
129+
130+
req, err = http.NewRequest(http.MethodGet, testServer.URL+"/test/try-refresh", nil)
131+
assert.NoError(t, err)
132+
133+
res, err = http.DefaultClient.Do(req)
134+
assert.NoError(t, err)
135+
assert.Equal(t, 401, res.StatusCode)
136+
137+
content, err = io.ReadAll(res.Body)
138+
assert.NoError(t, err)
139+
assert.Equal(t, `{"message":"try refresh token from errorHandler"}`, string(content))
140+
141+
req, err = http.NewRequest(http.MethodGet, testServer.URL+"/test/token-theft", nil)
142+
assert.NoError(t, err)
143+
144+
res, err = http.DefaultClient.Do(req)
145+
assert.NoError(t, err)
146+
assert.Equal(t, 403, res.StatusCode)
147+
148+
content, err = io.ReadAll(res.Body)
149+
assert.NoError(t, err)
150+
assert.Equal(t, `{"message":"token theft detected from errorHandler"}`, string(content))
151+
152+
req, err = http.NewRequest(http.MethodGet, testServer.URL+"/test/claim-validation", nil)
153+
assert.NoError(t, err)
154+
155+
res, err = http.DefaultClient.Do(req)
156+
assert.NoError(t, err)
157+
assert.Equal(t, 403, res.StatusCode)
158+
159+
content, err = io.ReadAll(res.Body)
160+
assert.NoError(t, err)
161+
assert.Equal(t, `{"message":"invalid claim from errorHandler"}`, string(content))
162+
163+
req, err = http.NewRequest(http.MethodGet, testServer.URL+"/test/clear-duplicate-session", nil)
164+
assert.NoError(t, err)
165+
166+
res, err = http.DefaultClient.Do(req)
167+
assert.NoError(t, err)
168+
assert.Equal(t, 200, res.StatusCode)
169+
170+
content, err = io.ReadAll(res.Body)
171+
assert.NoError(t, err)
172+
assert.Equal(t, `{"message":"clear duplicate session cookies from errorHandler"}`, string(content))
173+
})
174+
}

0 commit comments

Comments
 (0)