Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement kid support to tokens #2

Merged
merged 15 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions auth/access-token.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,13 @@ func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perm
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
}

// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID
func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID)
}

// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID
func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID)
}
26 changes: 26 additions & 0 deletions auth/access-token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,29 @@ func TestCreateAccessToken(t *testing.T) {
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
}

func TestCreateAccessTokenInvalid(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test")
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
}
20 changes: 20 additions & 0 deletions auth/pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,23 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat
}
return accessToken, refreshToken, nil
}

// CreateTokenPairWithKID creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively using the specified kID
func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID)
}

// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID)
if err != nil {
return "", "", err
}
refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID)
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}
31 changes: 31 additions & 0 deletions auth/pair_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,34 @@ func TestCreateTokenPair(t *testing.T) {
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)
}

func TestCreateTokenPairWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test")
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))

_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)
}
10 changes: 10 additions & 0 deletions auth/refresh-token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,13 @@ func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
}

// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID
func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID)
}

// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID
func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID)
}
20 changes: 20 additions & 0 deletions auth/refresh-token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,23 @@ func TestCreateRefreshToken(t *testing.T) {
assert.Equal(t, "test", b.ID)
assert.Equal(t, "test2", b.Claims.AccessTokenId)
}

func TestCreateRefreshTokenWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test")
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.Equal(t, "test2", b.Claims.AccessTokenId)
}
34 changes: 16 additions & 18 deletions cmd/mjwt/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ package main

import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"github.com/google/subcommands"
"os"
Expand All @@ -18,15 +16,15 @@ import (
)

type accessCmd struct {
issuer, subject, id, audience, duration string
issuer, subject, id, audience, duration, kID string
}

func (s *accessCmd) Name() string { return "access" }
func (s *accessCmd) Synopsis() string {
return "Generates an access token with permissions using the private key"
}
func (s *accessCmd) Usage() string {
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] <private key path> <space separated permissions>
return `sign [-iss <issuer>] [-sub <subject>] [-id <id>] [-aud <audience>] [-dur <duration>] [-kid <name>] <private key path> <space separated permissions>
Output a signed MJWT token with the specified permissions.
`
}
Expand All @@ -37,6 +35,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.id, "id", "", "MJWT ID")
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
f.StringVar(&s.kID, "kid", "\x00", "The Key ID of the signing key")
}

func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
Expand All @@ -46,7 +45,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
}

args := f.Args()
key, err := s.parseKey(args[0])
key, err := rsaprivate.Read(args[0])
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse private key: ", err)
return subcommands.ExitFailure
Expand All @@ -67,8 +66,17 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
return subcommands.ExitFailure
}

signer := mjwt.NewMJwtSigner(s.issuer, key)
token, err := signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
var token string
if s.kID == "\x00" {
signer := mjwt.NewMJwtSigner(s.issuer, key)
token, err = signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
} else {
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey(s.kID, key)
signer := mjwt.NewMJwtSignerWithKeyStore(s.issuer, nil, kStore)
token, err = signer.GenerateJwtWithKID(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}, s.kID)
}

if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
return subcommands.ExitFailure
Expand All @@ -77,13 +85,3 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
fmt.Println(token)
return subcommands.ExitSuccess
}

func (s *accessCmd) parseKey(privKeyFile string) (*rsa.PrivateKey, error) {
b, err := os.ReadFile(privKeyFile)
if err != nil {
return nil, err
}

p, _ := pem.Decode(b)
return x509.ParsePKCS1PrivateKey(p.Bytes)
}
23 changes: 4 additions & 19 deletions cmd/mjwt/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package main
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/google/subcommands"
"math/rand"
"os"
Expand Down Expand Up @@ -49,29 +49,14 @@ func (g *genCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) s
}

func (g *genCmd) gen(privPath, pubPath string) error {
createPriv, err := os.OpenFile(privPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer createPriv.Close()

createPub, err := os.OpenFile(pubPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer createPub.Close()

key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), g.bits)
if err != nil {
return err
}

keyBytes := x509.MarshalPKCS1PrivateKey(key)
pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey)
err = pem.Encode(createPriv, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
err = rsaprivate.Write(privPath, key)
if err != nil {
return err
}
err = pem.Encode(createPub, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes})
return err
return rsapublic.Write(pubPath, &key.PublicKey)
}
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module github.com/1f349/mjwt

go 1.19
go 1.22

toolchain go1.22.3

require (
github.com/1f349/rsa-helper v0.0.1
github.com/becheran/wildmatch-go v1.0.0
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/subcommands v1.2.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/1f349/rsa-helper v0.0.1 h1:Ec/MXHR2eIpLgIR69eqhCV2o8OOBs2JZNAkEhW7HQks=
github.com/1f349/rsa-helper v0.0.1/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
Expand Down
16 changes: 16 additions & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,28 @@ type Signer interface {
Verifier
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
SignJwt(claims jwt.Claims) (string, error)
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
Issuer() string
PrivateKey() *rsa.PrivateKey
PrivateKeyOf(kID string) *rsa.PrivateKey
}

// Verifier is used to verify the validity MJWT tokens and extract the claim values.
type Verifier interface {
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
PublicKey() *rsa.PublicKey
PublicKeyOf(kID string) *rsa.PublicKey
GetKeyStore() KeyStore
}

// KeyStore is used for the kid header support in Signer and Verifier.
type KeyStore interface {
SetKey(kID string, prvKey *rsa.PrivateKey)
SetKeyPublic(kID string, pubKey *rsa.PublicKey)
RemoveKey(kID string)
ListKeys() []string
GetKey(kID string) *rsa.PrivateKey
GetKeyPublic(kID string) *rsa.PublicKey
ClearKeys()
}
Loading
Loading