Skip to content

Commit 9753bcd

Browse files
committed
refactor: enhance deviceRequest struct
1 parent 94653ee commit 9753bcd

15 files changed

+134
-188
lines changed

device_request.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,41 @@
33

44
package fosite
55

6+
type UserCodeState int16
7+
8+
const (
9+
// User code is active
10+
UserCodeUnused = UserCodeState(0)
11+
// User code has been accepted
12+
UserCodeAccepted = UserCodeState(1)
13+
// User code has been rejected
14+
UserCodeRejected = UserCodeState(2)
15+
)
16+
617
// DeviceRequest is an implementation of DeviceRequester
718
type DeviceRequest struct {
19+
UserCodeState UserCodeState
820
Request
921
}
1022

23+
func (d *DeviceRequest) GetUserCodeState() UserCodeState {
24+
return d.UserCodeState
25+
}
26+
27+
func (d *DeviceRequest) SetUserCodeState(state UserCodeState) {
28+
d.UserCodeState = state
29+
}
30+
31+
func (d *DeviceRequest) Sanitize(allowedParameters []string) Requester {
32+
r, _ := d.Request.Sanitize(allowedParameters).(*Request)
33+
d.Request = *r
34+
return d
35+
}
36+
1137
// NewDeviceRequest returns a new device request
1238
func NewDeviceRequest() *DeviceRequest {
1339
return &DeviceRequest{
14-
Request: *NewRequest(),
40+
UserCodeState: UserCodeUnused,
41+
Request: *NewRequest(),
1542
}
1643
}

device_write_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestWriteDeviceUserResponse(t *testing.T) {
2626
ctx := context.Background()
2727

2828
rw := httptest.NewRecorder()
29-
ar := &Request{}
29+
ar := &DeviceRequest{}
3030
resp := &DeviceResponse{}
3131
resp.SetUserCode("AAAA")
3232
resp.SetDeviceCode("BBBB")

handler/rfc8628/auth_handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (d *DeviceAuthHandler) handleDeviceAuthSession(ctx context.Context, dar fos
6666
return "", "", err
6767
}
6868

69-
if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil)); err == nil {
69+
if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil).(fosite.DeviceRequester)); err == nil {
7070
return deviceCode, userCode, nil
7171
}
7272
}

handler/rfc8628/storage.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ type RFC8628CoreStorage interface {
2020
// DeviceAuthStorage handles the device auth session storage
2121
type DeviceAuthStorage interface {
2222
// CreateDeviceAuthSession stores the device auth request session.
23-
CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.Requester) (err error)
23+
CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.DeviceRequester) (err error)
2424

2525
// GetDeviceCodeSession hydrates the session based on the given device code and returns the device request.
2626
// If the device code has been invalidated with `InvalidateDeviceCodeSession`, this
2727
// method should return the ErrInvalidatedDeviceCode error.
2828
//
2929
// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedDeviceCode error!
30-
GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error)
30+
GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.DeviceRequester, err error)
3131

3232
// InvalidateDeviceCodeSession is called when a device code is being used. The state of the device
3333
// code should be set to invalid and consecutive requests to GetDeviceCodeSession should return the

handler/rfc8628/strategy.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,30 @@ type RFC8628CodeStrategy interface {
1818

1919
// DeviceRateLimitStrategy handles the rate limiting strategy
2020
type DeviceRateLimitStrategy interface {
21+
// ShouldRateLimit checks whether the token request should be rate-limited
2122
ShouldRateLimit(ctx context.Context, code string) (bool, error)
2223
}
2324

2425
// DeviceCodeStrategy handles the device_code strategy
2526
type DeviceCodeStrategy interface {
27+
// DeviceCodeSignature calculates the signature of a device_code
2628
DeviceCodeSignature(ctx context.Context, code string) (signature string, err error)
29+
30+
// GenerateDeviceCode generates a new device code and signature
2731
GenerateDeviceCode(ctx context.Context) (code string, signature string, err error)
28-
ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) (err error)
32+
33+
// ValidateDeviceCode validates the device_code
34+
ValidateDeviceCode(ctx context.Context, r fosite.DeviceRequester, code string) (err error)
2935
}
3036

