diff --git a/auth/access-token.go b/auth/access-token.go index 7b54e4d..17e8ab7 100644 --- a/auth/access-token.go +++ b/auth/access-token.go @@ -25,3 +25,13 @@ func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perm func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) { return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}) } + +// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID +func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) { + return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID) +} + +// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID +func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) { + return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID) +} diff --git a/auth/access-token_test.go b/auth/access-token_test.go index acf53aa..11b2523 100644 --- a/auth/access-token_test.go +++ b/auth/access-token_test.go @@ -31,3 +31,29 @@ func TestCreateAccessToken(t *testing.T) { assert.True(t, b.Claims.Perms.Has("mjwt:test2")) assert.False(t, b.Claims.Perms.Has("mjwt:test3")) } + +func TestCreateAccessTokenInvalid(t *testing.T) { + t.Parallel() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + kStore := mjwt.NewMJwtKeyStore() + kStore.SetKey("test", key) + + ps := claims.NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + + s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) + + accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test") + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.True(t, b.Claims.Perms.Has("mjwt:test")) + assert.True(t, b.Claims.Perms.Has("mjwt:test2")) + assert.False(t, b.Claims.Perms.Has("mjwt:test3")) +} diff --git a/auth/pair.go b/auth/pair.go index b2ce969..7d50e55 100644 --- a/auth/pair.go +++ b/auth/pair.go @@ -26,3 +26,23 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat } return accessToken, refreshToken, nil } + +// CreateTokenPairWithKID creates an access and refresh token pair using the default +// 15 minute and 7 day durations respectively using the specified kID +func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) { + return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID) +} + +// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using +// custom durations for the access and refresh tokens +func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) { + accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID) + if err != nil { + return "", "", err + } + refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID) + if err != nil { + return "", "", err + } + return accessToken, refreshToken, nil +} diff --git a/auth/pair_test.go b/auth/pair_test.go index 5a7dd77..0b9d135 100644 --- a/auth/pair_test.go +++ b/auth/pair_test.go @@ -36,3 +36,34 @@ func TestCreateTokenPair(t *testing.T) { assert.Equal(t, "1", b2.Subject) assert.Equal(t, "test2", b2.ID) } + +func TestCreateTokenPairWithKID(t *testing.T) { + t.Parallel() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + kStore := mjwt.NewMJwtKeyStore() + kStore.SetKey("test", key) + + ps := claims.NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + + s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) + + accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test") + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.True(t, b.Claims.Perms.Has("mjwt:test")) + assert.True(t, b.Claims.Perms.Has("mjwt:test2")) + assert.False(t, b.Claims.Perms.Has("mjwt:test3")) + + _, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) + assert.NoError(t, err) + assert.Equal(t, "1", b2.Subject) + assert.Equal(t, "test2", b2.ID) +} diff --git a/auth/refresh-token.go b/auth/refresh-token.go index 5667885..b0b4675 100644 --- a/auth/refresh-token.go +++ b/auth/refresh-token.go @@ -24,3 +24,13 @@ func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) { return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}) } + +// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID +func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) { + return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID) +} + +// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID +func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) { + return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID) +} diff --git a/auth/refresh-token_test.go b/auth/refresh-token_test.go index 4765c35..bcb7521 100644 --- a/auth/refresh-token_test.go +++ b/auth/refresh-token_test.go @@ -24,3 +24,23 @@ func TestCreateRefreshToken(t *testing.T) { assert.Equal(t, "test", b.ID) assert.Equal(t, "test2", b.Claims.AccessTokenId) } + +func TestCreateRefreshTokenWithKID(t *testing.T) { + t.Parallel() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + kStore := mjwt.NewMJwtKeyStore() + kStore.SetKey("test", key) + + s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) + + refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test") + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.Equal(t, "test2", b.Claims.AccessTokenId) +} diff --git a/cmd/mjwt/access.go b/cmd/mjwt/access.go index 32a3dd3..a2cb35e 100644 --- a/cmd/mjwt/access.go +++ b/cmd/mjwt/access.go @@ -2,14 +2,12 @@ package main import ( "context" - "crypto/rsa" - "crypto/x509" - "encoding/pem" "flag" "fmt" "github.com/1f349/mjwt" "github.com/1f349/mjwt/auth" "github.com/1f349/mjwt/claims" + "github.com/1f349/rsa-helper/rsaprivate" "github.com/golang-jwt/jwt/v4" "github.com/google/subcommands" "os" @@ -18,7 +16,7 @@ import ( ) type accessCmd struct { - issuer, subject, id, audience, duration string + issuer, subject, id, audience, duration, kID string } func (s *accessCmd) Name() string { return "access" } @@ -26,7 +24,7 @@ func (s *accessCmd) Synopsis() string { return "Generates an access token with permissions using the private key" } func (s *accessCmd) Usage() string { - return `sign [-iss ] [-sub ] [-id ] [-aud ] [-dur ] + return `sign [-iss ] [-sub ] [-id ] [-aud ] [-dur ] [-kid ] Output a signed MJWT token with the specified permissions. ` } @@ -37,6 +35,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&s.id, "id", "", "MJWT ID") f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT") f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)") + f.StringVar(&s.kID, "kid", "\x00", "The Key ID of the signing key") } func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { @@ -46,7 +45,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{} } args := f.Args() - key, err := s.parseKey(args[0]) + key, err := rsaprivate.Read(args[0]) if err != nil { _, _ = fmt.Fprintln(os.Stderr, "Error: Failed to parse private key: ", err) return subcommands.ExitFailure @@ -67,8 +66,17 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{} return subcommands.ExitFailure } - signer := mjwt.NewMJwtSigner(s.issuer, key) - token, err := signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}) + var token string + if s.kID == "\x00" { + signer := mjwt.NewMJwtSigner(s.issuer, key) + token, err = signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}) + } else { + kStore := mjwt.NewMJwtKeyStore() + kStore.SetKey(s.kID, key) + signer := mjwt.NewMJwtSignerWithKeyStore(s.issuer, nil, kStore) + token, err = signer.GenerateJwtWithKID(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}, s.kID) + } + if err != nil { _, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err) return subcommands.ExitFailure @@ -77,13 +85,3 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{} fmt.Println(token) return subcommands.ExitSuccess } - -func (s *accessCmd) parseKey(privKeyFile string) (*rsa.PrivateKey, error) { - b, err := os.ReadFile(privKeyFile) - if err != nil { - return nil, err - } - - p, _ := pem.Decode(b) - return x509.ParsePKCS1PrivateKey(p.Bytes) -} diff --git a/cmd/mjwt/gen.go b/cmd/mjwt/gen.go index 9eb0b47..b56e789 100644 --- a/cmd/mjwt/gen.go +++ b/cmd/mjwt/gen.go @@ -3,10 +3,10 @@ package main import ( "context" "crypto/rsa" - "crypto/x509" - "encoding/pem" "flag" "fmt" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" "github.com/google/subcommands" "math/rand" "os" @@ -49,29 +49,14 @@ func (g *genCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) s } func (g *genCmd) gen(privPath, pubPath string) error { - createPriv, err := os.OpenFile(privPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - defer createPriv.Close() - - createPub, err := os.OpenFile(pubPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - defer createPub.Close() - key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), g.bits) if err != nil { return err } - keyBytes := x509.MarshalPKCS1PrivateKey(key) - pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey) - err = pem.Encode(createPriv, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}) + err = rsaprivate.Write(privPath, key) if err != nil { return err } - err = pem.Encode(createPub, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes}) - return err + return rsapublic.Write(pubPath, &key.PublicKey) } diff --git a/go.mod b/go.mod index 150ebd1..6d52a03 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,11 @@ module github.com/1f349/mjwt -go 1.19 +go 1.22 + +toolchain go1.22.3 require ( + github.com/1f349/rsa-helper v0.0.1 github.com/becheran/wildmatch-go v1.0.0 github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/subcommands v1.2.0 diff --git a/go.sum b/go.sum index 239b42d..551cfd8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/1f349/rsa-helper v0.0.1 h1:Ec/MXHR2eIpLgIR69eqhCV2o8OOBs2JZNAkEhW7HQks= +github.com/1f349/rsa-helper v0.0.1/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg= github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA= github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/interfaces.go b/interfaces.go index 939a466..aeef231 100644 --- a/interfaces.go +++ b/interfaces.go @@ -12,12 +12,28 @@ type Signer interface { Verifier GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) SignJwt(claims jwt.Claims) (string, error) + GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) + SignJwtWithKID(claims jwt.Claims, kID string) (string, error) Issuer() string PrivateKey() *rsa.PrivateKey + PrivateKeyOf(kID string) *rsa.PrivateKey } // Verifier is used to verify the validity MJWT tokens and extract the claim values. type Verifier interface { VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) PublicKey() *rsa.PublicKey + PublicKeyOf(kID string) *rsa.PublicKey + GetKeyStore() KeyStore +} + +// KeyStore is used for the kid header support in Signer and Verifier. +type KeyStore interface { + SetKey(kID string, prvKey *rsa.PrivateKey) + SetKeyPublic(kID string, pubKey *rsa.PublicKey) + RemoveKey(kID string) + ListKeys() []string + GetKey(kID string) *rsa.PrivateKey + GetKeyPublic(kID string) *rsa.PublicKey + ClearKeys() } diff --git a/key_store.go b/key_store.go new file mode 100644 index 0000000..082fa3c --- /dev/null +++ b/key_store.go @@ -0,0 +1,185 @@ +package mjwt + +import ( + "crypto/rsa" + "errors" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" + "os" + "path" + "strings" + "sync" +) + +// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey +// or with rsa.PrivateKey instances as well. +type defaultMJwtKeyStore struct { + rwLocker *sync.RWMutex + store map[string]*rsa.PrivateKey + storePub map[string]*rsa.PublicKey +} + +var _ KeyStore = &defaultMJwtKeyStore{} + +// NewMJwtKeyStore creates a new defaultMJwtKeyStore. +func NewMJwtKeyStore() KeyStore { + return &defaultMJwtKeyStore{ + rwLocker: new(sync.RWMutex), + store: make(map[string]*rsa.PrivateKey), + storePub: make(map[string]*rsa.PublicKey), + } +} + +// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private +// rsa keys; the kID is the filename of the key up to the first . +func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) { + // Create empty KeyStore + ks := NewMJwtKeyStore().(*defaultMJwtKeyStore) + // List directory contents + dirEntries, err := os.ReadDir(directory) + if err != nil { + return nil, err + } + errs := make([]error, 0, len(dirEntries)/2) + // Import keys from files, based on extension + for _, entry := range dirEntries { + if entry.IsDir() { + continue + } + kID, _, _ := strings.Cut(entry.Name(), ".") + if kID == "" { + continue + } + pExt := path.Ext(entry.Name()) + if pExt == "."+keyPrvExt { + // Load rsa private key with the file name as the kID (Up to the first .) + key, err2 := rsaprivate.Read(path.Join(directory, entry.Name())) + if err2 == nil { + ks.store[kID] = key + ks.storePub[kID] = &key.PublicKey + } + errs = append(errs, err2) + } else if pExt == "."+keyPubExt { + // Load rsa public key with the file name as the kID (Up to the first .) + key, err2 := rsapublic.Read(path.Join(directory, entry.Name())) + if err2 == nil { + _, exs := ks.store[kID] + if !exs { + ks.store[kID] = nil + } + ks.storePub[kID] = key + } + errs = append(errs, err2) + } + } + return ks, errors.Join(errs...) +} + +// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified +// extensions for public and private keys +func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error { + if ks == nil { + return errors.New("ks is nil") + } + + // Create directory + err := os.MkdirAll(directory, 0700) + if err != nil { + return err + } + + errs := make([]error, 0, len(ks.ListKeys())/2) + // Export all keys + for _, kID := range ks.ListKeys() { + kPrv := ks.GetKey(kID) + if kPrv != nil { + err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv) + errs = append(errs, err2) + } + kPub := ks.GetKeyPublic(kID) + if kPub != nil { + err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub) + errs = append(errs, err2) + } + } + return errors.Join(errs...) +} + +// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore. +func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) { + if prvKey == nil { + return + } + d.rwLocker.Lock() + defer d.rwLocker.Unlock() + d.store[kID] = prvKey + d.storePub[kID] = &prvKey.PublicKey + return +} + +// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore. +func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) { + if pubKey == nil { + return + } + d.rwLocker.Lock() + defer d.rwLocker.Unlock() + _, exs := d.store[kID] + if !exs { + d.store[kID] = nil + } + d.storePub[kID] = pubKey + return +} + +// RemoveKey removes a specified kID from the KeyStore. +func (d *defaultMJwtKeyStore) RemoveKey(kID string) { + d.rwLocker.Lock() + defer d.rwLocker.Unlock() + delete(d.store, kID) + delete(d.storePub, kID) + return +} + +// ListKeys lists the kIDs of all the keys in the KeyStore. +func (d *defaultMJwtKeyStore) ListKeys() []string { + d.rwLocker.RLock() + defer d.rwLocker.RUnlock() + lKeys := make([]string, len(d.store)) + i := 0 + for k := range d.store { + lKeys[i] = k + i++ + } + return lKeys +} + +// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found. +func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey { + d.rwLocker.RLock() + defer d.rwLocker.RUnlock() + kPrv, ok := d.store[kID] + if ok { + return kPrv + } + return nil +} + +// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found. +func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey { + d.rwLocker.RLock() + defer d.rwLocker.RUnlock() + kPub, ok := d.storePub[kID] + if ok { + return kPub + } + return nil +} + +// ClearKeys removes all the stored keys in the KeyStore. +func (d *defaultMJwtKeyStore) ClearKeys() { + d.rwLocker.Lock() + defer d.rwLocker.Unlock() + clear(d.store) + clear(d.storePub) +} diff --git a/key_store_test.go b/key_store_test.go new file mode 100644 index 0000000..264f689 --- /dev/null +++ b/key_store_test.go @@ -0,0 +1,152 @@ +package mjwt + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" + "github.com/stretchr/testify/assert" + "os" + "path" + "testing" +) + +const kst_prvExt = "prv" +const kst_pubExt = "pub" + +func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) { + tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") + assert.NoError(t, err) + + if genKeys { + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+kst_prvExt), key1) + assert.NoError(t, err) + + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+kst_prvExt), key2) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+kst_pubExt), &key2.PublicKey) + assert.NoError(t, err) + + key3, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+kst_pubExt), &key3.PublicKey) + assert.NoError(t, err) + } + + return tempDir, func(t *testing.T) { + err := os.RemoveAll(tempDir) + assert.NoError(t, err) + } +} + +func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) { + key4, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + key5, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + const extraKID1 = "key4" + const extraKID2 = "key5" + + t.Run("TestSetKey", func(t *testing.T) { + kStore.SetKey(extraKID1, key4) + assert.Contains(t, kStore.ListKeys(), extraKID1) + }) + + t.Run("TestSetKeyPublic", func(t *testing.T) { + kStore.SetKeyPublic(extraKID2, &key5.PublicKey) + assert.Contains(t, kStore.ListKeys(), extraKID2) + }) + + t.Run("TestGetKey", func(t *testing.T) { + oKey := kStore.GetKey(extraKID1) + assert.Same(t, key4, oKey) + pKey := kStore.GetKey(extraKID2) + assert.Nil(t, pKey) + aKey := kStore.GetKey("key1") + assert.NotNil(t, aKey) + bKey := kStore.GetKey("key2") + assert.NotNil(t, bKey) + cKey := kStore.GetKey("key3") + assert.Nil(t, cKey) + }) + + t.Run("TestGetKeyPublic", func(t *testing.T) { + oKey := kStore.GetKeyPublic(extraKID1) + assert.Same(t, &key4.PublicKey, oKey) + pKey := kStore.GetKeyPublic(extraKID2) + assert.Same(t, &key5.PublicKey, pKey) + aKey := kStore.GetKeyPublic("key1") + assert.NotNil(t, aKey) + bKey := kStore.GetKeyPublic("key2") + assert.NotNil(t, bKey) + cKey := kStore.GetKeyPublic("key3") + assert.NotNil(t, cKey) + }) + + t.Run("TestRemoveKey", func(t *testing.T) { + kStore.RemoveKey(extraKID1) + assert.NotContains(t, kStore.ListKeys(), extraKID1) + oKey1 := kStore.GetKey(extraKID1) + assert.Nil(t, oKey1) + oKey2 := kStore.GetKeyPublic(extraKID1) + assert.Nil(t, oKey2) + }) + + t.Run("TestClearKeys", func(t *testing.T) { + kStore.ClearKeys() + assert.Empty(t, kStore.ListKeys()) + }) +} + +func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { + t.Parallel() + + tempDir, cleaner := setupTestDirKeyStore(t, true) + defer cleaner(t) + + kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt) + assert.NoError(t, err) + + assert.Len(t, kStore.ListKeys(), 3) + kIDsToFind := []string{"key1", "key2", "key3"} + for _, k := range kIDsToFind { + assert.Contains(t, kStore.ListKeys(), k) + } + + commonSubTestsKeyStore(t, kStore) +} + +func TestExportKeyStore(t *testing.T) { + t.Parallel() + + tempDir, cleaner := setupTestDirKeyStore(t, true) + defer cleaner(t) + tempDir2, cleaner2 := setupTestDirKeyStore(t, false) + defer cleaner2(t) + + kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt) + assert.NoError(t, err) + + const prvExt2 = "v" + const pubExt2 = "b" + + err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2) + assert.NoError(t, err) + + kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2) + assert.NoError(t, err) + + kIDsToFind := kStore.ListKeys() + assert.Len(t, kStore2.ListKeys(), len(kIDsToFind)) + for _, k := range kIDsToFind { + assert.Contains(t, kStore2.ListKeys(), k) + } + + commonSubTestsKeyStore(t, kStore2) +} diff --git a/mjwt_test.go b/mjwt_test.go index 3a83fe8..1c059eb 100644 --- a/mjwt_test.go +++ b/mjwt_test.go @@ -9,6 +9,8 @@ import ( "time" ) +var mt_ExtraKID = "tester" + type testClaims struct{ TestValue string } func (t testClaims) Valid() error { @@ -31,31 +33,111 @@ func (t testClaims2) Valid() error { func (t testClaims2) Type() string { return "testClaims2" } -func TestExtractClaims(t *testing.T) { - t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) +func setupTestKeyStoreMJWT(t *testing.T) (ks KeyStore, a, b, c *rsa.PrivateKey) { + ks = NewMJwtKeyStore() + var err error + + a, err = rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) + ks.SetKey("key1", a) - s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) + b, err = rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) + ks.SetKey("key2", b) - m := NewMJwtVerifier(&key.PublicKey) - _, _, err = ExtractClaims[testClaims](m, token) + c, err = rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) + ks.SetKey("key3", c) + + return +} + +func TestExtractClaims(t *testing.T) { + t.Parallel() + kStore, key, _, _ := setupTestKeyStoreMJWT(t) + + t.Run("TestNoKID", func(t *testing.T) { + t.Parallel() + s := NewMJwtSigner("mjwt.test", key) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) + assert.NoError(t, err) + + m := NewMJwtVerifier(&key.PublicKey) + _, _, err = ExtractClaims[testClaims](m, token) + assert.NoError(t, err) + }) + + t.Run("TestKID", func(t *testing.T) { + t.Parallel() + s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) + token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) + assert.NoError(t, err) + token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}, "key2") + assert.NoError(t, err) + + m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore) + _, _, err = ExtractClaims[testClaims](m, token1) + assert.NoError(t, err) + _, _, err = ExtractClaims[testClaims](m, token2) + assert.NoError(t, err) + }) } func TestExtractClaimsFail(t *testing.T) { t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) + kStore, key, key2, _ := setupTestKeyStoreMJWT(t) - s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) - assert.NoError(t, err) + t.Run("TestInvalidClaims", func(t *testing.T) { + t.Parallel() + s := NewMJwtSigner("mjwt.test", key) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) + assert.NoError(t, err) + + m := NewMJwtVerifier(&key.PublicKey) + _, _, err = ExtractClaims[testClaims2](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrClaimTypeMismatch) + }) + + t.Run("TestDefaultKeyNoKID", func(t *testing.T) { + t.Parallel() + s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) + token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, "key1") + assert.NoError(t, err) + + m := NewMJwtVerifier(&key.PublicKey) + _, _, err = ExtractClaims[testClaims](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoPublicKeyFound) + }) + + t.Run("TestNoDefaultKey", func(t *testing.T) { + t.Parallel() + s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) + assert.NoError(t, err) + + m := NewMJwtVerifierWithKeyStore(nil, kStore) + _, _, err = ExtractClaims[testClaims](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoPublicKeyFound) + }) + + t.Run("TestKIDNonExist", func(t *testing.T) { + t.Parallel() + kStore.SetKey(mt_ExtraKID, key2) + assert.Contains(t, kStore.ListKeys(), mt_ExtraKID) + + s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) + token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, mt_ExtraKID) + assert.NoError(t, err) + + kStore.RemoveKey(mt_ExtraKID) + assert.NotContains(t, kStore.ListKeys(), mt_ExtraKID) - m := NewMJwtVerifier(&key.PublicKey) - _, _, err = ExtractClaims[testClaims2](m, token) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrClaimTypeMismatch) + m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore) + _, _, err = ExtractClaims[testClaims](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoPublicKeyFound) + }) } diff --git a/signer.go b/signer.go index fa7947b..f3a2553 100644 --- a/signer.go +++ b/signer.go @@ -1,16 +1,18 @@ package mjwt import ( + "bytes" "crypto/rsa" - "crypto/x509" - "encoding/pem" - "fmt" + "errors" + "github.com/1f349/rsa-helper/rsaprivate" "github.com/golang-jwt/jwt/v4" "io" "os" "time" ) +var ErrNoPrivateKeyFound = errors.New("no private key found") + // defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name // to generate MJWT tokens type defaultMJwtSigner struct { @@ -24,10 +26,20 @@ var _ Verifier = &defaultMJwtSigner{} // NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer { + return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore()) +} + +// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey +// for no kID and a KeyStore for kID based keys +func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer { + var pKey *rsa.PublicKey = nil + if key != nil { + pKey = &key.PublicKey + } return &defaultMJwtSigner{ issuer: issuer, key: key, - verify: newMJwtVerifier(&key.PublicKey), + verify: NewMJwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier), } } @@ -45,50 +57,101 @@ func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits i // NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a // rsa.PrivateKey file. func NewMJwtSignerFromFile(issuer, file string) (Signer, error) { - // read file - raw, err := os.ReadFile(file) - if err != nil { - return nil, err - } + return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "") +} + +// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to +// load the keys into a KeyStore; there is no default rsa.PrivateKey +func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) { + return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt) +} + +// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey +// file as the non kID key and the path of a directory to load the keys into a KeyStore +func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) { + var err error - // decode pem block - block, _ := pem.Decode(raw) - if block == nil || block.Type != "RSA PRIVATE KEY" { - return nil, fmt.Errorf("invalid rsa private key pem block") + // read key + var prv *rsa.PrivateKey = nil + if file != "" { + prv, err = rsaprivate.Read(file) + if err != nil { + return nil, err + } } - // parse private key from pem block - key, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return nil, err + // read KeyStore + var kStore KeyStore = nil + if directory != "" { + kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt) + if err != nil { + return nil, err + } } - // create signer using rsa.PrivateKey - return NewMJwtSigner(issuer, key), nil + return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil } // Issuer returns the name of the issuer -func (d *defaultMJwtSigner) Issuer() string { return d.issuer } +func (d *defaultMJwtSigner) Issuer() string { + return d.issuer +} -// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims +// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) { return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims)) } // SignJwt signs a jwt.Claims compatible struct, this is used internally by -// GenerateJwt but is available for signing custom structs +// GenerateJwt but is available for signing custom structs; uses the default key func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) { + if d.key == nil { + return "", ErrNoPrivateKeyFound + } token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) return token.SignedString(d.key) } +// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID +func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) { + return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID) +} + +// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by +// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID +func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) { + pKey := d.verify.GetKeyStore().GetKey(kID) + if pKey == nil { + return "", ErrNoPrivateKeyFound + } + token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) + token.Header["kid"] = kID + return token.SignedString(pKey) +} + // VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt() func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { return d.verify.VerifyJwt(token, claims) } -func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { return d.key } -func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { return d.verify.pub } +func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { + return d.key +} +func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { + return d.verify.pub +} + +func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey { + return d.verify.kStore.GetKeyPublic(kID) +} + +func (d *defaultMJwtSigner) GetKeyStore() KeyStore { + return d.verify.GetKeyStore() +} + +func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey { + return d.verify.kStore.GetKey(kID) +} // readOrCreatePrivateKey returns the private key it the file already exists, // generates a new private key and saves it to the file, or returns an error if @@ -106,26 +169,15 @@ func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.Priva return nil, err } - keyBytes := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }) + // save key to file + err = rsaprivate.Write(file, key) if err != nil { return nil, err } - - // write the key to the file - err = os.WriteFile(file, keyBytes, 0600) return key, err } else { - // decode pem block - block, _ := pem.Decode(f) - if block == nil || block.Type != "RSA PRIVATE KEY" { - return nil, fmt.Errorf("invalid rsa private key pem block") - } - - // try to parse the private key - return x509.ParsePKCS1PrivateKey(block.Bytes) + // return key + return rsaprivate.Decode(bytes.NewReader(f)) } } diff --git a/signer_test.go b/signer_test.go index 64d8fda..f04fbae 100644 --- a/signer_test.go +++ b/signer_test.go @@ -5,11 +5,45 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" "github.com/stretchr/testify/assert" "os" + "path" "testing" ) +const st_prvExt = "prv" +const st_pubExt = "pub" + +func setupTestDirSigner(t *testing.T) (string, *rsa.PrivateKey, func(t *testing.T)) { + tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") + assert.NoError(t, err) + + var key3 *rsa.PrivateKey = nil + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1) + assert.NoError(t, err) + + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+st_prvExt), key2) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+st_pubExt), &key2.PublicKey) + assert.NoError(t, err) + + key3, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey) + assert.NoError(t, err) + + return tempDir, key3, func(t *testing.T) { + err := os.RemoveAll(tempDir) + assert.NoError(t, err) + } +} + func TestNewMJwtSigner(t *testing.T) { t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) @@ -17,6 +51,16 @@ func TestNewMJwtSigner(t *testing.T) { NewMJwtSigner("Test", key) } +func TestNewMJwtSignerWithKeyStore(t *testing.T) { + t.Parallel() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + kStore := NewMJwtKeyStore() + kStore.SetKey("test", key) + assert.Contains(t, kStore.ListKeys(), "test") + NewMJwtSignerWithKeyStore("Test", nil, kStore) +} + func TestNewMJwtSignerFromFile(t *testing.T) { t.Parallel() tempKey, err := os.CreateTemp("", "key-test-*.pem") @@ -67,3 +111,38 @@ func TestReadOrCreatePrivateKey(t *testing.T) { assert.NoError(t, err) assert.NoError(t, key3.Validate()) } + +func TestNewMJwtSignerFromDirectory(t *testing.T) { + t.Parallel() + + tempDir, prvKey3, cleaner := setupTestDirSigner(t) + defer cleaner(t) + + signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt) + assert.NoError(t, err) + + assert.Len(t, signer.GetKeyStore().ListKeys(), 3) + kIDsToFind := []string{"key1", "key2", "key3"} + for _, k := range kIDsToFind { + assert.Contains(t, signer.GetKeyStore().ListKeys(), k) + } + assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3"))) +} + +func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) { + t.Parallel() + + tempDir, prvKey3, cleaner := setupTestDirSigner(t) + defer cleaner(t) + + signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt) + assert.NoError(t, err) + + assert.Len(t, signer.GetKeyStore().ListKeys(), 3) + kIDsToFind := []string{"key1", "key2", "key3"} + for _, k := range kIDsToFind { + assert.Contains(t, signer.GetKeyStore().ListKeys(), k) + } + assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3"))) + assert.True(t, signer.PrivateKey().Equal(signer.GetKeyStore().GetKey("key1"))) +} diff --git a/verifier.go b/verifier.go index 0fe9411..d3d483b 100644 --- a/verifier.go +++ b/verifier.go @@ -2,54 +2,91 @@ package mjwt import ( "crypto/rsa" - "crypto/x509" - "encoding/pem" + "errors" + "github.com/1f349/rsa-helper/rsapublic" "github.com/golang-jwt/jwt/v4" - "os" ) +var ErrNoPublicKeyFound = errors.New("no public key found") +var ErrKIDInvalid = errors.New("kid invalid") + // defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate // MJWT tokens type defaultMJwtVerifier struct { - pub *rsa.PublicKey + pub *rsa.PublicKey + kStore KeyStore } var _ Verifier = &defaultMJwtVerifier{} // NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey func NewMJwtVerifier(key *rsa.PublicKey) Verifier { - return newMJwtVerifier(key) + return NewMJwtVerifierWithKeyStore(key, NewMJwtKeyStore()) } -func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier { - return &defaultMJwtVerifier{pub: key} +// NewMJwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key +// and a KeyStore for kID based keys +func NewMJwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier { + return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore} } // NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a // rsa.PublicKey file func NewMJwtVerifierFromFile(file string) (Verifier, error) { - // read file - f, err := os.ReadFile(file) - if err != nil { - return nil, err - } + return NewMJwtVerifierFromFileAndDirectory(file, "", "", "") +} - // decode pem block - block, _ := pem.Decode(f) +// NewMJwtVerifierFromDirectory creates a new defaultMJwtVerifier using the path of a directory to +// load the keys into a KeyStore; there is no default rsa.PublicKey +func NewMJwtVerifierFromDirectory(directory, prvExt, pubExt string) (Verifier, error) { + return NewMJwtVerifierFromFileAndDirectory("", directory, prvExt, pubExt) +} - // parse public key from pem block - pub, err := x509.ParsePKCS1PublicKey(block.Bytes) - if err != nil { - return nil, err +// NewMJwtVerifierFromFileAndDirectory creates a new defaultMJwtVerifier using the path of a rsa.PublicKey +// file as the non kID key and the path of a directory to load the keys into a KeyStore +func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string) (Verifier, error) { + var err error + + // read key + var pub *rsa.PublicKey = nil + if file != "" { + pub, err = rsapublic.Read(file) + if err != nil { + return nil, err + } + } + + // read KeyStore + var kStore KeyStore = nil + if directory != "" { + kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt) + if err != nil { + return nil, err + } } - // create verifier using rsa.PublicKey - return NewMJwtVerifier(pub), nil + return NewMJwtVerifierWithKeyStore(pub, kStore), nil } // VerifyJwt validates and parses MJWT tokens and returns the claims func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + kIDI, exs := token.Header["kid"] + if exs { + kID, ok := kIDI.(string) + if !ok { + return nil, ErrKIDInvalid + } + key := d.kStore.GetKeyPublic(kID) + if key == nil { + return nil, ErrNoPublicKeyFound + } else { + return key, nil + } + } + if d.pub == nil { + return nil, ErrNoPublicKeyFound + } return d.pub, nil }) if err != nil { @@ -58,4 +95,14 @@ func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jw return withClaims, claims.Valid() } -func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey { return d.pub } +func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey { + return d.pub +} + +func (d *defaultMJwtVerifier) PublicKeyOf(kID string) *rsa.PublicKey { + return d.kStore.GetKeyPublic(kID) +} + +func (d *defaultMJwtVerifier) GetKeyStore() KeyStore { + return d.kStore +} diff --git a/verifier_test.go b/verifier_test.go index 8e448dd..378ce1a 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -5,12 +5,49 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" "github.com/stretchr/testify/assert" "os" + "path" "testing" "time" ) +const vt_prvExt = "prv" +const vt_pubExt = "pub" + +func setupTestDirVerifier(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) { + tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") + assert.NoError(t, err) + + var key3 *rsa.PrivateKey = nil + + if genKeys { + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+vt_prvExt), key1) + assert.NoError(t, err) + + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+vt_prvExt), key2) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+vt_pubExt), &key2.PublicKey) + assert.NoError(t, err) + + key3, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+vt_pubExt), &key3.PublicKey) + assert.NoError(t, err) + } + + return tempDir, key3, func(t *testing.T) { + err := os.RemoveAll(tempDir) + assert.NoError(t, err) + } +} + func TestNewMJwtVerifierFromFile(t *testing.T) { t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) @@ -32,3 +69,43 @@ func TestNewMJwtVerifierFromFile(t *testing.T) { err = os.Remove(temp.Name()) assert.NoError(t, err) } + +func TestNewMJwtVerifierFromDirectory(t *testing.T) { + t.Parallel() + + tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true) + defer cleaner(t) + + s, err := NewMJwtSignerFromDirectory("mjwt.test", tempDir, vt_prvExt, vt_pubExt) + assert.NoError(t, err) + s.GetKeyStore().SetKey("key3", prvKey3) + token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3") + assert.NoError(t, err) + + v, err := NewMJwtVerifierFromDirectory(tempDir, vt_prvExt, vt_pubExt) + assert.NoError(t, err) + _, _, err = ExtractClaims[testClaims](v, token) + assert.NoError(t, err) +} + +func TestNewMJwtVerifierFromFileAndDirectory(t *testing.T) { + t.Parallel() + + tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true) + defer cleaner(t) + + s, err := NewMJwtSignerFromFileAndDirectory("mjwt.test", path.Join(tempDir, "key2.pem."+vt_prvExt), tempDir, vt_prvExt, vt_pubExt) + assert.NoError(t, err) + s.GetKeyStore().SetKey("key3", prvKey3) + token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}) + assert.NoError(t, err) + token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3") + assert.NoError(t, err) + + v, err := NewMJwtVerifierFromFileAndDirectory(path.Join(tempDir, "key2.pem."+vt_pubExt), tempDir, vt_prvExt, vt_pubExt) + assert.NoError(t, err) + _, _, err = ExtractClaims[testClaims](v, token1) + assert.NoError(t, err) + _, _, err = ExtractClaims[testClaims](v, token2) + assert.NoError(t, err) +}