Skip to content

Commit bfa2f4e

Browse files
committed
refactor: simplify handler and test logic
1 parent 9bf46bb commit bfa2f4e

File tree

4 files changed

+68
-66
lines changed

4 files changed

+68
-66
lines changed

device_request_handler.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic
2727
return request, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method))
2828
}
2929
if err := r.ParseForm(); err != nil {
30-
return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error()))
30+
return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error()))
3131
}
3232
if len(r.PostForm) == 0 {
3333
return request, errorsx.WithStack(ErrInvalidRequest.WithHint("The POST body can not be empty."))
@@ -44,11 +44,11 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic
4444
request.Client = client
4545

4646
if !client.GetGrantTypes().Has(string(GrantTypeDeviceCode)) {
47-
return nil, errorsx.WithStack(ErrInvalidGrant.WithHint("The requested OAuth 2.0 Client does not have the 'urn:ietf:params:oauth:grant-type:device_code' grant."))
47+
return request, errorsx.WithStack(ErrInvalidGrant.WithHint("The requested OAuth 2.0 Client does not have the 'urn:ietf:params:oauth:grant-type:device_code' grant."))
4848
}
4949

5050
if err := f.validateDeviceScope(ctx, r, request); err != nil {
51-
return nil, err
51+
return request, err
5252
}
5353

5454
if err := f.validateAudience(ctx, r, request); err != nil {
@@ -59,12 +59,13 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic
5959
}
6060

6161
func (f *Fosite) validateDeviceScope(ctx context.Context, r *http.Request, request *DeviceRequest) error {
62-
scope := RemoveEmpty(strings.Split(request.Form.Get("scope"), " "))
63-
for _, permission := range scope {
64-
if !f.Config.GetScopeStrategy(ctx)(request.Client.GetScopes(), permission) {
65-
return errorsx.WithStack(ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", permission))
62+
scopes := RemoveEmpty(strings.Split(request.Form.Get("scope"), " "))
63+
scopeStrategy := f.Config.GetScopeStrategy(ctx)
64+
for _, scope := range scopes {
65+
if !scopeStrategy(request.Client.GetScopes(), scope) {
66+
return errorsx.WithStack(ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", scope))
6667
}
6768
}
68-
request.SetRequestedScopes(scope)
69+
request.SetRequestedScopes(scopes)
6970
return nil
7071
}

device_request_handler_test.go

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,17 @@ import (
2222
func TestNewDeviceRequestWithPublicClient(t *testing.T) {
2323
ctrl := gomock.NewController(t)
2424
store := internal.NewMockStorage(ctrl)
25-
client := &DefaultClient{ID: "client_id"}
25+
deviceClient := &DefaultClient{ID: "client_id"}
26+
deviceClient.Public = true
27+
deviceClient.Scopes = []string{"17", "42"}
28+
deviceClient.Audience = []string{"aud2"}
29+
deviceClient.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
30+
31+
authCodeClient := &DefaultClient{ID: "client_id_2"}
32+
authCodeClient.Public = true
33+
authCodeClient.Scopes = []string{"17", "42"}
34+
authCodeClient.GrantTypes = []string{"authorization_code"}
35+
2636
defer ctrl.Finish()
2737
config := &Config{ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
2838
fosite := &Fosite{Store: store, Config: config}
@@ -63,40 +73,30 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
6373
},
6474
method: "POST",
6575
mock: func() {
66-
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
67-
client.Public = true
68-
client.Scopes = []string{"17", "42"}
69-
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
76+
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
7077
},
7178
expectedError: ErrInvalidScope,
7279
}, {
7380
description: "fails because audience not allowed",
7481
form: url.Values{
7582
"client_id": {"client_id"},
7683
"scope": {"17 42"},
77-
"audience": {"aud"},
84+
"audience": {"random_aud"},
7885
},
7986
method: "POST",
8087
mock: func() {
81-
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
82-
client.Public = true
83-
client.Scopes = []string{"17", "42"}
84-
client.Audience = []string{"aud2"}
85-
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
88+
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
8689
},
8790
expectedError: ErrInvalidRequest,
8891
}, {
8992
description: "fails because it doesn't have the proper grant",
9093
form: url.Values{
91-
"client_id": {"client_id"},
94+
"client_id": {"client_id_2"},
9295
"scope": {"17 42"},
9396
},
9497
method: "POST",
9598
mock: func() {
96-
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
97-
client.Public = true
98-
client.Scopes = []string{"17", "42"}
99-
client.GrantTypes = []string{"authorization_code"}
99+
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id_2")).Return(authCodeClient, nil)
100100
},
101101
expectedError: ErrInvalidGrant,
102102
}, {
@@ -107,10 +107,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
107107
},
108108
method: "POST",
109109
mock: func() {
110-
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
111-
client.Public = true
112-
client.Scopes = []string{"17", "42"}
113-
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
110+
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
114111
},
115112
}} {
116113
t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) {
@@ -123,10 +120,8 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
123120
}
124121

125122
ar, err := fosite.NewDeviceRequest(context.Background(), r)
126-
if c.expectedError != nil {
127-
assert.EqualError(t, err, c.expectedError.Error())
128-
} else {
129-
require.NoError(t, err)
123+
require.ErrorIs(t, err, c.expectedError)
124+
if c.expectedError == nil {
130125
assert.NotNil(t, ar.GetRequestedAt())
131126
}
132127
})
@@ -141,15 +136,21 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
141136
defer ctrl.Finish()
142137
config := &Config{ClientSecretsHasher: hasher, ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
143138
fosite := &Fosite{Store: store, Config: config}
139+
140+
client.Public = false
141+
client.Secret = []byte("client_secret")
142+
client.Scopes = []string{"foo", "bar"}
143+
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
144+
144145
for k, c := range []struct {
145146
header http.Header
146147
form url.Values
147148
method string
148149
expectedError error
149150
mock func()
150151
expect DeviceRequester
152+
description string
151153
}{
152-
// No client authn provided
153154
{
154155
form: url.Values{
155156
"client_id": {"client_id"},
@@ -159,14 +160,26 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
159160
method: "POST",
160161
mock: func() {
161162
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
162-
client.Public = false
163-
client.Secret = []byte("client_secret")
164-
client.Scopes = []string{"foo", "bar"}
165-
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
166163
hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New(""))
167164
},
165+
description: "Should failed becaue no client authn provided.",
166+
},
167+
{
168+
form: url.Values{
169+
"client_id": {"client_id2"},
170+
"scope": {"foo bar"},
171+
},
172+
header: http.Header{
173+
"Authorization": {basicAuth("client_id", "client_secret")},
174+
},
175+
expectedError: ErrInvalidRequest,
176+
method: "POST",
177+
mock: func() {
178+
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
179+
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil)
180+
},
181+
description: "should fail because different client is used in authn than in form",
168182
},
169-
// success
170183
{
171184
form: url.Values{
172185
"client_id": {"client_id"},
@@ -178,15 +191,12 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
178191
method: "POST",
179192
mock: func() {
180193
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
181-
client.Public = false
182-
client.Secret = []byte("client_secret")
183-
client.Scopes = []string{"foo", "bar"}
184-
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
185194
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil)
186195
},
196+
description: "should succeed",
187197
},
188198
} {
189-
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
199+
t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) {
190200
c.mock()
191201
r := &http.Request{
192202
Header: c.header,
@@ -196,11 +206,9 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
196206
}
197207

198208
req, err := fosite.NewDeviceRequest(context.Background(), r)
199-
if c.expectedError != nil {
200-
assert.EqualError(t, err, c.expectedError.Error())
201-
} else {
202-
require.NoError(t, err)
203-
assert.NotNil(t, req.GetRequestedAt())
209+
require.ErrorIs(t, err, c.expectedError)
210+
if c.expectedError == nil {
211+
assert.NotZero(t, req.GetRequestedAt())
204212
}
205213
})
206214
}

device_request_test.go

Lines changed: 0 additions & 18 deletions
This file was deleted.

fosite_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
. "github.com/ory/fosite"
1414
"github.com/ory/fosite/handler/oauth2"
1515
"github.com/ory/fosite/handler/par"
16+
"github.com/ory/fosite/handler/rfc8628"
1617
)
1718

1819
func TestAuthorizeEndpointHandlers(t *testing.T) {
@@ -25,6 +26,16 @@ func TestAuthorizeEndpointHandlers(t *testing.T) {
2526
assert.Equal(t, hs[0], h)
2627
}
2728

29+
func TestDeviceAuthorizeEndpointHandlers(t *testing.T) {
30+
h := &rfc8628.DeviceAuthHandler{}
31+
hs := DeviceEndpointHandlers{}
32+
hs.Append(h)
33+
hs.Append(h)
34+
hs.Append(&rfc8628.DeviceAuthHandler{})
35+
assert.Len(t, hs, 1)
36+
assert.Equal(t, hs[0], h)
37+
}
38+
2839
func TestTokenEndpointHandlers(t *testing.T) {
2940
h := &oauth2.AuthorizeExplicitGrantHandler{}
3041
hs := TokenEndpointHandlers{}

0 commit comments

Comments
 (0)