3137
// UserCodeStrategy handles the user_code strategy
3238
type UserCodeStrategy interface {
39+
// UserCodeSignature calculates the signature of a user_code
3340
UserCodeSignature(ctx context.Context, code string) (signature string, err error)
41+
42+
// GenerateUserCode generates a new user code and signature
3443
GenerateUserCode(ctx context.Context) (code string, signature string, err error)
35-
ValidateUserCode(ctx context.Context, r fosite.Requester, code string) (err error)
44+
45+
// ValidateUserCode validates the user_code
46+
ValidateUserCode(ctx context.Context, r fosite.DeviceRequester, code string) (err error)
3647
}

handler/rfc8628/strategy_hmacsha.go

Lines changed: 2 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ import (
88
"strings"
99
"time"
1010

11-
"github.com/mohae/deepcopy"
12-
1311
"github.com/ory/x/errorsx"
1412

1513
"github.com/ory/x/randx"
@@ -20,83 +18,6 @@ import (
2018

2119
const POLLING_RATE_LIMITING_LEEWAY = 200 * time.Millisecond
2220

23-
// DeviceFlowSession is a fosite.Session container specific for the device flow.
24-
type DeviceFlowSession interface {
25-
// GetBrowserFlowCompleted returns the flag indicating whether user has completed the browser flow or not.
26-
GetBrowserFlowCompleted() bool
27-
28-
// SetBrowserFlowCompleted allows client to mark user has completed the browser flow.
29-
SetBrowserFlowCompleted(flag bool)
30-
31-
fosite.Session
32-
}
33-
34-
// DefaultDeviceFlowSession is a DeviceFlowSession implementation for the device flow.
35-
type DefaultDeviceFlowSession struct {
36-
ExpiresAt map[fosite.TokenType]time.Time `json:"expires_at"`
37-
Username string `json:"username"`
38-
Subject string `json:"subject"`
39-
Extra map[string]interface{} `json:"extra"`
40-
BrowserFlowCompleted bool `json:"browser_flow_completed"`
41-
}
42-
43-
func (s *DefaultDeviceFlowSession) SetExpiresAt(key fosite.TokenType, exp time.Time) {
44-
if s.ExpiresAt == nil {
45-
s.ExpiresAt = make(map[fosite.TokenType]time.Time)
46-
}
47-
s.ExpiresAt[key] = exp
48-
}
49-
50-
func (s *DefaultDeviceFlowSession) GetExpiresAt(key fosite.TokenType) time.Time {
51-
if s.ExpiresAt == nil {
52-
s.ExpiresAt = make(map[fosite.TokenType]time.Time)
53-
}
54-
55-
if _, ok := s.ExpiresAt[key]; !ok {
56-
return time.Time{}
57-
}
58-
return s.ExpiresAt[key]
59-
}
60-
61-
func (s *DefaultDeviceFlowSession) GetUsername() string {
62-
if s == nil {
63-
return ""
64-
}
65-
return s.Username
66-
}
67-
68-
func (s *DefaultDeviceFlowSession) SetSubject(subject string) {
69-
s.Subject = subject
70-
}
71-
72-
func (s *DefaultDeviceFlowSession) GetSubject() string {
73-
if s == nil {
74-
return ""
75-
}
76-
77-
return s.Subject
78-
}
79-
80-
func (s *DefaultDeviceFlowSession) Clone() fosite.Session {
81-
if s == nil {
82-
return nil
83-
}
84-
85-
return deepcopy.Copy(s).(fosite.Session)
86-
}
87-
88-
func (s *DefaultDeviceFlowSession) GetBrowserFlowCompleted() bool {
89-
if s == nil {
90-
return false
91-
}
92-
93-
return s.BrowserFlowCompleted
94-
}
95-
96-
func (s *DefaultDeviceFlowSession) SetBrowserFlowCompleted(flag bool) {
97-
s.BrowserFlowCompleted = flag
98-
}
99-
10021
// DefaultDeviceStrategy implements the default device strategy
10122
type DefaultDeviceStrategy struct {
10223
Enigma *enigma.HMACStrategy
@@ -129,7 +50,7 @@ func (h *DefaultDeviceStrategy) UserCodeSignature(ctx context.Context, token str
12950
}
13051

13152
// ValidateUserCode validates a user_code
132-
func (h *DefaultDeviceStrategy) ValidateUserCode(ctx context.Context, r fosite.Requester, code string) error {
53+
func (h *DefaultDeviceStrategy) ValidateUserCode(ctx context.Context, r fosite.DeviceRequester, code string) error {
13354
exp := r.GetSession().GetExpiresAt(fosite.UserCode)
13455
if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)).Before(time.Now().UTC()) {
13556
return errorsx.WithStack(fosite.ErrDeviceExpiredToken.WithHintf("User code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx))))
@@ -156,7 +77,7 @@ func (h *DefaultDeviceStrategy) DeviceCodeSignature(ctx context.Context, token s
15677
}
15778

15879
// ValidateDeviceCode validates a device_code
159-
func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) error {
80+
func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite.DeviceRequester, code string) error {
16081
exp := r.GetSession().GetExpiresAt(fosite.DeviceCode)
16182
if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)).Before(time.Now().UTC()) {
16283
return errorsx.WithStack(fosite.ErrDeviceExpiredToken.WithHintf("Device code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx))))

handler/rfc8628/strategy_hmacsha_test.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,27 @@ var hmacshaStrategy = DefaultDeviceStrategy{
2828
},
2929
}
3030

31-
var hmacValidCase = fosite.Request{
32-
Client: &fosite.DefaultClient{
33-
Secret: []byte("foobarfoobarfoobarfoobar"),
34-
},
35-
Session: &fosite.DefaultSession{
36-
ExpiresAt: map[fosite.TokenType]time.Time{
37-
fosite.UserCode: time.Now().UTC().Add(time.Hour),
38-
fosite.DeviceCode: time.Now().UTC().Add(time.Hour),
31+
var hmacValidCase = fosite.DeviceRequest{
32+
Request: fosite.Request{
33+
Client: &fosite.DefaultClient{
34+
Secret: []byte("foobarfoobarfoobarfoobar"),
35+
},
36+
Session: &fosite.DefaultSession{
37+
ExpiresAt: map[fosite.TokenType]time.Time{
38+
fosite.UserCode: time.Now().UTC().Add(time.Hour),
39+
fosite.DeviceCode: time.Now().UTC().Add(time.Hour),
40+
},
3941
},
4042
},
4143
}
4244

4345
func TestHMACUserCode(t *testing.T) {
4446
for k, c := range []struct {
45-
r fosite.Request
47+
r fosite.DeviceRequester
4648
pass bool
4749
}{
4850
{
49-
r: hmacValidCase,
51+
r: &hmacValidCase,
5052
pass: true,
5153
},
5254
} {
@@ -56,7 +58,7 @@ func TestHMACUserCode(t *testing.T) {
5658
regex := regexp.MustCompile("[ABCDEFGHIJKLMNOPQRSTUVWXYZ]{8}")
5759
assert.Equal(t, len(regex.FindString(userCode)), len(userCode))
5860

59-
err = hmacshaStrategy.ValidateUserCode(context.TODO(), &c.r, userCode)
61+
err = hmacshaStrategy.ValidateUserCode(context.TODO(), c.r, userCode)
6062
if c.pass {
6163
assert.NoError(t, err)
6264
validate, _ := hmacshaStrategy.Enigma.GenerateHMACForString(context.TODO(), userCode)
@@ -73,11 +75,11 @@ func TestHMACUserCode(t *testing.T) {
7375

7476
func TestHMACDeviceCode(t *testing.T) {
7577
for k, c := range []struct {
76-
r fosite.Request
78+
r fosite.DeviceRequester
7779
pass bool
7880
}{
7981
{
80-
r: hmacValidCase,
82+
r: &hmacValidCase,
8183
pass: true,
8284
},
8385
} {
@@ -92,7 +94,7 @@ func TestHMACDeviceCode(t *testing.T) {
9294
strings.TrimPrefix(token, "ory_dc_"),
9395
} {
9496
t.Run(fmt.Sprintf("prefix=%v", k == 0), func(t *testing.T) {
95-
err = hmacshaStrategy.ValidateDeviceCode(context.TODO(), &c.r, token)
97+
err = hmacshaStrategy.ValidateDeviceCode(context.TODO(), c.r, token)
9698
if c.pass {
9799
assert.NoError(t, err)
98100
validate := hmacshaStrategy.Enigma.Signature(token)

handler/rfc8628/token_handler.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ func (c *DeviceCodeTokenEndpointHandler) PopulateTokenEndpointResponse(ctx conte
6262
return err
6363
}
6464

65-
var ar fosite.Requester
65+
var ar fosite.DeviceRequester
6666
if ar, err = c.session(ctx, requester, signature); err != nil {
6767
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
6868
}
6969

70-
if err = c.DeviceCodeStrategy.ValidateDeviceCode(ctx, requester, code); err != nil {
70+
if err = c.DeviceCodeStrategy.ValidateDeviceCode(ctx, ar, code); err != nil {
7171
return errorsx.WithStack(err)
7272
}
7373

@@ -154,7 +154,7 @@ func (c *DeviceCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context.
154154
return errorsx.WithStack(err)
155155
}
156156

157-
var ar fosite.Requester
157+
var ar fosite.DeviceRequester
158158
if ar, err = c.session(ctx, requester, signature); err != nil {
159159
if ar != nil && (errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) || errors.Is(err, fosite.ErrInvalidatedDeviceCode)) {
160160
return c.revokeTokens(ctx, requester.GetID())
@@ -252,7 +252,7 @@ func (c DeviceCodeTokenEndpointHandler) validateCode(ctx context.Context, reques
252252
return nil
253253
}
254254

255-
func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.Requester, error) {
255+
func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.DeviceRequester, error) {
256256
req, err := s.CoreStorage.GetDeviceCodeSession(ctx, codeSignature, requester.GetSession())
257257

258258
if err != nil && errors.Is(err, fosite.ErrInvalidatedDeviceCode) {
@@ -265,10 +265,6 @@ func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester f
265265
WithDebug("\"GetDeviceCodeSession\" must return a value for \"fosite.Requester\" when returning \"ErrInvalidatedDeviceCode\".")
266266
}
267267

268-
if err != nil && errors.Is(err, fosite.ErrAuthorizationPending) {
269-
return nil, err
270-
}
271-
272268
if err != nil && errors.Is(err, fosite.ErrNotFound) {
273269
return nil, errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error()))
274270
}
@@ -277,14 +273,14 @@ func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester f
277273
return nil, errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
278274
}
279275

280-
session, ok := req.GetSession().(DeviceFlowSession)
281-
if !ok {
282-
return nil, fosite.ErrServerError.WithHint("Wrong authorization request session.")
283-
}
276+
state := req.GetUserCodeState()
284277

285-
if !session.GetBrowserFlowCompleted() {
278+
if state == fosite.UserCodeUnused {
286279
return nil, fosite.ErrAuthorizationPending
287280
}
281+
if state == fosite.UserCodeRejected {
282+
return nil, fosite.ErrAccessDenied
283+
}
288284

289285
return req, err
290286
}

0 commit comments

Comments
 (0)