Skip to content

Commit 33741e1

Browse files
authored
fix: amr claim should contain provider_id for sso method (#2033)
## What kind of change does this PR introduce? * When a SSO user has MFA enabled and signs in, the `sso/saml` amr claim does not contain the provider field * Usage of sort.Reverse was previously incorrect and the comments made did not match what was expected - sort.Reverse needs to be chained with sort.Sort and doesn't sort the array if used on it's own (see [docs](https://pkg.go.dev/sort#Reverse)) * Opted to remove the sorting structs because we can just do it with sort.Slice which is simpler
1 parent 657ea45 commit 33741e1

File tree

2 files changed

+25
-47
lines changed

2 files changed

+25
-47
lines changed

internal/models/sessions.go

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,6 @@ type AMREntry struct {
7171
Provider string `json:"provider,omitempty"`
7272
}
7373

74-
type sortAMREntries struct {
75-
Array []AMREntry
76-
}
77-
78-
func (s sortAMREntries) Len() int {
79-
return len(s.Array)
80-
}
81-
82-
func (s sortAMREntries) Less(i, j int) bool {
83-
return s.Array[i].Timestamp < s.Array[j].Timestamp
84-
}
85-
86-
func (s sortAMREntries) Swap(i, j int) {
87-
s.Array[j], s.Array[i] = s.Array[i], s.Array[j]
88-
}
89-
9074
type Session struct {
9175
ID uuid.UUID `json:"-" db:"id"`
9276
UserID uuid.UUID `json:"user_id" db:"user_id"`
@@ -328,41 +312,27 @@ func (s *Session) CalculateAALAndAMR(user *User) (aal AuthenticatorAssuranceLeve
328312
if claim.IsAAL2Claim() {
329313
aal = AAL2
330314
}
331-
amr = append(amr, AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()})
332-
}
333-
334-
// makes sure that the AMR claims are always ordered most-recent first
335-
336-
// sort in ascending order
337-
sort.Sort(sortAMREntries{
338-
Array: amr,
339-
})
340-
341-
// now reverse for descending order
342-
_ = sort.Reverse(sortAMREntries{
343-
Array: amr,
344-
})
345-
346-
lastIndex := len(amr) - 1
347-
348-
if lastIndex > -1 && amr[lastIndex].Method == SSOSAML.String() {
349-
// initial AMR claim is from sso/saml, we need to add information
350-
// about the provider that was used for the authentication
351-
identities := user.Identities
352-
353-
if len(identities) == 1 {
354-
identity := identities[0]
355-
356-
if identity.IsForSSOProvider() {
357-
amr[lastIndex].Provider = strings.TrimPrefix(identity.Provider, "sso:")
315+
entry := AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()}
316+
if entry.Method == SSOSAML.String() {
317+
// SSO users should only have one identity since they are excluded from account linking
318+
// These checks act as a safeguard in the event future changes break this assumption.
319+
identities := user.Identities
320+
if len(identities) == 1 {
321+
identity := identities[0]
322+
if identity.IsForSSOProvider() {
323+
entry.Provider = strings.TrimPrefix(identity.Provider, "sso:")
324+
}
358325
}
359326
}
327+
amr = append(amr, entry)
360328

361-
// otherwise we can't identify that this user account has only
362-
// one SSO identity, so we are not encoding the provider at
363-
// this time
364329
}
365330

331+
// makes sure that the AMR claims are always ordered most-recent first
332+
sort.Slice(amr, func(i, j int) bool {
333+
return amr[i].Timestamp > amr[j].Timestamp
334+
})
335+
366336
return aal, amr, nil
367337
}
368338

internal/models/sessions_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package models
22

33
import (
4+
"strings"
45
"testing"
56
"time"
67

@@ -79,6 +80,13 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
7980

8081
session = ts.AddClaimAndReloadSession(session, TOTPSignIn)
8182

83+
identity, err := NewIdentity(u, "sso:95d4a792-4a2a-4523-ae63-bae0631de554", map[string]interface{}{
84+
"sub": u.GetEmail(),
85+
})
86+
require.NoError(ts.T(), err)
87+
require.NoError(ts.T(), ts.db.Create(identity))
88+
u.Identities = append(u.Identities, *identity)
89+
8290
session = ts.AddClaimAndReloadSession(session, SSOSAML)
8391

8492
aal, amr, err := session.CalculateAALAndAMR(u)
@@ -97,7 +105,7 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
97105

98106
for _, claim := range amr {
99107
if claim.Method == SSOSAML.String() {
100-
require.NotNil(ts.T(), claim.Provider)
108+
require.Equal(ts.T(), strings.TrimPrefix(identity.Provider, "sso:"), claim.Provider)
101109
}
102110
}
103111
require.True(ts.T(), found)

0 commit comments

Comments
 (0)