diff --git a/README.md b/README.md index 745961e6..24406733 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,31 @@ -# Athenz policy updater -[![License: Apache](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=flat-square)](https://opensource.org/licenses/Apache-2.0) [![release](https://img.shields.io/github/release/yahoojapan/athenz-policy-updater.svg?style=flat-square)](https://github.com/yahoojapan/athenz-policy-updater/releases/latest) [![CircleCI](https://circleci.com/gh/yahoojapan/athenz-policy-updater.svg)](https://circleci.com/gh/yahoojapan/athenz-policy-updater) [![codecov](https://codecov.io/gh/yahoojapan/athenz-policy-updater/branch/master/graph/badge.svg?token=2CzooNJtUu&style=flat-square)](https://codecov.io/gh/yahoojapan/athenz-policy-updater) [![Go Report Card](https://goreportcard.com/badge/github.com/yahoojapan/athenz-policy-updater)](https://goreportcard.com/report/github.com/yahoojapan/athenz-policy-updater) [![GolangCI](https://golangci.com/badges/github.com/yahoojapan/athenz-policy-updater.svg?style=flat-square)](https://golangci.com/r/github.com/yahoojapan/athenz-policy-updater) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/828220605c43419e92fb0667876dd2d0)](https://www.codacy.com/app/i.can.feel.gravity/athenz-policy-updater?utm_source=github.com&utm_medium=referral&utm_content=yahoojapan/athenz-policy-updater&utm_campaign=Badge_Grade) [![GoDoc](http://godoc.org/github.com/yahoojapan/athenz-policy-updater?status.svg)](http://godoc.org/github.com/yahoojapan/athenz-policy-updater) -## What is Athenz policy updater +# Athenz authorizer +[![License: Apache](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=flat-square)](https://opensource.org/licenses/Apache-2.0) [![release](https://img.shields.io/github/release/yahoojapan/athenz-authorizer.svg?style=flat-square)](https://github.com/yahoojapan/athenz-authorizer/releases/latest) [![CircleCI](https://circleci.com/gh/yahoojapan/athenz-authorizer.svg)](https://circleci.com/gh/yahoojapan/athenz-authorizer) [![codecov](https://codecov.io/gh/yahoojapan/athenz-authorizer/branch/master/graph/badge.svg?token=2CzooNJtUu&style=flat-square)](https://codecov.io/gh/yahoojapan/athenz-authorizer) [![Go Report Card](https://goreportcard.com/badge/github.com/yahoojapan/athenz-authorizer)](https://goreportcard.com/report/github.com/yahoojapan/athenz-authorizer) [![GolangCI](https://golangci.com/badges/github.com/yahoojapan/athenz-authorizer.svg?style=flat-square)](https://golangci.com/r/github.com/yahoojapan/athenz-authorizer) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/828220605c43419e92fb0667876dd2d0)](https://www.codacy.com/app/i.can.feel.gravity/athenz-authorizer?utm_source=github.com&utm_medium=referral&utm_content=yahoojapan/athenz-authorizer&utm_campaign=Badge_Grade) [![GoDoc](http://godoc.org/github.com/yahoojapan/athenz-authorizer?status.svg)](http://godoc.org/github.com/yahoojapan/athenz-authorizer) +## What is Athenz authorizer -Athenz policy updater is a library to cache the policies of [Athenz](https://github.com/yahoo/athenz) to provider authenication and authorization check of user request. +Athenz authorizer is a library to cache the policies of [Athenz](https://github.com/yahoo/athenz) to authorizer authenication and authorization check of user request. ![Overview](./doc/policy_updater_overview.png) ## Usage -To initialize policy updater. +To initialize authorizer. ```golang -// Initialize providerd -daemon, err := providerd.New( - providerd.AthenzURL("www.athenz.io"), // set athenz URL - providerd.AthenzDomains("domain1", "domain2" ... "domain N"), // set athenz domains - providerd.PubkeyRefreshDuration(time.Hour * 24), // set athenz public key refresh duration - providerd.PolicyRefreshDuration(time.Hour), // set policy refresh duration +// Initialize authorizerd +daemon, err := authorizerd.New( + authorizerd.WithAthenzURL("www.athenz.io"), // set athenz URL + authorizerd.WithAthenzDomains("domain1", "domain2" ... "domain N"), // set athenz domains + authorizerd.WithPubkeyRefreshDuration(time.Hour * 24), // set athenz public key refresh duration + authorizerd.WithPolicyRefreshDuration(time.Hour), // set policy refresh duration ) if err != nil { - // cannot initialize policy updater daemon + // cannot initialize authorizer daemon } -// Start policy updater daemon -ctx := context.Background() // user can control policy updator daemon lifetime using this context -errs := daemon.StartProviderd(ctx) +// Start authorizer daemon +ctx := context.Background() // user can control authorizer daemon lifetime using this context +errs := daemon.Start(ctx) go func() { err := <-errs // user should handle errors return from the daemon @@ -39,9 +39,9 @@ if err := daemon.VerifyRoleToken(ctx, roleTok, act, res); err != nil { ## How it works -To do the authentication and authorization check, the user needs to specify which [domain data](https://github.com/yahoo/athenz/blob/master/docs/data_model.md#data-model) to be cache. The policy updater will periodically refresh the policies and Athenz public key data to [verify and decode]((https://github.com/yahoo/athenz/blob/master/docs/zpu_policy_file.md#zts-signature-validation)) the domain data. The verified domain data will cache into the memory, and use for authentication and authorization check. +To do the authentication and authorization check, the user needs to specify which [domain data](https://github.com/yahoo/athenz/blob/master/docs/data_model.md#data-model) to be cache. The authorizer will periodically refresh the policies and Athenz public key data to [verify and decode]((https://github.com/yahoo/athenz/blob/master/docs/zpu_policy_file.md#zts-signature-validation)) the domain data. The verified domain data will cache into the memory, and use for authentication and authorization check. -The policy updater contains two sub-module, Athenz pubkey daemon (pubkeyd) and Athenz policy daemon (policyd). +The authorizer contains two sub-module, Athenz pubkey daemon (pubkeyd) and Athenz policy daemon (policyd). ### Athenz pubkey daemon @@ -53,7 +53,7 @@ Athenz policy daemon (policyd) is responsible for periodically update the policy ## Configuratrion -The policy updater uses functional options pattern to initialize the instance. All the options are defined [here](./option.go). +The authorizer uses functional options pattern to initialize the instance. All the options are defined [here](./option.go). | Option name | Description | Default Value | Required | Example | |---------------------------|---------------------------------------------------------------------------------------------------------------------|-------------------------|----------|------------------------| diff --git a/authorizerd.go b/authorizerd.go new file mode 100644 index 00000000..a181632c --- /dev/null +++ b/authorizerd.go @@ -0,0 +1,295 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package authorizerd + +import ( + "context" + "crypto/x509" + "net/http" + "strings" + "time" + + "github.com/kpango/gache" + "github.com/kpango/glg" + + "github.com/pkg/errors" + "github.com/yahoojapan/athenz-authorizer/jwk" + "github.com/yahoojapan/athenz-authorizer/policy" + "github.com/yahoojapan/athenz-authorizer/pubkey" + "github.com/yahoojapan/athenz-authorizer/role" +) + +// Authorizerd represents a daemon for user to verify the role token +type Authorizerd interface { + Start(ctx context.Context) <-chan error + VerifyRoleToken(ctx context.Context, tok, act, res string) error + VerifyRoleJWT(ctx context.Context, tok, act, res string) error + VerifyRoleCert(ctx context.Context, peerCerts []*x509.Certificate, act, res string) error + GetPolicyCache(ctx context.Context) map[string]interface{} +} + +type authorizer struct { + // + pubkeyd pubkey.Daemon + policyd policy.Daemon + jwkd jwk.Daemon + roleProcessor role.Processor + + // common parameters + athenzURL string + client *http.Client + + // successful result cache + cache gache.Gache + cacheExp time.Duration + + // roleCertURIPrefix + roleCertURIPrefix string + + // pubkeyd parameters + disablePubkeyd bool + pubkeyRefreshDuration string + pubkeySysAuthDomain string + pubkeyEtagExpTime string + pubkeyEtagFlushDur string + + // policyd parameters + disablePolicyd bool + policyExpireMargin string + athenzDomains []string + policyRefreshDuration string + policyEtagFlushDur string + policyEtagExpTime string + + // jwkd parameters + disableJwkd bool + jwkRefreshDuration string +} + +type mode uint8 + +const ( + token mode = iota + jwt +) + +// New return Authorizerd +// This function will initialize the Authorizerd object with the options +func New(opts ...Option) (Authorizerd, error) { + var ( + prov = &authorizer{ + cache: gache.New(), + } + err error + + pubkeyProvider pubkey.Provider + jwkProvider jwk.Provider + ) + + for _, opt := range append(defaultOptions, opts...) { + if err = opt(prov); err != nil { + return nil, errors.Wrap(err, "error creating authorizerd") + } + } + + if !prov.disablePubkeyd { + if prov.pubkeyd, err = pubkey.New( + pubkey.WithAthenzURL(prov.athenzURL), + pubkey.WithSysAuthDomain(prov.pubkeySysAuthDomain), + pubkey.WithEtagExpTime(prov.pubkeyEtagExpTime), + pubkey.WithEtagFlushDuration(prov.pubkeyEtagFlushDur), + pubkey.WithRefreshDuration(prov.pubkeyRefreshDuration), + pubkey.WithHTTPClient(prov.client), + ); err != nil { + return nil, errors.Wrap(err, "error create pubkeyd") + } + + pubkeyProvider = prov.pubkeyd.GetProvider() + } + + if !prov.disablePolicyd { + if prov.policyd, err = policy.New( + policy.WithExpireMargin(prov.policyExpireMargin), + policy.WithEtagFlushDuration(prov.policyEtagFlushDur), + policy.WithEtagExpTime(prov.policyEtagExpTime), + policy.WithAthenzURL(prov.athenzURL), + policy.WithAthenzDomains(prov.athenzDomains...), + policy.WithRefreshDuration(prov.policyRefreshDuration), + policy.WithHTTPClient(prov.client), + policy.WithPubKeyProvider(prov.pubkeyd.GetProvider()), + ); err != nil { + return nil, errors.Wrap(err, "error create policyd") + } + } + + if !prov.disableJwkd { + if prov.jwkd, err = jwk.New( + jwk.WithAthenzURL(prov.athenzURL), + jwk.WithRefreshDuration(prov.jwkRefreshDuration), + jwk.WithHTTPClient(prov.client), + ); err != nil { + return nil, errors.Wrap(err, "error create jwkd") + } + + jwkProvider = prov.jwkd.GetProvider() + } + + prov.roleProcessor = role.New( + role.WithPubkeyProvider(pubkeyProvider), + role.WithJWKProvider(jwkProvider)) + + return prov, nil +} + +// Start starts authorizer daemon. +func (p *authorizer) Start(ctx context.Context) <-chan error { + var ( + ech = make(chan error, 200) + g = p.cache.StartExpired(ctx, p.cacheExp/2) + cech, pech, jech <-chan error + ) + + if !p.disablePubkeyd { + cech = p.pubkeyd.Start(ctx) + } + if !p.disablePolicyd { + pech = p.policyd.Start(ctx) + } + if !p.disableJwkd { + jech = p.jwkd.Start(ctx) + } + + go func() { + for { + select { + case <-ctx.Done(): + g.Stop() + g.Clear() + ech <- ctx.Err() + return + case err := <-cech: + if err != nil { + ech <- errors.Wrap(err, "update pubkey error") + } + case err := <-pech: + if err != nil { + ech <- errors.Wrap(err, "update policy error") + } + case err := <-jech: + if err != nil { + ech <- errors.Wrap(err, "update jwk error") + } + } + } + }() + + return ech +} + +// VerifyRoleToken verifies the role token for specific resource and return and verification error. +func (p *authorizer) VerifyRoleToken(ctx context.Context, tok, act, res string) error { + return p.verify(ctx, token, tok, act, res) +} + +func (p *authorizer) VerifyRoleJWT(ctx context.Context, tok, act, res string) error { + return p.verify(ctx, jwt, tok, act, res) +} + +func (p *authorizer) verify(ctx context.Context, m mode, tok, act, res string) error { + if act == "" || res == "" { + return errors.Wrap(ErrInvalidParameters, "empty action / resource") + } + + // check if exists in verification success cache + _, ok := p.cache.Get(tok + act + res) + if ok { + glg.Debugf("use cached result. tok: %s, act: %s, res: %s", tok, act, res) + return nil + } + + var ( + domain string + roles []string + ) + + switch m { + case token: + rt, err := p.roleProcessor.ParseAndValidateRoleToken(tok) + if err != nil { + glg.Debugf("error parse and validate role token, err: %v", err) + return errors.Wrap(err, "error verify role token") + } + domain = rt.Domain + roles = rt.Roles + case jwt: + rc, err := p.roleProcessor.ParseAndValidateRoleJWT(tok) + if err != nil { + glg.Debugf("error parse and validate role jwt, err: %v", err) + return errors.Wrap(err, "error verify role jwt") + } + domain = rc.Domain + roles = strings.Split(strings.TrimSpace(rc.Role), ",") + } + + if err := p.policyd.CheckPolicy(ctx, domain, roles, act, res); err != nil { + glg.Debugf("error check, err: %v", err) + return errors.Wrap(err, "token unauthorizate") + } + glg.Debugf("set roletoken result. tok: %s, act: %s, res: %s", tok, act, res) + p.cache.SetWithExpire(tok+act+res, struct{}{}, p.cacheExp) + return nil +} + +func (p *authorizer) VerifyRoleCert(ctx context.Context, peerCerts []*x509.Certificate, act, res string) error { + dr := make([]string, 0, 2) + drcheck := make(map[string]struct{}) + domainRoles := make(map[string][]string) + for _, cert := range peerCerts { + for _, uri := range cert.URIs { + if strings.HasPrefix(uri.String(), p.roleCertURIPrefix) { + dr = strings.SplitN(strings.TrimPrefix(uri.String(), p.roleCertURIPrefix), "/", 2) // domain/role + if len(dr) != 2 { + continue + } + domain, roleName := dr[0], dr[1] + // duplicated role check + if _, ok := drcheck[domain+roleName]; !ok { + domainRoles[domain] = append(domainRoles[domain], roleName) + drcheck[domain+roleName] = struct{}{} + } + } + } + } + + if len(domainRoles) == 0 { + return errors.New("not valid role certificate") + } + + var err error + for domain, roles := range domainRoles { + // TODO futurework + if err = p.policyd.CheckPolicy(ctx, domain, roles, act, res); err == nil { + return nil + } + } + + return errors.Wrap(err, "role certificates unauthorizate") +} + +func (p *authorizer) GetPolicyCache(ctx context.Context) map[string]interface{} { + return p.policyd.GetPolicyCache(ctx) + +} diff --git a/authorizerd_mock_test.go b/authorizerd_mock_test.go new file mode 100644 index 00000000..c066fdf3 --- /dev/null +++ b/authorizerd_mock_test.go @@ -0,0 +1,117 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package authorizerd + +import ( + "context" + "time" + + "github.com/pkg/errors" + "github.com/yahoojapan/athenz-authorizer/jwk" + "github.com/yahoojapan/athenz-authorizer/pubkey" + "github.com/yahoojapan/athenz-authorizer/role" +) + +type ConfdMock struct { + pubkey.Daemon + confdExp time.Duration +} + +func (cm *ConfdMock) Start(ctx context.Context) <-chan error { + ech := make(chan error, 1) + go func() { + time.Sleep(cm.confdExp) + ech <- errors.New("pubkey error") + }() + return ech +} + +type PolicydMock struct { + UpdateFunc func(context.Context) error + CheckPolicyFunc func(ctx context.Context, domain string, roles []string, action, resource string) error + + policydExp time.Duration + policyCache map[string]interface{} +} + +func (pm *PolicydMock) Start(context.Context) <-chan error { + ech := make(chan error, 1) + go func() { + time.Sleep(pm.policydExp) + ech <- errors.New("policyd error") + }() + return ech +} + +func (pm *PolicydMock) Update(ctx context.Context) error { + if pm.UpdateFunc != nil { + return pm.UpdateFunc(ctx) + } + return nil +} + +func (pm *PolicydMock) CheckPolicy(ctx context.Context, domain string, roles []string, action, resource string) error { + if pm.CheckPolicyFunc != nil { + return pm.CheckPolicyFunc(ctx, domain, roles, action, resource) + } + return nil +} + +func (pm *PolicydMock) GetPolicyCache(ctx context.Context) map[string]interface{} { + return pm.policyCache +} + +type TokenMock struct { + role.Processor + wantErr error + rt *role.Token + c *role.Claim +} + +func (rm *TokenMock) ParseAndValidateRoleToken(tok string) (*role.Token, error) { + return rm.rt, rm.wantErr +} + +func (rm *TokenMock) ParseAndValidateRoleJWT(cred string) (*role.Claim, error) { + return rm.c, rm.wantErr +} + +type JwkdMock struct { + StartFunc func(context.Context) <-chan error + UpdateFunc func(context.Context) error + GetProviderFunc func() jwk.Provider +} + +func (jm *JwkdMock) Start(ctx context.Context) <-chan error { + if jm.StartFunc != nil { + return jm.StartFunc(ctx) + } + return nil +} + +func (jm *JwkdMock) Update(ctx context.Context) error { + if jm.UpdateFunc != nil { + return jm.UpdateFunc(ctx) + } + return nil +} + +func (jm *JwkdMock) GetProvider() jwk.Provider { + if jm.GetProviderFunc != nil { + return jm.GetProviderFunc() + } + return nil +} diff --git a/authorizerd_test.go b/authorizerd_test.go new file mode 100644 index 00000000..dc658d03 --- /dev/null +++ b/authorizerd_test.go @@ -0,0 +1,1070 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package authorizerd + +import ( + "context" + "crypto/x509" + "encoding/pem" + "net/http" + "reflect" + "testing" + "time" + + "github.com/kpango/gache" + "github.com/pkg/errors" + "github.com/yahoojapan/athenz-authorizer/jwk" + "github.com/yahoojapan/athenz-authorizer/policy" + "github.com/yahoojapan/athenz-authorizer/pubkey" + "github.com/yahoojapan/athenz-authorizer/role" +) + +func TestNew(t *testing.T) { + type args struct { + opts []Option + } + tests := []struct { + name string + args args + checkFunc func(Authorizerd, error) error + }{ + { + name: "test new success", + args: args{ + []Option{}, + }, + checkFunc: func(prov Authorizerd, err error) error { + if err != nil { + return errors.Wrap(err, "unexpected error") + } + if prov.(*authorizer).athenzURL != "www.athenz.com/zts/v1" { + return errors.New("invalid url") + } + if prov.(*authorizer).pubkeyd == nil { + return errors.New("cannot new pubkeyd") + } + if prov.(*authorizer).policyd == nil { + return errors.New("cannot new policyd") + } + return nil + }, + }, + { + name: "test new success with options", + args: args{ + []Option{WithAthenzURL("www.dummy.com")}, + }, + checkFunc: func(prov Authorizerd, err error) error { + if err != nil { + return errors.Wrap(err, "unexpected error") + } + if prov.(*authorizer).athenzURL != "www.dummy.com" { + return errors.New("invalid url") + } + return nil + }, + }, + { + name: "test New returns error", + args: args{ + []Option{WithPubkeyEtagExpTime("dummy")}, + }, + checkFunc: func(prov Authorizerd, err error) error { + want := "error create pubkeyd: invalid etag expire time: time: invalid duration dummy" + if err.Error() != want { + return errors.Errorf("Unexpected error: %s, expected: %s", err, want) + } + return nil + }, + }, + { + name: "test NewPolicy returns error", + args: args{ + []Option{WithPolicyEtagExpTime("dummy")}, + }, + checkFunc: func(prov Authorizerd, err error) error { + if err.Error() != "error create policyd: error create policyd: invalid etag expire time: time: invalid duration dummy" { + return errors.Wrap(err, "unexpected error") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, goter := New(tt.args.opts...) + if err := tt.checkFunc(got, goter); err != nil { + t.Errorf("New() error = %v", err) + } + }) + } +} + +func TestStart(t *testing.T) { + type fields struct { + pubkeyd pubkey.Daemon + policyd policy.Daemon + jwkd jwk.Daemon + cache gache.Gache + cacheExp time.Duration + } + type args struct { + ctx context.Context + } + type test struct { + name string + fields fields + args args + checkFunc func(Authorizerd, error) error + afterFunc func() + } + tests := []test{ + func() test { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*10)) + cm := &ConfdMock{ + confdExp: time.Second, + } + pm := &PolicydMock{ + policydExp: time.Second, + } + jd := &JwkdMock{} + return test{ + name: "test context done", + fields: fields{ + pubkeyd: cm, + policyd: pm, + jwkd: jd, + cache: gache.New(), + cacheExp: time.Minute, + }, + args: args{ + ctx: ctx, + }, + checkFunc: func(prov Authorizerd, err error) error { + if err.Error() != "context deadline exceeded" { + return errors.Wrap(err, "unexpected err") + } + return nil + }, + afterFunc: func() { + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + cm := &ConfdMock{ + confdExp: time.Millisecond * 10, + } + pm := &PolicydMock{ + policydExp: time.Second, + } + jd := &JwkdMock{} + return test{ + name: "test context pubkey updater returns error", + fields: fields{ + pubkeyd: cm, + policyd: pm, + jwkd: jd, + cache: gache.New(), + cacheExp: time.Minute, + }, + args: args{ + ctx: ctx, + }, + checkFunc: func(prov Authorizerd, err error) error { + if err.Error() != "update pubkey error: pubkey error" { + return errors.Wrap(err, "unexpected err") + } + return nil + }, + afterFunc: func() { + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + cm := &ConfdMock{ + confdExp: time.Second, + } + pm := &PolicydMock{ + policydExp: time.Millisecond * 10, + } + jd := &JwkdMock{} + return test{ + name: "test policyd returns error", + fields: fields{ + pubkeyd: cm, + policyd: pm, + jwkd: jd, + cache: gache.New(), + cacheExp: time.Minute, + }, + args: args{ + ctx: ctx, + }, + checkFunc: func(prov Authorizerd, err error) error { + if err.Error() != "update policy error: policyd error" { + return errors.Wrap(err, "unexpected err") + } + return nil + }, + afterFunc: func() { + cancel() + }, + } + }(), + func() test { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*500)) + cm := &ConfdMock{ + confdExp: time.Second, + } + pm := &PolicydMock{ + policydExp: time.Second, + } + jd := &JwkdMock{ + StartFunc: func(ctx context.Context) <-chan error { + ch := make(chan error, 1) + go func() { + time.Sleep(time.Millisecond * 20) + ch <- errors.New("dummy") + }() + return ch + }, + } + return test{ + name: "test jwkd returns error", + fields: fields{ + pubkeyd: cm, + policyd: pm, + jwkd: jd, + cache: gache.New(), + cacheExp: time.Minute, + }, + args: args{ + ctx: ctx, + }, + checkFunc: func(prov Authorizerd, err error) error { + if err.Error() != "update jwk error: dummy" { + return errors.Errorf("unexpected error: %s", err) + } + return nil + }, + afterFunc: func() { + cancel() + }, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prov := &authorizer{ + pubkeyd: tt.fields.pubkeyd, + policyd: tt.fields.policyd, + jwkd: tt.fields.jwkd, + cache: tt.fields.cache, + cacheExp: tt.fields.cacheExp, + } + ch := prov.Start(tt.args.ctx) + goter := <-ch + if err := tt.checkFunc(prov, goter); err != nil { + t.Errorf("Start() error = %v", err) + } + tt.afterFunc() + }) + } +} + +func TestVerifyRoleToken(t *testing.T) { + type args struct { + ctx context.Context + tok string + act string + res string + } + type fields struct { + policyd policy.Daemon + cache gache.Gache + cacheExp time.Duration + roleTokenProcessor role.Processor + } + type test struct { + name string + args args + fields fields + wantErr string + checkFunc func(*authorizer) error + } + tests := []test{ + func() test { + c := gache.New() + rm := &TokenMock{ + rt: &role.Token{}, + wantErr: nil, + } + cm := &PolicydMock{} + return test{ + name: "test verify success", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleTokenProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "", + checkFunc: func(prov *authorizer) error { + _, ok := prov.cache.Get("dummyTokdummyActdummyRes") + if !ok { + return errors.New("cannot get dummyTokdummyActdummyRes from cache") + } + return nil + }, + } + }(), + func() test { + c := gache.New() + c.Set("dummyTokdummyActdummyRes", "dummy") + rm := &TokenMock{ + rt: &role.Token{}, + wantErr: nil, + } + cm := &PolicydMock{} + return test{ + name: "test use cache success", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleTokenProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "", + } + }(), + func() test { + c := gache.New() + c.Set("dummyTokdummyActdummyRes", "dummy") + rm := &TokenMock{ + rt: &role.Token{}, + wantErr: nil, + } + cm := &PolicydMock{} + return test{ + name: "test empty action", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleTokenProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "empty action / resource: Access denied due to invalid/empty action/resource values", + } + }(), + func() test { + c := gache.New() + c.Set("dummyTokdummyActdummyRes", "dummy") + rm := &TokenMock{ + rt: &role.Token{}, + wantErr: nil, + } + cm := &PolicydMock{} + return test{ + name: "test empty res", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "", + }, + fields: fields{ + policyd: cm, + roleTokenProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "empty action / resource: Access denied due to invalid/empty action/resource values", + } + }(), + func() test { + c := gache.New() + rm := &TokenMock{ + wantErr: errors.New("cannot parse roletoken"), + } + cm := &PolicydMock{} + return test{ + name: "test parse roletoken error", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleTokenProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "error verify role token: cannot parse roletoken", + } + }(), + func() test { + c := gache.New() + rm := &TokenMock{ + rt: &role.Token{}, + } + cm := &PolicydMock{ + CheckPolicyFunc: func(context.Context, string, []string, string, string) error { + return errors.New("deny") + }, + } + return test{ + name: "test return deny", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleTokenProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "token unauthorizate: deny", + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prov := &authorizer{ + policyd: tt.fields.policyd, + roleProcessor: tt.fields.roleTokenProcessor, + cache: tt.fields.cache, + cacheExp: tt.fields.cacheExp, + } + err := prov.VerifyRoleToken(tt.args.ctx, tt.args.tok, tt.args.act, tt.args.res) + if err != nil { + if err.Error() != tt.wantErr { + t.Errorf("VerifyRoleToken() unexpected error want:%s, result:%s", tt.wantErr, err.Error()) + return + } + } else { + if tt.wantErr != "" { + t.Errorf("VerifyRoleToken() return nil. want %s", tt.wantErr) + return + } + } + if tt.checkFunc != nil { + if err := tt.checkFunc(prov); err != nil { + t.Errorf("VerifyRoleToken() error: %v", err) + } + } + }) + } +} + +func Test_authorizer_VerifyRoleJWT(t *testing.T) { + type fields struct { + pubkeyd pubkey.Daemon + policyd policy.Daemon + jwkd jwk.Daemon + roleProcessor role.Processor + athenzURL string + client *http.Client + cache gache.Gache + cacheExp time.Duration + roleCertURIPrefix string + pubkeyRefreshDuration string + pubkeySysAuthDomain string + pubkeyEtagExpTime string + pubkeyEtagFlushDur string + policyExpireMargin string + athenzDomains []string + policyRefreshDuration string + policyEtagFlushDur string + policyEtagExpTime string + } + type args struct { + ctx context.Context + tok string + act string + res string + } + type test struct { + name string + args args + fields fields + wantErr string + checkFunc func(*authorizer) error + } + tests := []test{ + func() test { + c := gache.New() + rm := &TokenMock{ + c: &role.Claim{}, + wantErr: nil, + } + cm := &PolicydMock{} + return test{ + name: "test verify success", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "", + checkFunc: func(prov *authorizer) error { + _, ok := prov.cache.Get("dummyTokdummyActdummyRes") + if !ok { + return errors.New("cannot get dummyTokdummyActdummyRes from cache") + } + return nil + }, + } + }(), + func() test { + c := gache.New() + c.Set("dummyTokdummyActdummyRes", "dummy") + rm := &TokenMock{ + c: &role.Claim{}, + wantErr: nil, + } + cm := &PolicydMock{} + return test{ + name: "test use cache success", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "", + } + }(), + func() test { + c := gache.New() + c.Set("dummyTokdummyActdummyRes", "dummy") + rm := &TokenMock{ + c: &role.Claim{}, + wantErr: nil, + } + cm := &PolicydMock{} + return test{ + name: "test empty action", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "empty action / resource: Access denied due to invalid/empty action/resource values", + } + }(), + func() test { + c := gache.New() + c.Set("dummyTokdummyActdummyRes", "dummy") + rm := &TokenMock{ + c: &role.Claim{}, + wantErr: nil, + } + cm := &PolicydMock{} + return test{ + name: "test empty res", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "", + }, + fields: fields{ + policyd: cm, + roleProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "empty action / resource: Access denied due to invalid/empty action/resource values", + } + }(), + func() test { + c := gache.New() + rm := &TokenMock{ + wantErr: errors.New("cannot parse role jwt"), + } + cm := &PolicydMock{} + return test{ + name: "test parse role jwt error", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "error verify role jwt: cannot parse role jwt", + } + }(), + func() test { + c := gache.New() + rm := &TokenMock{ + c: &role.Claim{}, + } + cm := &PolicydMock{ + CheckPolicyFunc: func(context.Context, string, []string, string, string) error { + return errors.New("deny") + }, + } + return test{ + name: "test return deny", + args: args{ + ctx: context.Background(), + tok: "dummyTok", + act: "dummyAct", + res: "dummyRes", + }, + fields: fields{ + policyd: cm, + roleProcessor: rm, + cache: c, + cacheExp: time.Minute, + }, + wantErr: "token unauthorizate: deny", + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &authorizer{ + pubkeyd: tt.fields.pubkeyd, + policyd: tt.fields.policyd, + jwkd: tt.fields.jwkd, + roleProcessor: tt.fields.roleProcessor, + athenzURL: tt.fields.athenzURL, + client: tt.fields.client, + cache: tt.fields.cache, + cacheExp: tt.fields.cacheExp, + roleCertURIPrefix: tt.fields.roleCertURIPrefix, + pubkeyRefreshDuration: tt.fields.pubkeyRefreshDuration, + pubkeySysAuthDomain: tt.fields.pubkeySysAuthDomain, + pubkeyEtagExpTime: tt.fields.pubkeyEtagExpTime, + pubkeyEtagFlushDur: tt.fields.pubkeyEtagFlushDur, + policyExpireMargin: tt.fields.policyExpireMargin, + athenzDomains: tt.fields.athenzDomains, + policyRefreshDuration: tt.fields.policyRefreshDuration, + policyEtagFlushDur: tt.fields.policyEtagFlushDur, + policyEtagExpTime: tt.fields.policyEtagExpTime, + } + err := p.VerifyRoleJWT(tt.args.ctx, tt.args.tok, tt.args.act, tt.args.res) + if err != nil { + if err.Error() != tt.wantErr { + t.Errorf("VerifyRoleJWT() unexpected error want:%s, result:%s", tt.wantErr, err.Error()) + return + } + } else { + if tt.wantErr != "" { + t.Errorf("VerifyRoleJWT() return nil. want %s", tt.wantErr) + return + } + } + if tt.checkFunc != nil { + if err := tt.checkFunc(p); err != nil { + t.Errorf("VerifyRoleJWT() error: %v", err) + } + } + }) + } +} + +func Test_authorizer_verify(t *testing.T) { + type fields struct { + pubkeyd pubkey.Daemon + policyd policy.Daemon + jwkd jwk.Daemon + roleProcessor role.Processor + athenzURL string + client *http.Client + cache gache.Gache + cacheExp time.Duration + roleCertURIPrefix string + pubkeyRefreshDuration string + pubkeySysAuthDomain string + pubkeyEtagExpTime string + pubkeyEtagFlushDur string + policyExpireMargin string + athenzDomains []string + policyRefreshDuration string + policyEtagFlushDur string + policyEtagExpTime string + } + type args struct { + ctx context.Context + m mode + tok string + act string + res string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &authorizer{ + pubkeyd: tt.fields.pubkeyd, + policyd: tt.fields.policyd, + jwkd: tt.fields.jwkd, + roleProcessor: tt.fields.roleProcessor, + athenzURL: tt.fields.athenzURL, + client: tt.fields.client, + cache: tt.fields.cache, + cacheExp: tt.fields.cacheExp, + roleCertURIPrefix: tt.fields.roleCertURIPrefix, + pubkeyRefreshDuration: tt.fields.pubkeyRefreshDuration, + pubkeySysAuthDomain: tt.fields.pubkeySysAuthDomain, + pubkeyEtagExpTime: tt.fields.pubkeyEtagExpTime, + pubkeyEtagFlushDur: tt.fields.pubkeyEtagFlushDur, + policyExpireMargin: tt.fields.policyExpireMargin, + athenzDomains: tt.fields.athenzDomains, + policyRefreshDuration: tt.fields.policyRefreshDuration, + policyEtagFlushDur: tt.fields.policyEtagFlushDur, + policyEtagExpTime: tt.fields.policyEtagExpTime, + } + if err := p.verify(tt.args.ctx, tt.args.m, tt.args.tok, tt.args.act, tt.args.res); (err != nil) != tt.wantErr { + t.Errorf("authorizer.verify() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_authorizer_VerifyRoleCert(t *testing.T) { + type fields struct { + pubkeyd pubkey.Daemon + policyd policy.Daemon + jwkd jwk.Daemon + roleProcessor role.Processor + athenzURL string + client *http.Client + cache gache.Gache + cacheExp time.Duration + roleCertURIPrefix string + pubkeyRefreshDuration string + pubkeySysAuthDomain string + pubkeyEtagExpTime string + pubkeyEtagFlushDur string + policyExpireMargin string + athenzDomains []string + policyRefreshDuration string + policyEtagFlushDur string + policyEtagExpTime string + } + type args struct { + ctx context.Context + peerCerts []*x509.Certificate + act string + res string + } + type test struct { + name string + fields fields + args args + wantErr bool + } + tests := []test{ + func() test { + crt := `-----BEGIN CERTIFICATE----- +MIICGTCCAcOgAwIBAgIJALLML3PdJAZ1MA0GCSqGSIb3DQEBCwUAMFwxCzAJBgNV +BAYTAlVTMQswCQYDVQQIEwJDQTEPMA0GA1UEChMGQXRoZW56MRcwFQYDVQQLEw5U +ZXN0aW5nIERvbWFpbjEWMBQGA1UEAxMNYXRoZW56LnN5bmNlcjAeFw0xOTA0Mjcw +MjQ2MjNaFw0yOTA0MjQwMjQ2MjNaMFwxCzAJBgNVBAYTAlVTMQswCQYDVQQIEwJD +QTEPMA0GA1UEChMGQXRoZW56MRcwFQYDVQQLEw5UZXN0aW5nIERvbWFpbjEWMBQG +A1UEAxMNYXRoZW56LnN5bmNlcjBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQCvv27a +SNAnK0vcN8fqqQgMHwb0EhfVWMwoRTBQFrCmA9mH/84QgI/0kR3ZI+DlDNBCgDHd +rEJZVPyX2V41VOX3AgMBAAGjaDBmMGQGA1UdEQRdMFuGGXNwaWZmZTovL2F0aGVu +ei9zYS9zeW5jZXKGHmF0aGVuejovL3JvbGUvY29yZXRlY2gvcmVhZGVyc4YeYXRo +ZW56Oi8vcm9sZS9jb3JldGVjaC93cml0ZXJzMA0GCSqGSIb3DQEBCwUAA0EAa3Ra +Wo7tEDFBGqSVYSVuoh0GpsWC0VBAYYi9vhAGfp+g5M2oszvRuxOHYsQmYAjYroTJ +bu80CwTnWhmdBo36Ig== +-----END CERTIFICATE----- +` + block, _ := pem.Decode([]byte(crt)) + cert, _ := x509.ParseCertificate(block.Bytes) + + pm := &PolicydMock{ + CheckPolicyFunc: func(ctx context.Context, domain string, roles []string, act, res string) error { + containRole := func(r string) bool { + for _, role := range roles { + if role == r { + return true + } + } + return false + } + if domain != "coretech" { + return errors.Errorf("invalid domain, got: %s, want: %s", domain, "coretech") + } + if !containRole("readers") || !containRole("writers") { + return errors.Errorf("invalid role, got: %s", roles) + } + return nil + }, + } + + return test{ + name: "parse and verify role cert success", + fields: fields{ + roleCertURIPrefix: "athenz://role/", + policyd: pm, + }, + args: args{ + ctx: context.Background(), + peerCerts: []*x509.Certificate{ + cert, + }, + act: "abc", + res: "def", + }, + } + }(), + func() test { + crt := ` +-----BEGIN CERTIFICATE----- +MIICLjCCAZegAwIBAgIBADANBgkqhkiG9w0BAQ0FADA0MQswCQYDVQQGEwJ1czEL +MAkGA1UECAwCSEsxCzAJBgNVBAoMAkhLMQswCQYDVQQDDAJISzAeFw0xOTA3MDQw +NjU2MTJaFw0yMDA3MDMwNjU2MTJaMDQxCzAJBgNVBAYTAnVzMQswCQYDVQQIDAJI +SzELMAkGA1UECgwCSEsxCzAJBgNVBAMMAkhLMIGfMA0GCSqGSIb3DQEBAQUAA4GN +ADCBiQKBgQDdUHpdYo/UeYvzB4Z3WvUe2yHsuxrhh7x/D2A5OPb19+ZZy4cdMDUW +qd3hw/tvBWxSUYueL75AifVAQdncUJ+7of3WByFYVSemDrdlD9K/+PyGFZotA+Xj +GmNWjAsGBYuU5roxJZI2c78vJzKj2DU1a9hq/PJ9WGvX4i1Xwf0FKwIDAQABo1Aw +TjAdBgNVHQ4EFgQUiLEo7+nigzdGft2ZEbpkZFxgU+MwHwYDVR0jBBgwFoAUiLEo +7+nigzdGft2ZEbpkZFxgU+MwDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQ0FAAOB +gQCiedWe2DXuE0ak1oGV+28qLpyc/Ff9RNNwUbCKB6L/+OWoROVdaz/DoZjfE9vr +ilcIAqkugYyMzW4cY2RexOLYrkyyjLjMj5C2ff4m13gqRLHU0rFpaKpjYr8KYiGD +KSdPh6TRd/kYpv7t6cVm1Orll4O5jh+IdoguGkOCxheMaQ== +-----END CERTIFICATE-----` + block, _ := pem.Decode([]byte(crt)) + cert, _ := x509.ParseCertificate(block.Bytes) + + pm := &PolicydMock{ + CheckPolicyFunc: func(ctx context.Context, domain string, roles []string, act, res string) error { + return nil + }, + } + + return test{ + name: "invalid athenz role certificate", + fields: fields{ + roleCertURIPrefix: "athenz://role/", + policyd: pm, + }, + args: args{ + ctx: context.Background(), + peerCerts: []*x509.Certificate{ + cert, + }, + act: "abc", + res: "def", + }, + wantErr: true, + } + }(), + func() test { + crt := `-----BEGIN CERTIFICATE----- +MIICGTCCAcOgAwIBAgIJALLML3PdJAZ1MA0GCSqGSIb3DQEBCwUAMFwxCzAJBgNV +BAYTAlVTMQswCQYDVQQIEwJDQTEPMA0GA1UEChMGQXRoZW56MRcwFQYDVQQLEw5U +ZXN0aW5nIERvbWFpbjEWMBQGA1UEAxMNYXRoZW56LnN5bmNlcjAeFw0xOTA0Mjcw +MjQ2MjNaFw0yOTA0MjQwMjQ2MjNaMFwxCzAJBgNVBAYTAlVTMQswCQYDVQQIEwJD +QTEPMA0GA1UEChMGQXRoZW56MRcwFQYDVQQLEw5UZXN0aW5nIERvbWFpbjEWMBQG +A1UEAxMNYXRoZW56LnN5bmNlcjBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQCvv27a +SNAnK0vcN8fqqQgMHwb0EhfVWMwoRTBQFrCmA9mH/84QgI/0kR3ZI+DlDNBCgDHd +rEJZVPyX2V41VOX3AgMBAAGjaDBmMGQGA1UdEQRdMFuGGXNwaWZmZTovL2F0aGVu +ei9zYS9zeW5jZXKGHmF0aGVuejovL3JvbGUvY29yZXRlY2gvcmVhZGVyc4YeYXRo +ZW56Oi8vcm9sZS9jb3JldGVjaC93cml0ZXJzMA0GCSqGSIb3DQEBCwUAA0EAa3Ra +Wo7tEDFBGqSVYSVuoh0GpsWC0VBAYYi9vhAGfp+g5M2oszvRuxOHYsQmYAjYroTJ +bu80CwTnWhmdBo36Ig== +-----END CERTIFICATE----- +` + block, _ := pem.Decode([]byte(crt)) + cert, _ := x509.ParseCertificate(block.Bytes) + + pm := &PolicydMock{ + CheckPolicyFunc: func(ctx context.Context, domain string, roles []string, act, res string) error { + return errors.New("dummy") + }, + } + + return test{ + name: "parse and verify role cert success", + fields: fields{ + roleCertURIPrefix: "athenz://role/", + policyd: pm, + }, + args: args{ + ctx: context.Background(), + peerCerts: []*x509.Certificate{ + cert, + }, + act: "abc", + res: "def", + }, + wantErr: true, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &authorizer{ + pubkeyd: tt.fields.pubkeyd, + policyd: tt.fields.policyd, + jwkd: tt.fields.jwkd, + roleProcessor: tt.fields.roleProcessor, + athenzURL: tt.fields.athenzURL, + client: tt.fields.client, + cache: tt.fields.cache, + cacheExp: tt.fields.cacheExp, + roleCertURIPrefix: tt.fields.roleCertURIPrefix, + pubkeyRefreshDuration: tt.fields.pubkeyRefreshDuration, + pubkeySysAuthDomain: tt.fields.pubkeySysAuthDomain, + pubkeyEtagExpTime: tt.fields.pubkeyEtagExpTime, + pubkeyEtagFlushDur: tt.fields.pubkeyEtagFlushDur, + policyExpireMargin: tt.fields.policyExpireMargin, + athenzDomains: tt.fields.athenzDomains, + policyRefreshDuration: tt.fields.policyRefreshDuration, + policyEtagFlushDur: tt.fields.policyEtagFlushDur, + policyEtagExpTime: tt.fields.policyEtagExpTime, + } + if err := p.VerifyRoleCert(tt.args.ctx, tt.args.peerCerts, tt.args.act, tt.args.res); (err != nil) != tt.wantErr { + t.Errorf("authorizer.VerifyRoleCert() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_authorizer_GetPolicyCache(t *testing.T) { + type fields struct { + pubkeyd pubkey.Daemon + policyd policy.Daemon + jwkd jwk.Daemon + roleProcessor role.Processor + athenzURL string + client *http.Client + cache gache.Gache + cacheExp time.Duration + roleCertURIPrefix string + pubkeyRefreshDuration string + pubkeySysAuthDomain string + pubkeyEtagExpTime string + pubkeyEtagFlushDur string + policyExpireMargin string + athenzDomains []string + policyRefreshDuration string + policyEtagFlushDur string + policyEtagExpTime string + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + want map[string]interface{} + }{ + { + name: "GetPolicyCache success", + fields: fields{ + policyd: &PolicydMock{}, + }, + args: args{ + ctx: context.Background(), + }, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &authorizer{ + pubkeyd: tt.fields.pubkeyd, + policyd: tt.fields.policyd, + jwkd: tt.fields.jwkd, + roleProcessor: tt.fields.roleProcessor, + athenzURL: tt.fields.athenzURL, + client: tt.fields.client, + cache: tt.fields.cache, + cacheExp: tt.fields.cacheExp, + roleCertURIPrefix: tt.fields.roleCertURIPrefix, + pubkeyRefreshDuration: tt.fields.pubkeyRefreshDuration, + pubkeySysAuthDomain: tt.fields.pubkeySysAuthDomain, + pubkeyEtagExpTime: tt.fields.pubkeyEtagExpTime, + pubkeyEtagFlushDur: tt.fields.pubkeyEtagFlushDur, + policyExpireMargin: tt.fields.policyExpireMargin, + athenzDomains: tt.fields.athenzDomains, + policyRefreshDuration: tt.fields.policyRefreshDuration, + policyEtagFlushDur: tt.fields.policyEtagFlushDur, + policyEtagExpTime: tt.fields.policyEtagExpTime, + } + if got := a.GetPolicyCache(tt.args.ctx); !reflect.DeepEqual(got, tt.want) { + t.Errorf("authorizer.GetPolicyCache() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/doc.go b/doc.go index 74672f24..59c529b7 100644 --- a/doc.go +++ b/doc.go @@ -14,5 +14,5 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package providerd represents the policy updater daemon. -package providerd +// Package authorizerd represents the policy updater daemon. +package authorizerd diff --git a/errors.go b/errors.go index 3a8fbc92..bb0443f9 100644 --- a/errors.go +++ b/errors.go @@ -13,13 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -package providerd +package authorizerd import ( "errors" - "github.com/yahoojapan/athenz-policy-updater/policy" - "github.com/yahoojapan/athenz-policy-updater/role" + "github.com/yahoojapan/athenz-authorizer/policy" + "github.com/yahoojapan/athenz-authorizer/role" ) var ( diff --git a/go.mod b/go.mod index 2753f70f..35a82006 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,15 @@ -module github.com/yahoojapan/athenz-policy-updater +module github.com/yahoojapan/athenz-authorizer go 1.12 require ( github.com/ardielle/ardielle-go v1.5.2 + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/google/go-cmp v0.3.0 - github.com/kpango/gache v1.1.15 - github.com/kpango/glg v1.4.4 + github.com/kpango/gache v1.1.19 + github.com/kpango/glg v1.4.5 + github.com/lestrrat-go/jwx v0.9.0 github.com/pkg/errors v0.8.1 - github.com/yahoo/athenz v1.8.23 + github.com/yahoo/athenz v1.8.24 golang.org/x/sync v0.0.0-20190423024810-112230192c58 ) diff --git a/go.sum b/go.sum index 3716ead2..a17f3cce 100644 --- a/go.sum +++ b/go.sum @@ -3,9 +3,10 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE github.com/OneOfOne/xxhash v1.2.5/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q= github.com/OrlovEvgeny/go-mcache v0.0.0-20181113222421-bed69649df7d/go.mod h1:HyURA1Z5rjNkt9E7XyiegZk1ZBvvB+1vYzkeu52goIc= github.com/OrlovEvgeny/go-mcache v0.0.0-20190520090815-302f7b82bb96/go.mod h1:9X0sgxPdm23rjgP2JD9uS+0QL2EEpQKRgABK3+BQ81A= -github.com/VictoriaMetrics/fastcache v1.5.0/go.mod h1:+jv9Ckb+za/P1ZRg/sulP5Ni1v49daAVERr0H3CuscE= +github.com/VictoriaMetrics/fastcache v1.5.1/go.mod h1:+jv9Ckb+za/P1ZRg/sulP5Ni1v49daAVERr0H3CuscE= github.com/allegro/bigcache v1.1.0/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= +github.com/allegro/bigcache v1.2.1/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/ardielle/ardielle-go v1.5.2 h1:TilHTpHIQJ27R1Tl/iITBzMwiUGSlVfiVhwDNGM3Zj4= github.com/ardielle/ardielle-go v1.5.2/go.mod h1:I4hy1n795cUhaVt/ojz83SNVCYIGsAFAONtv2Dr7HUI= github.com/ardielle/ardielle-tools v1.5.4/go.mod h1:oZN+JRMnqGiIhrzkRN9l26Cej9dEx4jeNG6A+AdkShk= @@ -21,6 +22,8 @@ github.com/coocood/freecache v1.0.1/go.mod h1:ePwxCDzOYvARfHdr1pByNct1at3CoKnsip github.com/coocood/freecache v1.1.0/go.mod h1:ePwxCDzOYvARfHdr1pByNct1at3CoKnsipOHwKlNbzI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= @@ -32,13 +35,15 @@ github.com/hlts2/gocache v0.0.0-20190217073200-8b772e486b6e/go.mod h1:F4tUovaw56 github.com/jawher/mow.cli v1.0.4/go.mod h1:5hQj2V8g+qYmLUVWqu4Wuja1pI57M83EChYLVZ0sMKk= github.com/jawher/mow.cli v1.1.0/go.mod h1:aNaQlc7ozF3vw6IJ2dHjp2ZFiA4ozMIYY6PyuRJwlUg= github.com/kpango/fastime v1.0.0/go.mod h1:Y5XY5bLG5yc7g2XmMUzc22XYV1XaH+KgUOHkDvLp4SA= -github.com/kpango/fastime v1.0.12 h1:7I1Ha4tu1GOM9LIGEQnNYCcA5BrcYfJ+4ZU0ed491B8= -github.com/kpango/fastime v1.0.12/go.mod h1:lVqUTcXmQnk1wriyvq5DElbRSRDC0XtqbXQRdz0Eo+g= +github.com/kpango/fastime v1.0.14 h1:oubwGg1oUyxe6HLOgvDUcuCmC1/1AXgwhtRVGsUV3f0= +github.com/kpango/fastime v1.0.14/go.mod h1:lVqUTcXmQnk1wriyvq5DElbRSRDC0XtqbXQRdz0Eo+g= github.com/kpango/gache v1.1.0/go.mod h1:BHKRCYnJ2pRFFIJNc61KTJb3KXSzlrt/ITfgfCQJAJw= -github.com/kpango/gache v1.1.15 h1:eKRYImFf9T49xmGm+f+XMq5Lz5LJ4ovVvb1H3t3eIc4= -github.com/kpango/gache v1.1.15/go.mod h1:fPggz5URX77Xcsb8nve85y47ltkJNfT0sSyCEpcxEn8= -github.com/kpango/glg v1.4.4 h1:bXzAlvur2Nk4RseQNedBi70gIPdpNLPhAoYQ7wD0RjE= -github.com/kpango/glg v1.4.4/go.mod h1:besyu510on2Btx6QuKa0Lqj7PaMP/0O9AUwlaiC+et8= +github.com/kpango/gache v1.1.19 h1:3Vdcm1Ic35xU3fmr6SHOSWnldw9fSoBhVYxVpOUxet0= +github.com/kpango/gache v1.1.19/go.mod h1:WDvNUs3vLe+1yIOzrpnKwxwNDuUm41mI/Mq7RX6ZIOs= +github.com/kpango/glg v1.4.5 h1:fToaqeCUkPAmdArV3NNucWGX15vXfh2RB23Fgg/vv5A= +github.com/kpango/glg v1.4.5/go.mod h1:Hq2meR77NKh8vxar+lCIjUHpCPh0Q+LQUFDmduvW9G4= +github.com/lestrrat-go/jwx v0.9.0 h1:Fnd0EWzTm0kFrBPzE/PEPp9nzllES5buMkksPMjEKpM= +github.com/lestrrat-go/jwx v0.9.0/go.mod h1:iEoxlYfZjvoGpuWwxUz+eR5e6KTJGsaRcy/YNA/UnBk= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -49,8 +54,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/vmihailenco/msgpack v4.0.1+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= -github.com/yahoo/athenz v1.8.23 h1:5s9keEQG88iflPFRh1d0RjWvi9TUhCr3dwsb7XnuY3o= -github.com/yahoo/athenz v1.8.23/go.mod h1:wQ3kpWCncoxMXxj2mr/Gf6YJwtE+Uwk7DOPYpm/C99A= +github.com/yahoo/athenz v1.8.24 h1:Bf+2xcG06wgeDY/X/dbJft5uW4ip9ao/5Kvp+uC5osw= +github.com/yahoo/athenz v1.8.24/go.mod h1:wQ3kpWCncoxMXxj2mr/Gf6YJwtE+Uwk7DOPYpm/C99A= golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/net v0.0.0-20180921000356-2f5d2388922f/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/jwk/daemon.go b/jwk/daemon.go new file mode 100644 index 00000000..ae56c002 --- /dev/null +++ b/jwk/daemon.go @@ -0,0 +1,144 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package jwk + +import ( + "context" + "fmt" + "net/http" + "sync/atomic" + "time" + + "github.com/kpango/glg" + "github.com/lestrrat-go/jwx/jwk" + "github.com/pkg/errors" +) + +// Daemon represents the daemon to retrieve jwk from Athenz. +type Daemon interface { + Start(ctx context.Context) <-chan error + Update(context.Context) error + GetProvider() Provider +} + +type jwkd struct { + athenzURL string + refreshDuration time.Duration + errRetryInterval time.Duration + + client *http.Client + + keys atomic.Value +} + +// Provider represent the jwk provider to retrive the json web key. +type Provider func(keyID string) interface{} + +// New represent the constructor of Policyd +func New(opts ...Option) (Daemon, error) { + j := new(jwkd) + for _, opt := range append(defaultOptions, opts...) { + err := opt(j) + if err != nil { + return nil, errors.Wrap(err, "error create policyd") + } + } + + return j, nil +} + +func (j *jwkd) Start(ctx context.Context) <-chan error { + glg.Info("Starting jwk updator") + ech := make(chan error, 100) + fch := make(chan struct{}, 1) + if err := j.Update(ctx); err != nil { + ech <- errors.Wrap(err, "error update athenz json web key") + fch <- struct{}{} + } + + go func() { + defer close(fch) + defer close(ech) + ticker := time.NewTicker(j.refreshDuration) + ebuf := errors.New("") + + update := func() { + if err := j.Update(ctx); err != nil { + err = errors.Wrap(err, "error update athenz json web key") + select { + case ech <- errors.Wrap(ebuf, err.Error()): + ebuf = errors.New("") + default: + ebuf = errors.Wrap(ebuf, err.Error()) + } + select { + case fch <- struct{}{}: + default: + glg.Warn("failure queue already full") + } + } + } + + for { + select { + case <-ctx.Done(): + glg.Info("Stopping jwkd") + ticker.Stop() + if ebuf.Error() != "" { + ech <- errors.Wrap(ctx.Err(), ebuf.Error()) + } else { + ech <- ctx.Err() + } + return + case <-fch: + update() + case <-ticker.C: + update() + } + } + }() + + return ech +} + +func (j *jwkd) Update(ctx context.Context) (err error) { + url := fmt.Sprintf("https://%s/oauth2/keys", j.athenzURL) + keys, err := jwk.FetchHTTP(url, jwk.WithHTTPClient(j.client)) + if err != nil { + return err + } + + j.keys.Store(keys) + return nil +} + +func (j *jwkd) GetProvider() Provider { + return j.getKey +} + +func (j *jwkd) getKey(keyID string) interface{} { + if keyID == "" { + return nil + } + + for _, keys := range j.keys.Load().(*jwk.Set).LookupKeyID(keyID) { + raw, err := keys.Materialize() + if err == nil { + return raw + } + } + return nil +} diff --git a/jwk/daemon_test.go b/jwk/daemon_test.go new file mode 100644 index 00000000..d09585a6 --- /dev/null +++ b/jwk/daemon_test.go @@ -0,0 +1,541 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package jwk + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/lestrrat-go/jwx/jwk" + "github.com/pkg/errors" +) + +func TestNew(t *testing.T) { + type args struct { + opts []Option + } + tests := []struct { + name string + args args + want Daemon + wantErr bool + }{ + { + name: "New daemon success", + args: args{ + opts: []Option{ + WithAthenzURL("www.dummy.com"), + }, + }, + want: &jwkd{ + athenzURL: "www.dummy.com", + refreshDuration: time.Hour * 24, + errRetryInterval: time.Millisecond, + client: http.DefaultClient, + }, + }, + { + name: "New daemon fail", + args: args{ + opts: []Option{ + WithRefreshDuration("dummy"), + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := New(tt.args.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_jwkd_Start(t *testing.T) { + type fields struct { + athenzURL string + refreshDuration time.Duration + errRetryInterval time.Duration + client *http.Client + keys atomic.Value + } + type args struct { + ctx context.Context + } + type test struct { + name string + fields fields + args args + checkFunc func(*jwkd, <-chan error) error + afterFunc func() + } + tests := []test{ + func() test { + k := `{ + "e":"AQAB", + "kty":"RSA", + "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw" + }` + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(k)) + })) + ctx, cancel := context.WithCancel(context.Background()) + + return test{ + name: "Start success", + fields: fields{ + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + refreshDuration: time.Millisecond * 10, + errRetryInterval: time.Millisecond, + client: srv.Client(), + }, + args: args{ + ctx: ctx, + }, + checkFunc: func(j *jwkd, ch <-chan error) error { + time.Sleep(time.Millisecond * 100) + cancel() + if k := j.keys.Load(); k == nil { + return errors.New("cannot update keys") + } + + return nil + }, + afterFunc: func() { + cancel() + }, + } + }(), + func() test { + i := 1 + k := `{ + "e":"AQAB", + "kty":"RSA", + "kid" :"%s", + "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw" + }` + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(fmt.Sprintf(k, i))) + i = i + 1 + })) + ctx, cancel := context.WithCancel(context.Background()) + + return test{ + name: "Start can update", + fields: fields{ + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + refreshDuration: time.Millisecond * 10, + errRetryInterval: time.Millisecond, + client: srv.Client(), + }, + args: args{ + ctx: ctx, + }, + checkFunc: func(j *jwkd, ch <-chan error) error { + time.Sleep(time.Millisecond * 100) + k1 := j.keys.Load() + if k1 == nil { + return errors.New("cannot update keys") + } + + time.Sleep(time.Millisecond * 30) + cancel() + + k2 := j.keys.Load() + if k2 == nil { + return errors.New("cannot update keys") + } + + if k1 == k2 { + return errors.Errorf("key do not update after it starts, k1: %s, k2: %s", k1, k2) + } + + return nil + }, + afterFunc: func() { + cancel() + }, + } + }(), + func() test { + i := 1 + k := `{ + "e":"AQAB", + "kty":"RSA", + "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw" + }` + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if i < 3 { + i++ + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(200) + w.Write([]byte(k)) + i = i + 1 + })) + ctx, cancel := context.WithCancel(context.Background()) + + return test{ + name: "Start retry update", + fields: fields{ + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + refreshDuration: time.Millisecond * 10, + errRetryInterval: time.Millisecond, + client: srv.Client(), + }, + args: args{ + ctx: ctx, + }, + checkFunc: func(j *jwkd, ch <-chan error) error { + time.Sleep(time.Millisecond * 100) + cancel() + if k := j.keys.Load(); k == nil { + return errors.New("cannot update keys") + } + + return nil + }, + afterFunc: func() { + cancel() + }, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.afterFunc != nil { + defer tt.afterFunc() + } + j := &jwkd{ + athenzURL: tt.fields.athenzURL, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + client: tt.fields.client, + keys: tt.fields.keys, + } + got := j.Start(tt.args.ctx) + if tt.checkFunc != nil { + if err := tt.checkFunc(j, got); err != nil { + t.Errorf("jwkd.Start() error = %v", err) + } + } + }) + } +} + +func Test_jwkd_Update(t *testing.T) { + type fields struct { + athenzURL string + refreshDuration time.Duration + errRetryInterval time.Duration + client *http.Client + keys atomic.Value + } + type args struct { + ctx context.Context + } + type test struct { + name string + fields fields + args args + checkFunc func(*jwkd) error + wantErr bool + } + tests := []test{ + func() test { + k := `{ + "e":"AQAB", + "kty":"RSA", + "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw" + }` + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(k)) + })) + + return test{ + name: "Update success", + fields: fields{ + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + client: srv.Client(), + }, + args: args{ + ctx: context.Background(), + }, + checkFunc: func(j *jwkd) error { + val := j.keys.Load() + if val == nil { + return errors.New("keys is empty") + } + + s := val.(*jwk.Set) + if _, ok := s.Keys[0].(*jwk.RSAPublicKey); !ok { + return errors.Errorf("Unexpected type: %v", reflect.TypeOf(s.Keys[0])) + } + return nil + }, + } + }(), + func() test { + k := `{ + "e":"AQAB", + "kty":"dummy", + "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw" + }` + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(k)) + })) + + return test{ + name: "Update fail", + fields: fields{ + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + client: srv.Client(), + }, + args: args{ + ctx: context.Background(), + }, + checkFunc: func(j *jwkd) error { + if j.keys.Load() != nil { + return errors.Errorf("keys expected nil") + } + return nil + }, + wantErr: true, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + j := &jwkd{ + athenzURL: tt.fields.athenzURL, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + client: tt.fields.client, + keys: tt.fields.keys, + } + if err := j.Update(tt.args.ctx); (err != nil) != tt.wantErr { + t.Errorf("jwkd.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.checkFunc != nil { + if err := tt.checkFunc(j); err != nil { + t.Errorf("jwkd.Update() error = %v", err) + } + } + }) + } +} + +func Test_jwkd_GetProvider(t *testing.T) { + type fields struct { + athenzURL string + refreshDuration time.Duration + errRetryInterval time.Duration + client *http.Client + keys atomic.Value + } + tests := []struct { + name string + fields fields + checkFunc func(Provider) error + }{ + { + name: "get success", + checkFunc: func(p Provider) error { + if p == nil { + return errors.New("GetProvider return nil") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + j := &jwkd{ + athenzURL: tt.fields.athenzURL, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + client: tt.fields.client, + keys: tt.fields.keys, + } + got := j.GetProvider() + if err := tt.checkFunc(got); err != nil { + t.Errorf("jwkd.GetProvider() err %v", err) + } + }) + } +} + +func Test_jwkd_getKey(t *testing.T) { + type fields struct { + athenzURL string + refreshDuration time.Duration + errRetryInterval time.Duration + client *http.Client + keys atomic.Value + } + type args struct { + keyID string + } + type test struct { + name string + fields fields + args args + want interface{} + } + genKey := func() *rsa.PrivateKey { + k, _ := rsa.GenerateKey(rand.Reader, 2048) + return k + } + newKey := func(k interface{}, keyID string) jwk.Key { + jwkKey, _ := jwk.New(k) + jwkKey.Set(jwk.KeyIDKey, keyID) + return jwkKey + } + tests := []test{ + func() test { + rsaKey := genKey() + k := newKey(rsaKey, "dummyID") + set := &jwk.Set{ + Keys: []jwk.Key{ + k, + }, + } + key := atomic.Value{} + key.Store(set) + + return test{ + name: "get key success", + fields: fields{ + keys: key, + }, + args: args{ + keyID: "dummyID", + }, + want: rsaKey, + } + }(), + func() test { + rsaKey := genKey() + k := newKey(rsaKey, "dummyID") + set := &jwk.Set{ + Keys: []jwk.Key{ + k, + }, + } + + key := atomic.Value{} + key.Store(set) + + return test{ + name: "get key not found", + fields: fields{ + keys: key, + }, + args: args{ + keyID: "not exists", + }, + want: nil, + } + }(), + func() test { + rsaKey := genKey() + k := newKey(rsaKey, "") + set := &jwk.Set{ + Keys: []jwk.Key{ + k, + }, + } + + key := atomic.Value{} + key.Store(set) + + return test{ + name: "get key id empty return nil", + fields: fields{ + keys: key, + }, + args: args{ + keyID: "", + }, + want: nil, + } + }(), + func() test { + rsaKey1 := genKey() + k1 := newKey(rsaKey1, "dummyID1") + + rsaKey2 := genKey() + k2 := newKey(rsaKey2, "dummyID2") + + rsaKey3 := genKey() + k3 := newKey(rsaKey3, "dummyID3") + + set := &jwk.Set{ + Keys: []jwk.Key{ + k1, k2, k3, + }, + } + key := atomic.Value{} + key.Store(set) + + return test{ + name: "get key success from multiple key", + fields: fields{ + keys: key, + }, + args: args{ + keyID: "dummyID2", + }, + want: rsaKey2, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + j := &jwkd{ + athenzURL: tt.fields.athenzURL, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + client: tt.fields.client, + keys: tt.fields.keys, + } + if got := j.getKey(tt.args.keyID); !reflect.DeepEqual(got, tt.want) { + t.Errorf("jwkd.getKey() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/jwk/doc.go b/jwk/doc.go new file mode 100644 index 00000000..b2998496 --- /dev/null +++ b/jwk/doc.go @@ -0,0 +1,18 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package jwk represents the jwk daemon fetching logic and the interface +package jwk diff --git a/jwk/error.go b/jwk/error.go new file mode 100644 index 00000000..5efdb39c --- /dev/null +++ b/jwk/error.go @@ -0,0 +1,23 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package jwk + +import "github.com/pkg/errors" + +var ( + // ErrFetchAthenzJWK "Fetch athenz json web key error" + ErrFetchAthenzJWK = errors.New("Fetch athenz json web key error") +) diff --git a/jwk/option.go b/jwk/option.go new file mode 100644 index 00000000..c9f53da7 --- /dev/null +++ b/jwk/option.go @@ -0,0 +1,83 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package jwk + +import ( + "net/http" + "time" + + "github.com/pkg/errors" +) + +var ( + defaultOptions = []Option{ + WithRefreshDuration("24h"), + WithErrRetryInterval("1ms"), + WithHTTPClient(http.DefaultClient), + } +) + +// Option represents a functional options pattern interface +type Option func(*jwkd) error + +// WithAthenzURL represents set athenzURL functional option +func WithAthenzURL(url string) Option { + return func(j *jwkd) error { + if url == "" { + return nil + } + j.athenzURL = url + return nil + } +} + +// WithRefreshDuration represents a RefreshDuration functional option +func WithRefreshDuration(t string) Option { + return func(j *jwkd) error { + if t == "" { + return nil + } + rd, err := time.ParseDuration(t) + if err != nil { + return errors.Wrap(err, "invalid refresh duration") + } + j.refreshDuration = rd + return nil + } +} + +// WithErrRetryInterval represents a ErrRetryInterval functional option +func WithErrRetryInterval(i string) Option { + return func(j *jwkd) error { + if i == "" { + return nil + } + ri, err := time.ParseDuration(i) + if err != nil { + return errors.Wrap(err, "invalid err retry interval") + } + j.errRetryInterval = ri + return nil + } +} + +// WithHTTPClient represents a HTTPClient functional option +func WithHTTPClient(cl *http.Client) Option { + return func(j *jwkd) error { + j.client = cl + return nil + } +} diff --git a/jwk/option_test.go b/jwk/option_test.go new file mode 100644 index 00000000..88157216 --- /dev/null +++ b/jwk/option_test.go @@ -0,0 +1,266 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package jwk + +import ( + "fmt" + "net/http" + "reflect" + "testing" + "time" +) + +func TestWithAthenzURL(t *testing.T) { + type args struct { + url string + } + tests := []struct { + name string + args args + checkFunc func(Option) error + }{ + { + name: "set success", + args: args{ + "http://dummy.com", + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err != nil { + return err + } + if pol.athenzURL != "http://dummy.com" { + return fmt.Errorf("Error") + } + + return nil + }, + }, + { + name: "empty value", + args: args{ + "", + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err != nil { + return err + } + if !reflect.DeepEqual(pol, &jwkd{}) { + return fmt.Errorf("expected no changes, but got %v", pol) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithAthenzURL(tt.args.url) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithAthenzURL() error = %v", err) + } + }) + } +} + +func TestWithRefreshDuration(t *testing.T) { + type args struct { + t string + } + tests := []struct { + name string + args args + checkFunc func(Option) error + }{ + { + name: "set success", + args: args{ + "1h", + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err != nil { + return err + } + if pol.refreshDuration != time.Hour { + return fmt.Errorf("Error") + } + + return nil + }, + }, { + name: "invalid format", + args: args{ + "dummy", + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err == nil { + return fmt.Errorf("expected error, but not return") + } + + return nil + }, + }, + { + name: "empty value", + args: args{ + "", + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err != nil { + return err + } + if !reflect.DeepEqual(pol, &jwkd{}) { + return fmt.Errorf("expected no changes, but got %v", pol) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithRefreshDuration(tt.args.t) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithRefreshDuration() error = %v", err) + } + }) + } +} + +func TestWithErrRetryInterval(t *testing.T) { + type args struct { + i string + } + tests := []struct { + name string + args args + checkFunc func(Option) error + }{ + { + name: "set success", + args: args{ + "1h", + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err != nil { + return err + } + if pol.errRetryInterval != time.Hour { + return fmt.Errorf("Error") + } + + return nil + }, + }, { + name: "invalid format", + args: args{ + "dummy", + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err == nil { + return fmt.Errorf("expected error, but not return") + } + + return nil + }, + }, + { + name: "empty value", + args: args{ + "", + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err != nil { + return err + } + if !reflect.DeepEqual(pol, &jwkd{}) { + return fmt.Errorf("expected no changes, but got %v", pol) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithErrRetryInterval(tt.args.i) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithErrRetryInterval() error= %v", err) + } + }) + } +} + +func TestWithHTTPClient(t *testing.T) { + type args struct { + cl *http.Client + } + type test struct { + name string + args args + checkFunc func(Option) error + } + tests := []test{ + func() test { + cl := &http.Client{} + return test{ + name: "set success", + args: args{ + cl: cl, + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err != nil { + return err + } + if pol.client != cl { + return fmt.Errorf("Error") + } + + return nil + }, + } + }(), + { + name: "empty value", + args: args{ + nil, + }, + checkFunc: func(opt Option) error { + pol := &jwkd{} + if err := opt(pol); err != nil { + return err + } + if !reflect.DeepEqual(pol, &jwkd{}) { + return fmt.Errorf("expected no changes, but got %v", pol) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithHTTPClient(tt.args.cl) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithHTTPClient() error = %v", err) + } + }) + } +} diff --git a/model.go b/model.go index 3fa8ecce..7f456485 100644 --- a/model.go +++ b/model.go @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -package providerd +package authorizerd type signedPolicy struct { KeyID string `json:"keyId"` diff --git a/option.go b/option.go index f45ea002..296a8b88 100644 --- a/option.go +++ b/option.go @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -package providerd +package authorizerd import ( "net/http" @@ -22,34 +22,38 @@ import ( var ( defaultOptions = []Option{ - AthenzURL("www.athenz.com/zts/v1"), - Transport(nil), - CacheExp(time.Minute), + WithAthenzURL("www.athenz.com/zts/v1"), + WithTransport(nil), + WithCacheExp(time.Minute), + WithRoleCertURIPrefix("athenz://role/"), + WithEnablePubkeyd(), + WithEnablePolicyd(), + WithEnableJwkd(), } ) // Option represents a functional options pattern interface -type Option func(*provider) error +type Option func(*authorizer) error // AthenzURL represents a AthenzURL functional option -func AthenzURL(url string) Option { - return func(prov *provider) error { +func WithAthenzURL(url string) Option { + return func(prov *authorizer) error { prov.athenzURL = url return nil } } -// AthenzDomains represents a AthenzDomains functional option -func AthenzDomains(domains ...string) Option { - return func(prov *provider) error { +// WithAthenzDomains represents a AthenzDomains functional option +func WithAthenzDomains(domains ...string) Option { + return func(prov *authorizer) error { prov.athenzDomains = domains return nil } } -// Transport represents a Transport functional option -func Transport(t *http.Transport) Option { - return func(prov *provider) error { +// WithTransport represents a Transport functional option +func WithTransport(t *http.Transport) Option { + return func(prov *authorizer) error { if t == nil { prov.client = &http.Client{ Timeout: time.Second * 30, @@ -63,46 +67,70 @@ func Transport(t *http.Transport) Option { } } -// CacheExp represents the cache expiration time -func CacheExp(exp time.Duration) Option { - return func(prov *provider) error { +// WithCacheExp represents the cache expiration time +func WithCacheExp(exp time.Duration) Option { + return func(prov *authorizer) error { prov.cache.SetDefaultExpire(exp) prov.cacheExp = exp return nil } } +// WithRoleCertURIPrefix represents a RoleCertURIPrefix functional option +func WithRoleCertURIPrefix(t string) Option { + return func(prov *authorizer) error { + prov.roleCertURIPrefix = t + return nil + } +} + /* Pubkeyd parameters */ -// PubkeyRefreshDuration represents a PubkeyRefreshDuration functional option -func PubkeyRefreshDuration(t string) Option { - return func(prov *provider) error { +// WithEnablePubkeyd represents a EnablePubkey functional optiond +func WithEnablePubkeyd() Option { + return func(prov *authorizer) error { + prov.disablePubkeyd = false + return nil + } +} + +// WithDisablePubkeyd represents a DisablePubkey functional optiond +func WithDisablePubkeyd() Option { + return func(prov *authorizer) error { + prov.disablePubkeyd = true + return nil + } +} + +// WithPubkeyRefreshDuration represents a PubkeyRefreshDuration functional option +func WithPubkeyRefreshDuration(t string) Option { + return func(prov *authorizer) error { prov.pubkeyRefreshDuration = t return nil } } -// PubkeySysAuthDomain represents a PubkeySysAuthDomain functional option -func PubkeySysAuthDomain(domain string) Option { - return func(prov *provider) error { +// WithPubkeySysAuthDomain represents a PubkeySysAuthDomain functional option +func WithPubkeySysAuthDomain(domain string) Option { + return func(prov *authorizer) error { prov.pubkeySysAuthDomain = domain return nil } } -// PubkeyEtagExpTime represents a PubkeyEtagExpTime functional option -func PubkeyEtagExpTime(t string) Option { - return func(prov *provider) error { +// WithPubkeyEtagExpTime represents a PubkeyEtagExpTime functional option +func WithPubkeyEtagExpTime(t string) Option { + return func(prov *authorizer) error { prov.pubkeyEtagExpTime = t return nil } } -// PubkeyEtagFlushDur represents a PubkeyEtagFlushDur functional option -func PubkeyEtagFlushDur(t string) Option { - return func(prov *provider) error { +// WithPubkeyEtagFlushDuration represents a PubkeyEtagFlushDur functional option +func WithPubkeyEtagFlushDuration(t string) Option { + return func(prov *authorizer) error { prov.pubkeyEtagFlushDur = t return nil } @@ -112,34 +140,78 @@ func PubkeyEtagFlushDur(t string) Option { policyd parameters */ -// PolicyRefreshDuration represents a PolicyRefreshDuration functional option -func PolicyRefreshDuration(t string) Option { - return func(prov *provider) error { +// WithEnablePolicyd represents a EnablePolicyd functional optiond +func WithEnablePolicyd() Option { + return func(prov *authorizer) error { + prov.disablePolicyd = false + return nil + } +} + +// WithDisablePolicyd represents a DisablePolicyd functional optiond +func WithDisablePolicyd() Option { + return func(prov *authorizer) error { + prov.disablePolicyd = true + return nil + } +} + +// WithPolicyRefreshDuration represents a PolicyRefreshDuration functional option +func WithPolicyRefreshDuration(t string) Option { + return func(prov *authorizer) error { prov.policyRefreshDuration = t return nil } } -// PolicyExpireMargin represents a PolicyExpireMargin functional option -func PolicyExpireMargin(t string) Option { - return func(prov *provider) error { +// WithPolicyExpireMargin represents a PolicyExpireMargin functional option +func WithPolicyExpireMargin(t string) Option { + return func(prov *authorizer) error { prov.policyExpireMargin = t return nil } } -// PolicyEtagExpTime represents a PolicyEtagExpTime functional option -func PolicyEtagExpTime(t string) Option { - return func(prov *provider) error { +// WithPolicyEtagExpTime represents a PolicyEtagExpTime functional option +func WithPolicyEtagExpTime(t string) Option { + return func(prov *authorizer) error { prov.policyEtagExpTime = t return nil } } -// PolicyEtagFlushDur represents a PolicyEtagFlushDur functional option -func PolicyEtagFlushDur(t string) Option { - return func(prov *provider) error { +// WithPolicyEtagFlushDuration represents a PolicyEtagFlushDur functional option +func WithPolicyEtagFlushDuration(t string) Option { + return func(prov *authorizer) error { prov.policyEtagFlushDur = t return nil } } + +/* + jwkd parameters +*/ + +// WithEnableJwkd represents a EnableJwkd functional optiond +func WithEnableJwkd() Option { + return func(prov *authorizer) error { + prov.disableJwkd = false + return nil + } +} + +// WithDisableJwkd represents a DisableJwkd functional optiond +func WithDisableJwkd() Option { + return func(prov *authorizer) error { + prov.disableJwkd = true + return nil + } +} + +// WithJwkRefreshDuration represents a JwkRefreshDuration functional option +func WithJwkRefreshDuration(t string) Option { + return func(prov *authorizer) error { + prov.jwkRefreshDuration = t + return nil + } +} diff --git a/option_test.go b/option_test.go index e2d447c5..a7f9fbb2 100644 --- a/option_test.go +++ b/option_test.go @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -package providerd +package authorizerd import ( "fmt" @@ -25,7 +25,65 @@ import ( "github.com/kpango/gache" ) -func TestPolicyRefreshDuration(t *testing.T) { +func TestWithEnablePubkeyd(t *testing.T) { + tests := []struct { + name string + checkFunc func(Option) error + }{ + { + name: "set success", + checkFunc: func(opt Option) error { + prov := &authorizer{} + if err := opt(prov); err != nil { + return err + } + if prov.disablePubkeyd != false { + return fmt.Errorf("invalid param was set") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithEnablePubkeyd() + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithEnablePubkeyd() error = %v", err) + } + }) + } +} + +func TestWithDisablePubkeyd(t *testing.T) { + tests := []struct { + name string + checkFunc func(Option) error + }{ + { + name: "set success", + checkFunc: func(opt Option) error { + prov := &authorizer{} + if err := opt(prov); err != nil { + return err + } + if prov.disablePubkeyd != true { + return fmt.Errorf("invalid param was set") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithDisablePubkeyd() + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithDisablePubkeyd() error = %v", err) + } + }) + } +} + +func TestWithPolicyRefreshDuration(t *testing.T) { type args struct { t string } @@ -40,7 +98,7 @@ func TestPolicyRefreshDuration(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -53,14 +111,14 @@ func TestPolicyRefreshDuration(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PolicyRefreshDuration(tt.args.t) + got := WithPolicyRefreshDuration(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("PolicyRefreshDuration() error = %v", err) + t.Errorf("WithPolicyRefreshDuration() error = %v", err) } }) } } -func TestPubkeyRefreshDuration(t *testing.T) { +func TestWithPubkeyRefreshDuration(t *testing.T) { type args struct { t string } @@ -75,7 +133,7 @@ func TestPubkeyRefreshDuration(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -88,14 +146,14 @@ func TestPubkeyRefreshDuration(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PubkeyRefreshDuration(tt.args.t) + got := WithPubkeyRefreshDuration(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("PubkeyRefreshDuration() error = %v", err) + t.Errorf("WithPubkeyRefreshDuration() error = %v", err) } }) } } -func TestAthenzURL(t *testing.T) { +func TestWithAthenzURL(t *testing.T) { type args struct { t string } @@ -110,7 +168,7 @@ func TestAthenzURL(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -123,14 +181,14 @@ func TestAthenzURL(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := AthenzURL(tt.args.t) + got := WithAthenzURL(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("AthenzURL() error = %v", err) + t.Errorf("WithAthenzURL() error = %v", err) } }) } } -func TestAthenzDomains(t *testing.T) { +func TestWithAthenzDomains(t *testing.T) { type args struct { t []string } @@ -145,7 +203,7 @@ func TestAthenzDomains(t *testing.T) { t: []string{"dummy1", "dummy2"}, }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -158,14 +216,15 @@ func TestAthenzDomains(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := AthenzDomains(tt.args.t...) + got := WithAthenzDomains(tt.args.t...) if err := tt.checkFunc(got); err != nil { - t.Errorf("AthenzDomains() error = %v", err) + t.Errorf("WithAthenzDomains() error = %v", err) } }) } } -func TestPubkeySysAuthDomain(t *testing.T) { + +func TestWithPubkeySysAuthDomain(t *testing.T) { type args struct { t string } @@ -180,7 +239,7 @@ func TestPubkeySysAuthDomain(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -193,14 +252,15 @@ func TestPubkeySysAuthDomain(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PubkeySysAuthDomain(tt.args.t) + got := WithPubkeySysAuthDomain(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("PubkeySysAuthDomain() error = %v", err) + t.Errorf("WithPubkeySysAuthDomain() error = %v", err) } }) } } -func TestPubkeyEtagExpTime(t *testing.T) { + +func TestWithPubkeyEtagExpTime(t *testing.T) { type args struct { t string } @@ -215,7 +275,7 @@ func TestPubkeyEtagExpTime(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -228,14 +288,14 @@ func TestPubkeyEtagExpTime(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PubkeyEtagExpTime(tt.args.t) + got := WithPubkeyEtagExpTime(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("PubkeyEtagExpTime() error = %v", err) + t.Errorf("WithPubkeyEtagExpTime() error = %v", err) } }) } } -func TestPubkeyEtagFlushDur(t *testing.T) { +func TestWithPubkeyEtagFlushDuration(t *testing.T) { type args struct { t string } @@ -250,7 +310,7 @@ func TestPubkeyEtagFlushDur(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -263,14 +323,73 @@ func TestPubkeyEtagFlushDur(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PubkeyEtagFlushDur(tt.args.t) + got := WithPubkeyEtagFlushDuration(tt.args.t) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithPubkeyEtagFlushDuration() error = %v", err) + } + }) + } +} + +func TestWithEnablePolicyd(t *testing.T) { + tests := []struct { + name string + checkFunc func(Option) error + }{ + { + name: "set success", + checkFunc: func(opt Option) error { + prov := &authorizer{} + if err := opt(prov); err != nil { + return err + } + if prov.disablePolicyd != false { + return fmt.Errorf("invalid param was set") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithEnablePolicyd() + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithEnablePolicyd() error = %v", err) + } + }) + } +} + +func TestWithDisablePolicyd(t *testing.T) { + tests := []struct { + name string + checkFunc func(Option) error + }{ + { + name: "set success", + checkFunc: func(opt Option) error { + prov := &authorizer{} + if err := opt(prov); err != nil { + return err + } + if prov.disablePolicyd != true { + return fmt.Errorf("invalid param was set") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithDisablePolicyd() if err := tt.checkFunc(got); err != nil { - t.Errorf("PubkeyEtagFlushDur() error = %v", err) + t.Errorf("WithDisablePolicyd() error = %v", err) } }) } } -func TestPolicyExpireMargin(t *testing.T) { + +func TestWithPolicyExpireMargin(t *testing.T) { type args struct { t string } @@ -285,7 +404,7 @@ func TestPolicyExpireMargin(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -298,14 +417,14 @@ func TestPolicyExpireMargin(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PolicyExpireMargin(tt.args.t) + got := WithPolicyExpireMargin(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("PolicyExpireMargin() error = %v", err) + t.Errorf("WithPolicyExpireMargin() error = %v", err) } }) } } -func TestPolicyEtagFlushDur(t *testing.T) { +func TestWithPolicyEtagFlushDuration(t *testing.T) { type args struct { t string } @@ -320,7 +439,7 @@ func TestPolicyEtagFlushDur(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -333,9 +452,9 @@ func TestPolicyEtagFlushDur(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PolicyEtagFlushDur(tt.args.t) + got := WithPolicyEtagFlushDuration(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("PolicyEtagFlushDur() error = %v", err) + t.Errorf("WithPolicyEtagFlushDuration() error = %v", err) } }) } @@ -355,7 +474,7 @@ func TestPolicyEtagExpTime(t *testing.T) { t: "dummy", }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -368,9 +487,9 @@ func TestPolicyEtagExpTime(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PolicyEtagExpTime(tt.args.t) + got := WithPolicyEtagExpTime(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("PolicyEtagExpTime() error = %v", err) + t.Errorf("WithPolicyEtagExpTime() error = %v", err) } }) } @@ -390,7 +509,7 @@ func TestCacheExp(t *testing.T) { d: time.Duration(time.Hour * 2), }, checkFunc: func(opt Option) error { - prov := &provider{ + prov := &authorizer{ cache: gache.New(), } if err := opt(prov); err != nil { @@ -405,9 +524,9 @@ func TestCacheExp(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := CacheExp(tt.args.d) + got := WithCacheExp(tt.args.d) if err := tt.checkFunc(got); err != nil { - t.Errorf("CacheExp() error = %v", err) + t.Errorf("WithCacheExp() error = %v", err) } }) } @@ -427,7 +546,7 @@ func TestTransport(t *testing.T) { t: &http.Transport{}, }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -443,7 +562,7 @@ func TestTransport(t *testing.T) { t: nil, }, checkFunc: func(opt Option) error { - prov := &provider{} + prov := &authorizer{} if err := opt(prov); err != nil { return err } @@ -459,9 +578,103 @@ func TestTransport(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := Transport(tt.args.t) + got := WithTransport(tt.args.t) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithTransport() error = %v", err) + } + }) + } +} + +func TestWithEnableJwkd(t *testing.T) { + tests := []struct { + name string + checkFunc func(Option) error + }{ + { + name: "set success", + checkFunc: func(opt Option) error { + prov := &authorizer{} + if err := opt(prov); err != nil { + return err + } + if prov.disableJwkd != false { + return fmt.Errorf("invalid param was set") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithEnableJwkd() + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithEnableJwkd() error = %v", err) + } + }) + } +} + +func TestWithDisableJwkd(t *testing.T) { + tests := []struct { + name string + checkFunc func(Option) error + }{ + { + name: "set success", + checkFunc: func(opt Option) error { + prov := &authorizer{} + if err := opt(prov); err != nil { + return err + } + if prov.disableJwkd != true { + return fmt.Errorf("invalid param was set") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithDisableJwkd() + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithDisableJwkd() error = %v", err) + } + }) + } +} + +func TestWithJwkRefreshDuration(t *testing.T) { + type args struct { + t string + } + tests := []struct { + name string + args args + checkFunc func(Option) error + }{ + { + name: "set success", + args: args{ + t: "dummy", + }, + checkFunc: func(opt Option) error { + prov := &authorizer{} + if err := opt(prov); err != nil { + return err + } + if prov.jwkRefreshDuration != "dummy" { + return fmt.Errorf("invalid param was set") + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithJwkRefreshDuration(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("Transport() error = %v", err) + t.Errorf("WithJwkRefreshDuration() error = %v", err) } }) } diff --git a/policy/policyd.go b/policy/daemon.go similarity index 81% rename from policy/policyd.go rename to policy/daemon.go index a9e93f7b..015976b7 100644 --- a/policy/policyd.go +++ b/policy/daemon.go @@ -30,21 +30,25 @@ import ( "github.com/kpango/glg" "github.com/pkg/errors" "github.com/yahoo/athenz/utils/zpe-updater/util" - "github.com/yahoojapan/athenz-policy-updater/pubkey" + "github.com/yahoojapan/athenz-authorizer/pubkey" "golang.org/x/sync/errgroup" ) // Policyd represent the daemon to retrieve policy data from Athenz. -type Policyd interface { - StartPolicyUpdater(context.Context) <-chan error - UpdatePolicy(context.Context) error +type Daemon interface { + Start(context.Context) <-chan error + Update(context.Context) error CheckPolicy(ctx context.Context, domain string, roles []string, action, resource string) error + GetPolicyCache(context.Context) map[string]interface{} } type policyd struct { - expireMargin time.Duration // expire margin force update policy when the policy expire time hit the margin - rolePolicies gache.Gache //*sync.Map // map[:role.][]Assertion - refreshDuration time.Duration + expireMargin time.Duration // expire margin force update policy when the policy expire time hit the margin + rolePolicies gache.Gache //*sync.Map // map[:role.][]Assertion + policyExpiredDuration time.Duration + + refreshDuration time.Duration + //flushDur time.Duration errRetryInterval time.Duration pkp pubkey.Provider @@ -65,8 +69,8 @@ type etagCache struct { sp *SignedPolicy } -// NewPolicyd represent the constructor of Policyd -func NewPolicyd(opts ...Option) (Policyd, error) { +// New represent the constructor of Policyd +func New(opts ...Option) (Daemon, error) { p := &policyd{ rolePolicies: gache.New(), etagCache: gache.New(), @@ -78,8 +82,7 @@ func NewPolicyd(opts ...Option) (Policyd, error) { }) for _, opt := range append(defaultOptions, opts...) { - err := opt(p) - if err != nil { + if err := opt(p); err != nil { return nil, errors.Wrap(err, "error create policyd") } } @@ -87,12 +90,13 @@ func NewPolicyd(opts ...Option) (Policyd, error) { return p, nil } -// StartPolicyUpdater starts the Policy daemon to retrive the policy data periodically -func (p *policyd) StartPolicyUpdater(ctx context.Context) <-chan error { +// Start starts the Policy daemon to retrive the policy data periodically +func (p *policyd) Start(ctx context.Context) <-chan error { glg.Info("Starting policyd updater") ech := make(chan error, 100) fch := make(chan struct{}, 1) - if err := p.UpdatePolicy(ctx); err != nil { + if err := p.Update(ctx); err != nil { + glg.Debugf("Error initialize policy data, err: %v", err) ech <- errors.Wrap(err, "error update policy") fch <- struct{}{} } @@ -100,32 +104,23 @@ func (p *policyd) StartPolicyUpdater(ctx context.Context) <-chan error { go func() { defer close(fch) defer close(ech) + + p.rolePolicies.StartExpired(ctx, p.policyExpiredDuration) p.etagCache.StartExpired(ctx, p.etagFlushDur) - p.rolePolicies.StartExpired(ctx, time.Hour*24) ticker := time.NewTicker(p.refreshDuration) - ebuf := errors.New("") for { select { case <-ctx.Done(): glg.Info("Stopping policyd updater") ticker.Stop() ech <- ctx.Err() - if ebuf.Error() != "" { - ech <- errors.Wrap(ctx.Err(), ebuf.Error()) - } else { - ech <- ctx.Err() - } return case <-fch: - if err := p.UpdatePolicy(ctx); err != nil { - err = errors.Wrap(err, "error update policy") - select { - case ech <- errors.Wrap(ebuf, err.Error()): - ebuf = errors.New("") - default: - ebuf = errors.Wrap(ebuf, err.Error()) - } + if err := p.Update(ctx); err != nil { + ech <- errors.Wrap(err, "error update policy") + time.Sleep(p.errRetryInterval) + select { case fch <- struct{}{}: default: @@ -133,14 +128,9 @@ func (p *policyd) StartPolicyUpdater(ctx context.Context) <-chan error { } } case <-ticker.C: - if err := p.UpdatePolicy(ctx); err != nil { - err = errors.Wrap(err, "error update policy") - select { - case ech <- errors.Wrap(ebuf, err.Error()): - ebuf = errors.New("") - default: - ebuf = errors.Wrap(ebuf, err.Error()) - } + if err := p.Update(ctx); err != nil { + ech <- errors.Wrap(err, "error update policy") + select { case fch <- struct{}{}: default: @@ -154,8 +144,8 @@ func (p *policyd) StartPolicyUpdater(ctx context.Context) <-chan error { return ech } -// UpdatePolicy updates and cache policy data -func (p *policyd) UpdatePolicy(ctx context.Context) error { +// Update updates and cache policy data +func (p *policyd) Update(ctx context.Context) error { glg.Info("Updating policy") defer glg.Info("Updated policy") eg := errgroup.Group{} @@ -233,19 +223,26 @@ func (p *policyd) CheckPolicy(ctx context.Context, domain string, roles []string return err } +func (p *policyd) GetPolicyCache(ctx context.Context) map[string]interface{} { + return p.rolePolicies.ToRawMap(ctx) +} + func (p *policyd) fetchAndCachePolicy(ctx context.Context, dom string) error { spd, upd, err := p.fetchPolicy(ctx, dom) if err != nil { + glg.Debugf("fetch policy failed, err: %v", err) return errors.Wrap(err, "error fetch policy") } + glg.Debugf("fetch policy success, updated: %v", upd) if upd { - if glg.Get().GetCurrentMode(glg.DEBG) != glg.NONE { + glg.DebugFunc(func() string { rawpol, _ := json.Marshal(spd) - glg.Debugf("fetched policy data:\tdomain\t%s\tbody\t%s", dom, (string)(rawpol)) - } + return fmt.Sprintf("fetched policy data:\tdomain\t%s\tbody\t%s", dom, (string)(rawpol)) + }) if err = p.simplifyAndCache(ctx, spd); err != nil { + glg.Debugf("simplify and cache error: %v", err) return errors.Wrap(err, "error simplify and cache") } } @@ -258,9 +255,10 @@ func (p *policyd) fetchPolicy(ctx context.Context, domain string) (*SignedPolicy // https://{www.athenz.com/zts/v1}/domain/{athenz domain}/signed_policy_data url := fmt.Sprintf("https://%s/domain/%s/signed_policy_data", p.athenzURL, domain) + glg.Debugf("fetching policy, url: %v", url) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - glg.Errorf("Fetch policy error, domain: %s, error: %v", domain, err) + glg.Errorf("fetch policy error, domain: %s, error: %v", domain, err) return nil, false, errors.Wrap(err, "error creating fetch policy request") } @@ -324,11 +322,9 @@ func (p *policyd) fetchPolicy(ctx context.Context, domain string) (*SignedPolicy func (p *policyd) simplifyAndCache(ctx context.Context, sp *SignedPolicy) error { rp := gache.New() - defer rp.Clear() - eg := errgroup.Group{} - mu := new(sync.Mutex) - assm := new(sync.Map) + assm := new(sync.Map) // assertion map + for _, policy := range sp.DomainSignedPolicyData.SignedPolicyData.PolicyData.Policies { pol := policy eg.Go(func() error { @@ -341,11 +337,11 @@ func (p *policyd) simplifyAndCache(ctx context.Context, sp *SignedPolicy) error if _, ok := assm.Load(km); !ok { assm.Store(km, ass) } else { + // deny policy will override allow policy, and also remove duplication if strings.EqualFold("deny", ass.Effect) { assm.Store(km, ass) } } - } } @@ -362,19 +358,20 @@ func (p *policyd) simplifyAndCache(ctx context.Context, sp *SignedPolicy) error ass := val.(*util.Assertion) a, err := NewAssertion(ass.Action, ass.Resource, ass.Effect) if err != nil { + glg.Debugf("error adding assertion to the cache, err: %v", err) retErr = err return false } - var asss []*Assertion - mu.Lock() + var asss []*Assertion if r, ok := rp.Get(ass.Role); ok { asss = append(r.([]*Assertion), a) } else { asss = []*Assertion{a} } rp.SetWithExpire(ass.Role, asss, time.Duration(sp.DomainSignedPolicyData.SignedPolicyData.Expires.UnixNano())) - mu.Unlock() + + glg.Debugf("added assertion to the cache: %+v", ass) return true }) if retErr != nil { diff --git a/policy/policyd_test.go b/policy/daemon_test.go similarity index 78% rename from policy/policyd_test.go rename to policy/daemon_test.go index 23e77c41..cc09c7cc 100644 --- a/policy/policyd_test.go +++ b/policy/daemon_test.go @@ -31,18 +31,18 @@ import ( "github.com/pkg/errors" authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" "github.com/yahoo/athenz/utils/zpe-updater/util" - "github.com/yahoojapan/athenz-policy-updater/pubkey" + "github.com/yahoojapan/athenz-authorizer/pubkey" ) -func TestNewPolicyd(t *testing.T) { +func TestNew(t *testing.T) { type args struct { opts []Option } tests := []struct { name string args args - want Policyd - checkFunc func(got Policyd) error + want Daemon + checkFunc func(got Daemon) error wantErr bool }{ { @@ -50,7 +50,7 @@ func TestNewPolicyd(t *testing.T) { args: args{ opts: []Option{}, }, - checkFunc: func(got Policyd) error { + checkFunc: func(got Daemon) error { p := got.(*policyd) if p.expireMargin != time.Hour*3 { return errors.New("invalid expireMargin") @@ -61,9 +61,9 @@ func TestNewPolicyd(t *testing.T) { { name: "new success with options", args: args{ - opts: []Option{ExpireMargin("5s")}, + opts: []Option{WithExpireMargin("5s")}, }, - checkFunc: func(got Policyd) error { + checkFunc: func(got Daemon) error { p := got.(*policyd) if p.expireMargin != time.Second*5 { return errors.New("invalid expireMargin") @@ -74,41 +74,42 @@ func TestNewPolicyd(t *testing.T) { { name: "new error due to options", args: args{ - opts: []Option{EtagExpTime("dummy")}, + opts: []Option{WithEtagExpTime("dummy")}, }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewPolicyd(tt.args.opts...) + got, err := New(tt.args.opts...) if (err != nil) != tt.wantErr { - t.Errorf("NewPolicyd() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } if tt.checkFunc != nil { if err := tt.checkFunc(got); err != nil { - t.Errorf("NewPolicyd() = %v", err) + t.Errorf("New() = %v", err) } } }) } } -func Test_policy_StartPolicyUpdater(t *testing.T) { +func Test_policy_Start(t *testing.T) { type fields struct { - expireMargin time.Duration - rolePolicies gache.Gache - refreshDuration time.Duration - errRetryInterval time.Duration - pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration - etagExpTime time.Duration - athenzURL string - athenzDomains []string - client *http.Client + expireMargin time.Duration + rolePolicies gache.Gache + policyExpiredDuration time.Duration + refreshDuration time.Duration + errRetryInterval time.Duration + pkp pubkey.Provider + etagCache gache.Gache + etagFlushDur time.Duration + etagExpTime time.Duration + athenzURL string + athenzDomains []string + client *http.Client } type args struct { ctx context.Context @@ -131,16 +132,17 @@ func Test_policy_StartPolicyUpdater(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) return test{ - name: "Start updator success", + name: "Start success", fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - etagFlushDur: time.Second, - refreshDuration: time.Second, - expireMargin: time.Hour, - client: srv.Client(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + etagCache: gache.New(), + etagExpTime: time.Minute, + etagFlushDur: time.Second, + refreshDuration: time.Second, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -188,16 +190,17 @@ func Test_policy_StartPolicyUpdater(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) return test{ - name: "Start updator can update cache", + name: "Start can update cache", fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - etagFlushDur: time.Second, - refreshDuration: time.Millisecond * 30, - expireMargin: time.Hour, - client: srv.Client(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + etagCache: gache.New(), + etagExpTime: time.Minute, + etagFlushDur: time.Second, + refreshDuration: time.Millisecond * 30, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -261,17 +264,18 @@ func Test_policy_StartPolicyUpdater(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) return test{ - name: "Start updator retry update", + name: "Start retry update", fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - etagFlushDur: time.Second, - refreshDuration: time.Minute, - errRetryInterval: time.Millisecond * 5, - expireMargin: time.Hour, - client: srv.Client(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + etagCache: gache.New(), + etagExpTime: time.Minute, + etagFlushDur: time.Second, + refreshDuration: time.Minute, + errRetryInterval: time.Millisecond * 5, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -325,41 +329,43 @@ func Test_policy_StartPolicyUpdater(t *testing.T) { defer tt.afterFunc() } p := &policyd{ - expireMargin: tt.fields.expireMargin, - rolePolicies: tt.fields.rolePolicies, - refreshDuration: tt.fields.refreshDuration, - errRetryInterval: tt.fields.errRetryInterval, - pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, - etagExpTime: tt.fields.etagExpTime, - athenzURL: tt.fields.athenzURL, - athenzDomains: tt.fields.athenzDomains, - client: tt.fields.client, + expireMargin: tt.fields.expireMargin, + rolePolicies: tt.fields.rolePolicies, + policyExpiredDuration: tt.fields.policyExpiredDuration, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + pkp: tt.fields.pkp, + etagCache: tt.fields.etagCache, + etagFlushDur: tt.fields.etagFlushDur, + etagExpTime: tt.fields.etagExpTime, + athenzURL: tt.fields.athenzURL, + athenzDomains: tt.fields.athenzDomains, + client: tt.fields.client, } - ch := p.StartPolicyUpdater(tt.args.ctx) + ch := p.Start(tt.args.ctx) if tt.checkFunc != nil { if err := tt.checkFunc(p, ch); err != nil { - t.Errorf("policy.StartPolicyUpdater() error = %v", err) + t.Errorf("policy.Start() error = %v", err) } } }) } } -func Test_policy_UpdatePolicy(t *testing.T) { +func Test_policy_Update(t *testing.T) { type fields struct { - expireMargin time.Duration - rolePolicies gache.Gache - refreshDuration time.Duration - errRetryInterval time.Duration - pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration - etagExpTime time.Duration - athenzURL string - athenzDomains []string - client *http.Client + expireMargin time.Duration + rolePolicies gache.Gache + policyExpiredDuration time.Duration + refreshDuration time.Duration + errRetryInterval time.Duration + pkp pubkey.Provider + etagCache gache.Gache + etagFlushDur time.Duration + etagExpTime time.Duration + athenzURL string + athenzDomains []string + client *http.Client } type args struct { ctx context.Context @@ -385,12 +391,13 @@ func Test_policy_UpdatePolicy(t *testing.T) { return test{ name: "Update policy success", fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -480,14 +487,15 @@ func Test_policy_UpdatePolicy(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*10)) return test{ - name: "Update policy success", + name: "Update error, context timeout", fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -517,27 +525,28 @@ func Test_policy_UpdatePolicy(t *testing.T) { } p := &policyd{ - expireMargin: tt.fields.expireMargin, - rolePolicies: tt.fields.rolePolicies, - refreshDuration: tt.fields.refreshDuration, - errRetryInterval: tt.fields.errRetryInterval, - pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, - etagExpTime: tt.fields.etagExpTime, - athenzURL: tt.fields.athenzURL, - athenzDomains: tt.fields.athenzDomains, - client: tt.fields.client, + expireMargin: tt.fields.expireMargin, + rolePolicies: tt.fields.rolePolicies, + policyExpiredDuration: tt.fields.policyExpiredDuration, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + pkp: tt.fields.pkp, + etagCache: tt.fields.etagCache, + etagFlushDur: tt.fields.etagFlushDur, + etagExpTime: tt.fields.etagExpTime, + athenzURL: tt.fields.athenzURL, + athenzDomains: tt.fields.athenzDomains, + client: tt.fields.client, } if tt.beforeFunc != nil { tt.beforeFunc() } - if err := p.UpdatePolicy(tt.args.ctx); (err != nil) && tt.wantErr != "" && err.Error() != tt.wantErr { - t.Errorf("policy.UpdatePolicy() error = %v, wantErr %v", err, tt.wantErr) + if err := p.Update(tt.args.ctx); (err != nil) && tt.wantErr != "" && err.Error() != tt.wantErr { + t.Errorf("policy.Update() error = %v, wantErr %v", err, tt.wantErr) } if tt.checkFunc != nil { if err := tt.checkFunc(p); err != nil { - t.Errorf("policy.UpdatePolicy() error = %v", err) + t.Errorf("policy.Update() error = %v", err) } } }) @@ -761,17 +770,18 @@ func Test_policy_CheckPolicy(t *testing.T) { func Test_policy_fetchAndCachePolicy(t *testing.T) { type fields struct { - expireMargin time.Duration - rolePolicies gache.Gache - refreshDuration time.Duration - errRetryInterval time.Duration - pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration - etagExpTime time.Duration - athenzURL string - athenzDomains []string - client *http.Client + expireMargin time.Duration + rolePolicies gache.Gache + policyExpiredDuration time.Duration + refreshDuration time.Duration + errRetryInterval time.Duration + pkp pubkey.Provider + etagCache gache.Gache + etagFlushDur time.Duration + etagExpTime time.Duration + athenzURL string + athenzDomains []string + client *http.Client } type args struct { ctx context.Context @@ -796,12 +806,13 @@ func Test_policy_fetchAndCachePolicy(t *testing.T) { return test{ name: "fetch policy success with updated policy", fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -837,12 +848,13 @@ func Test_policy_fetchAndCachePolicy(t *testing.T) { return test{ name: "fetch policy failed", fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -869,12 +881,13 @@ func Test_policy_fetchAndCachePolicy(t *testing.T) { return test{ name: "simplifyAndCache failed", fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -894,17 +907,18 @@ func Test_policy_fetchAndCachePolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &policyd{ - expireMargin: tt.fields.expireMargin, - rolePolicies: tt.fields.rolePolicies, - refreshDuration: tt.fields.refreshDuration, - errRetryInterval: tt.fields.errRetryInterval, - pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, - etagExpTime: tt.fields.etagExpTime, - athenzURL: tt.fields.athenzURL, - athenzDomains: tt.fields.athenzDomains, - client: tt.fields.client, + expireMargin: tt.fields.expireMargin, + rolePolicies: tt.fields.rolePolicies, + policyExpiredDuration: tt.fields.policyExpiredDuration, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + pkp: tt.fields.pkp, + etagCache: tt.fields.etagCache, + etagFlushDur: tt.fields.etagFlushDur, + etagExpTime: tt.fields.etagExpTime, + athenzURL: tt.fields.athenzURL, + athenzDomains: tt.fields.athenzDomains, + client: tt.fields.client, } if err := p.fetchAndCachePolicy(tt.args.ctx, tt.args.dom); (err != nil) != tt.wantErr { t.Errorf("policy.fetchAndCachePolicy() error = %v, wantErr %v", err, tt.wantErr) @@ -920,17 +934,18 @@ func Test_policy_fetchAndCachePolicy(t *testing.T) { func Test_policy_fetchPolicy(t *testing.T) { type fields struct { - expireMargin time.Duration - rolePolicies gache.Gache - refreshDuration time.Duration - errRetryInterval time.Duration - pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration - etagExpTime time.Duration - athenzURL string - athenzDomains []string - client *http.Client + expireMargin time.Duration + rolePolicies gache.Gache + policyExpiredDuration time.Duration + refreshDuration time.Duration + errRetryInterval time.Duration + pkp pubkey.Provider + etagCache gache.Gache + etagFlushDur time.Duration + etagExpTime time.Duration + athenzURL string + athenzDomains []string + client *http.Client } type args struct { ctx context.Context @@ -954,11 +969,12 @@ func Test_policy_fetchPolicy(t *testing.T) { return test{ name: "test fetch success", fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + policyExpiredDuration: time.Minute * 30, + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -1013,11 +1029,12 @@ func Test_policy_fetchPolicy(t *testing.T) { return test{ name: "test fetch error url", fields: fields{ - athenzURL: " ", - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Second, - client: srv.Client(), + athenzURL: " ", + policyExpiredDuration: time.Minute * 30, + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Second, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -1075,11 +1092,12 @@ func Test_policy_fetchPolicy(t *testing.T) { return test{ name: "test etag exists but not modified", fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: etagCac, - etagExpTime: time.Minute, - expireMargin: time.Second, - client: srv.Client(), + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + policyExpiredDuration: time.Minute * 30, + etagCache: etagCac, + etagExpTime: time.Minute, + expireMargin: time.Second, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -1147,11 +1165,12 @@ func Test_policy_fetchPolicy(t *testing.T) { return test{ name: "test etag exists but modified", fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: etagCac, - etagExpTime: time.Minute, - expireMargin: time.Second, - client: srv.Client(), + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + policyExpiredDuration: time.Minute * 30, + etagCache: etagCac, + etagExpTime: time.Minute, + expireMargin: time.Second, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -1205,11 +1224,12 @@ func Test_policy_fetchPolicy(t *testing.T) { return test{ name: "test fetch error make https request", fields: fields{ - athenzURL: "dummyURL", - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + athenzURL: "dummyURL", + policyExpiredDuration: time.Minute * 30, + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -1247,11 +1267,12 @@ func Test_policy_fetchPolicy(t *testing.T) { return test{ name: "test fetch error return not ok", fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + policyExpiredDuration: time.Minute * 30, + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -1290,11 +1311,12 @@ func Test_policy_fetchPolicy(t *testing.T) { return test{ name: "test fetch error decode policy", fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + policyExpiredDuration: time.Minute * 30, + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -1333,11 +1355,12 @@ func Test_policy_fetchPolicy(t *testing.T) { return test{ name: "test fetch verify error", fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagExpTime: time.Minute, - expireMargin: time.Hour, - client: srv.Client(), + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + policyExpiredDuration: time.Minute * 30, + etagCache: gache.New(), + etagExpTime: time.Minute, + expireMargin: time.Hour, + client: srv.Client(), pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { return VerifierMock{ VerifyFunc: func(d, s string) error { @@ -1370,17 +1393,18 @@ func Test_policy_fetchPolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &policyd{ - expireMargin: tt.fields.expireMargin, - rolePolicies: tt.fields.rolePolicies, - refreshDuration: tt.fields.refreshDuration, - errRetryInterval: tt.fields.errRetryInterval, - pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, - etagExpTime: tt.fields.etagExpTime, - athenzURL: tt.fields.athenzURL, - athenzDomains: tt.fields.athenzDomains, - client: tt.fields.client, + expireMargin: tt.fields.expireMargin, + rolePolicies: tt.fields.rolePolicies, + policyExpiredDuration: tt.fields.policyExpiredDuration, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + pkp: tt.fields.pkp, + etagCache: tt.fields.etagCache, + etagFlushDur: tt.fields.etagFlushDur, + etagExpTime: tt.fields.etagExpTime, + athenzURL: tt.fields.athenzURL, + athenzDomains: tt.fields.athenzDomains, + client: tt.fields.client, } got, got1, err := p.fetchPolicy(tt.args.ctx, tt.args.domain) @@ -1393,17 +1417,18 @@ func Test_policy_fetchPolicy(t *testing.T) { func Test_policy_simplifyAndCache(t *testing.T) { type fields struct { - expireMargin time.Duration - rolePolicies gache.Gache - refreshDuration time.Duration - errRetryInterval time.Duration - pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration - etagExpTime time.Duration - athenzURL string - athenzDomains []string - client *http.Client + expireMargin time.Duration + rolePolicies gache.Gache + policyExpiredDuration time.Duration + refreshDuration time.Duration + errRetryInterval time.Duration + pkp pubkey.Provider + etagCache gache.Gache + etagFlushDur time.Duration + etagExpTime time.Duration + athenzURL string + athenzDomains []string + client *http.Client } type args struct { ctx context.Context @@ -1429,7 +1454,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache success with data", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1524,7 +1550,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "test context done", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: ctx, @@ -1584,7 +1611,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache deny overwrite allow", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1671,7 +1699,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache success with no data", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1692,7 +1721,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache failed with invalid assertion", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1727,7 +1757,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache success with no data", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1748,7 +1779,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache failed with invalid assertion", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1791,7 +1823,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache replace by new assertion", fields: fields{ - rolePolicies: rp, + rolePolicies: rp, + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1853,7 +1886,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache delete", fields: fields{ - rolePolicies: rp, + rolePolicies: rp, + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1914,7 +1948,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache success with 100x100 data", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -1974,7 +2009,8 @@ func Test_policy_simplifyAndCache(t *testing.T) { return test{ name: "cache success with no race condition with 100x100 data", fields: fields{ - rolePolicies: gache.New(), + rolePolicies: gache.New(), + policyExpiredDuration: time.Minute * 30, }, args: args{ ctx: context.Background(), @@ -2034,17 +2070,18 @@ func Test_policy_simplifyAndCache(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &policyd{ - expireMargin: tt.fields.expireMargin, - rolePolicies: tt.fields.rolePolicies, - refreshDuration: tt.fields.refreshDuration, - errRetryInterval: tt.fields.errRetryInterval, - pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, - etagExpTime: tt.fields.etagExpTime, - athenzURL: tt.fields.athenzURL, - athenzDomains: tt.fields.athenzDomains, - client: tt.fields.client, + expireMargin: tt.fields.expireMargin, + rolePolicies: tt.fields.rolePolicies, + policyExpiredDuration: tt.fields.policyExpiredDuration, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + pkp: tt.fields.pkp, + etagCache: tt.fields.etagCache, + etagFlushDur: tt.fields.etagFlushDur, + etagExpTime: tt.fields.etagExpTime, + athenzURL: tt.fields.athenzURL, + athenzDomains: tt.fields.athenzDomains, + client: tt.fields.client, } if err := p.simplifyAndCache(tt.args.ctx, tt.args.sp); (err != nil) != tt.wantErr { t.Errorf("policy.simplifyAndCache() error = %v, wantErr %v", err, tt.wantErr) @@ -2057,3 +2094,64 @@ func Test_policy_simplifyAndCache(t *testing.T) { }) } } + +func Test_policyd_GetPolicyCache(t *testing.T) { + type fields struct { + expireMargin time.Duration + rolePolicies gache.Gache + policyExpiredDuration time.Duration + refreshDuration time.Duration + errRetryInterval time.Duration + pkp pubkey.Provider + etagCache gache.Gache + etagFlushDur time.Duration + etagExpTime time.Duration + athenzURL string + athenzDomains []string + client *http.Client + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + want map[string]interface{} + }{ + { + name: "get policy cache success", + fields: fields{ + rolePolicies: func() gache.Gache { + g := gache.New() + return g + }(), + }, + args: args{ + ctx: context.Background(), + }, + want: make(map[string]interface{}), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &policyd{ + expireMargin: tt.fields.expireMargin, + rolePolicies: tt.fields.rolePolicies, + policyExpiredDuration: tt.fields.policyExpiredDuration, + refreshDuration: tt.fields.refreshDuration, + errRetryInterval: tt.fields.errRetryInterval, + pkp: tt.fields.pkp, + etagCache: tt.fields.etagCache, + etagFlushDur: tt.fields.etagFlushDur, + etagExpTime: tt.fields.etagExpTime, + athenzURL: tt.fields.athenzURL, + athenzDomains: tt.fields.athenzDomains, + client: tt.fields.client, + } + if got := p.GetPolicyCache(tt.args.ctx); !reflect.DeepEqual(got, tt.want) { + t.Errorf("policyd.GetPolicyCache() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/policy/option.go b/policy/option.go index 4ac7c255..f47dac85 100644 --- a/policy/option.go +++ b/policy/option.go @@ -20,25 +20,26 @@ import ( "time" "github.com/pkg/errors" - "github.com/yahoojapan/athenz-policy-updater/pubkey" + "github.com/yahoojapan/athenz-authorizer/pubkey" ) var ( defaultOptions = []Option{ - ExpireMargin("3h"), - EtagFlushDur("12h"), - EtagExpTime("24h"), - RefreshDuration("30m"), - ErrRetryInterval("1m"), - HTTPClient(&http.Client{}), + WithExpireMargin("3h"), + WithEtagFlushDuration("12h"), + WithEtagExpTime("24h"), + WithPolicyExpiredDuration("1m"), + WithRefreshDuration("30m"), + WithErrRetryInterval("1m"), + WithHTTPClient(http.DefaultClient), } ) // Option represents a functional options pattern interface type Option func(*policyd) error -// EtagFlushDur represents a ETagFlushDur functional option -func EtagFlushDur(t string) Option { +// WithEtagFlushDuration represents a ETagFlushDur functional option +func WithEtagFlushDuration(t string) Option { return func(pol *policyd) error { if t == "" { return nil @@ -52,8 +53,8 @@ func EtagFlushDur(t string) Option { } } -// ExpireMargin represents a ExpiryMargin functional option -func ExpireMargin(t string) Option { +// WithExpireMargin represents a ExpiryMargin functional option +func WithExpireMargin(t string) Option { return func(pol *policyd) error { if t == "" { return nil @@ -67,8 +68,8 @@ func ExpireMargin(t string) Option { } } -// EtagExpTime represents a EtagExpTime functional option -func EtagExpTime(t string) Option { +// WithEtagExpTime represents a EtagExpTime functional option +func WithEtagExpTime(t string) Option { return func(pol *policyd) error { if t == "" { return nil @@ -82,8 +83,8 @@ func EtagExpTime(t string) Option { } } -// AthenzURL represents a AthenzURL functional option -func AthenzURL(url string) Option { +// WithAthenzURL represents a AthenzURL functional option +func WithAthenzURL(url string) Option { return func(pol *policyd) error { if url == "" { return nil @@ -93,8 +94,8 @@ func AthenzURL(url string) Option { } } -// AthenzDomains represents a AthenzDomain functional option -func AthenzDomains(doms ...string) Option { +// WithAthenzDomains represents a AthenzDomain functional option +func WithAthenzDomains(doms ...string) Option { return func(pol *policyd) error { if doms == nil { return nil @@ -104,8 +105,23 @@ func AthenzDomains(doms ...string) Option { } } -// RefreshDuration represents a RefreshDuration functional option -func RefreshDuration(t string) Option { +// WithPolicyExpiredDuration represents a PolicyExpiredDuration functional option +func WithPolicyExpiredDuration(t string) Option { + return func(pol *policyd) error { + if t == "" { + return nil + } + rd, err := time.ParseDuration(t) + if err != nil { + return errors.Wrap(err, "invalid refresh duration") + } + pol.policyExpiredDuration = rd + return nil + } +} + +// WithRefreshDuration represents a RefreshDuration functional option +func WithRefreshDuration(t string) Option { return func(pol *policyd) error { if t == "" { return nil @@ -119,8 +135,8 @@ func RefreshDuration(t string) Option { } } -// HTTPClient represents a HttpClient functional option -func HTTPClient(c *http.Client) Option { +// WithHTTPClient represents a HttpClient functional option +func WithHTTPClient(c *http.Client) Option { return func(pol *policyd) error { if c != nil { pol.client = c @@ -129,8 +145,8 @@ func HTTPClient(c *http.Client) Option { } } -// PubKeyProvider represents a PubKeyProvider functional option -func PubKeyProvider(pkp pubkey.Provider) Option { +// WithPubKeyProvider represents a PubKeyProvider functional option +func WithPubKeyProvider(pkp pubkey.Provider) Option { return func(pol *policyd) error { if pkp != nil { pol.pkp = pkp @@ -139,8 +155,8 @@ func PubKeyProvider(pkp pubkey.Provider) Option { } } -// ErrRetryInterval represents a ErrRetryInterval functional option -func ErrRetryInterval(i string) Option { +// WithErrRetryInterval represents a ErrRetryInterval functional option +func WithErrRetryInterval(i string) Option { return func(pol *policyd) error { if i == "" { return nil diff --git a/policy/option_test.go b/policy/option_test.go index ec99d801..7040354a 100644 --- a/policy/option_test.go +++ b/policy/option_test.go @@ -23,10 +23,10 @@ import ( "time" authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" - "github.com/yahoojapan/athenz-policy-updater/pubkey" + "github.com/yahoojapan/athenz-authorizer/pubkey" ) -func TestEtagFlushDur(t *testing.T) { +func TestWithEtagFlushDur(t *testing.T) { type args struct { t string } @@ -84,15 +84,15 @@ func TestEtagFlushDur(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := EtagFlushDur(tt.args.t) + got := WithEtagFlushDuration(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("EtagFlushDur() error = %v", err) + t.Errorf("WithEtagFlushDur() error = %v", err) } }) } } -func TestExpireMargin(t *testing.T) { +func TestWithExpireMargin(t *testing.T) { type args struct { t string } @@ -150,15 +150,15 @@ func TestExpireMargin(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := ExpireMargin(tt.args.t) + got := WithExpireMargin(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("ExpireMargin() error = %v", err) + t.Errorf("WithExpireMargin() error = %v", err) } }) } } -func TestEtagExpTime(t *testing.T) { +func TestWithEtagExpTime(t *testing.T) { type args struct { t string } @@ -216,15 +216,15 @@ func TestEtagExpTime(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := EtagExpTime(tt.args.t) + got := WithEtagExpTime(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("EtagExpTime() error = %v", err) + t.Errorf("WithEtagExpTime() error = %v", err) } }) } } -func TestAthenzURL(t *testing.T) { +func TestWithAthenzURL(t *testing.T) { type args struct { t string } @@ -269,15 +269,15 @@ func TestAthenzURL(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := AthenzURL(tt.args.t) + got := WithAthenzURL(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("AthenzURL() error = %v", err) + t.Errorf("WithAthenzURL() error = %v", err) } }) } } -func TestAthenzDomains(t *testing.T) { +func TestWithAthenzDomains(t *testing.T) { type args struct { t []string } @@ -322,15 +322,15 @@ func TestAthenzDomains(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := AthenzDomains(tt.args.t...) + got := WithAthenzDomains(tt.args.t...) if err := tt.checkFunc(got); err != nil { - t.Errorf("AthenzDomains() error = %v", err) + t.Errorf("WithAthenzDomains() error = %v", err) } }) } } -func TestRefreshDuration(t *testing.T) { +func TestWithRefreshDuration(t *testing.T) { type args struct { t string } @@ -388,15 +388,15 @@ func TestRefreshDuration(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := RefreshDuration(tt.args.t) + got := WithRefreshDuration(tt.args.t) if err := tt.checkFunc(got); err != nil { - t.Errorf("RefreshDuration() error = %v", err) + t.Errorf("WithRefreshDuration() error = %v", err) } }) } } -func TestHTTPClient(t *testing.T) { +func TestWithHTTPClient(t *testing.T) { type args struct { c *http.Client } @@ -445,15 +445,15 @@ func TestHTTPClient(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := HTTPClient(tt.args.c) + got := WithHTTPClient(tt.args.c) if err := tt.checkFunc(got); err != nil { - t.Errorf("HTTPClient() error = %v", err) + t.Errorf("WithHTTPClient() error = %v", err) } }) } } -func TestPubKeyProvider(t *testing.T) { +func TestWithPubKeyProvider(t *testing.T) { type args struct { pkp pubkey.Provider } @@ -504,15 +504,15 @@ func TestPubKeyProvider(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := PubKeyProvider(tt.args.pkp) + got := WithPubKeyProvider(tt.args.pkp) if err := tt.checkFunc(got); err != nil { - t.Errorf("PubKeyProvider() error = %v", err) + t.Errorf("WithPubKeyProvider() error = %v", err) } }) } } -func TestErrRetryInterval(t *testing.T) { +func TestWithErrRetryInterval(t *testing.T) { type args struct { i string } @@ -570,9 +570,9 @@ func TestErrRetryInterval(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := ErrRetryInterval(tt.args.i) + got := WithErrRetryInterval(tt.args.i) if err := tt.checkFunc(got); err != nil { - t.Errorf("ErrRetryInterval() error= %v", err) + t.Errorf("WithErrRetryInterval() error= %v", err) } }) } diff --git a/policy/signed_policy.go b/policy/signed_policy.go index 4a4a18fc..7d9114c9 100644 --- a/policy/signed_policy.go +++ b/policy/signed_policy.go @@ -21,7 +21,7 @@ import ( "github.com/pkg/errors" "github.com/yahoo/athenz/utils/zpe-updater/util" - "github.com/yahoojapan/athenz-policy-updater/pubkey" + "github.com/yahoojapan/athenz-authorizer/pubkey" ) // SignedPolicy represents the signed policy data diff --git a/policy/signed_policy_test.go b/policy/signed_policy_test.go index d537f18d..603d3f07 100644 --- a/policy/signed_policy_test.go +++ b/policy/signed_policy_test.go @@ -22,7 +22,7 @@ import ( "github.com/pkg/errors" authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" "github.com/yahoo/athenz/utils/zpe-updater/util" - "github.com/yahoojapan/athenz-policy-updater/pubkey" + "github.com/yahoojapan/athenz-authorizer/pubkey" ) func TestSignedPolicy_Verify(t *testing.T) { diff --git a/providerd.go b/providerd.go deleted file mode 100644 index faef5ad0..00000000 --- a/providerd.go +++ /dev/null @@ -1,170 +0,0 @@ -/* -Copyright (C) 2018 Yahoo Japan Corporation Athenz team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package providerd - -import ( - "context" - "net/http" - "time" - - "github.com/kpango/gache" - "github.com/kpango/glg" - - "github.com/pkg/errors" - "github.com/yahoojapan/athenz-policy-updater/policy" - "github.com/yahoojapan/athenz-policy-updater/pubkey" - "github.com/yahoojapan/athenz-policy-updater/role" -) - -// Providerd represents a daemon for user to verify the role token -type Providerd interface { - StartProviderd(context.Context) <-chan error - VerifyRoleToken(ctx context.Context, tok, act, res string) error - // VerifyRoleCert(cert []*x509.Certificate) error -} - -type provider struct { - // - pubkeyd pubkey.Pubkeyd - policyd policy.Policyd - roleTokenParser role.RoleTokenParser - - // common parameters - athenzURL string - client *http.Client - - // successful result cache - cache gache.Gache - cacheExp time.Duration - - // pubkeyd parameters - pubkeyRefreshDuration string - pubkeySysAuthDomain string - pubkeyEtagExpTime string - pubkeyEtagFlushDur string - - // policyd parameters - policyExpireMargin string - athenzDomains []string - policyRefreshDuration string - policyEtagFlushDur string - policyEtagExpTime string -} - -// New return Providerd -// This function will initialize the Providerd object with the options -func New(opts ...Option) (Providerd, error) { - prov := &provider{ - cache: gache.New(), - } - - var err error - for _, opt := range append(defaultOptions, opts...) { - if err = opt(prov); err != nil { - return nil, errors.Wrap(err, "error creating providerd") - } - } - - if prov.pubkeyd, err = pubkey.NewPubkeyd( - pubkey.AthenzURL(prov.athenzURL), - pubkey.SysAuthDomain(prov.pubkeySysAuthDomain), - pubkey.ETagExpTime(prov.pubkeyEtagExpTime), - pubkey.ETagFlushDur(prov.pubkeyEtagFlushDur), - pubkey.RefreshDuration(prov.pubkeyRefreshDuration), - pubkey.HTTPClient(prov.client), - ); err != nil { - return nil, errors.Wrap(err, "error create pubkeyd") - } - - if prov.policyd, err = policy.NewPolicyd( - policy.ExpireMargin(prov.policyExpireMargin), - policy.EtagFlushDur(prov.policyEtagFlushDur), - policy.EtagExpTime(prov.policyEtagExpTime), - policy.AthenzURL(prov.athenzURL), - policy.AthenzDomains(prov.athenzDomains...), - policy.RefreshDuration(prov.policyRefreshDuration), - policy.HTTPClient(prov.client), - policy.PubKeyProvider(prov.pubkeyd.GetProvider()), - ); err != nil { - return nil, errors.Wrap(err, "error create policyd") - } - - prov.roleTokenParser = role.NewRoleTokenParser(prov.pubkeyd.GetProvider()) - - return prov, nil -} - -// StartProviderd starts provider daemon. -func (p *provider) StartProviderd(ctx context.Context) <-chan error { - ech := make(chan error, 200) - - g := p.cache.StartExpired(ctx, p.cacheExp/2) - - cech := p.pubkeyd.StartPubkeyUpdater(ctx) - pech := p.policyd.StartPolicyUpdater(ctx) - - go func() { - for { - select { - case <-ctx.Done(): - g.Clear() - ech <- ctx.Err() - return - case err := <-cech: - if err != nil { - ech <- errors.Wrap(err, "update pubkey error") - } - case err := <-pech: - if err != nil { - ech <- errors.Wrap(err, "update policy error") - } - } - } - }() - - return ech -} - -// VerifyRoleToken verifies the role token for specific resource and return and verification error. -func (p *provider) VerifyRoleToken(ctx context.Context, tok, act, res string) error { - if act == "" || res == "" { - return errors.Wrap(ErrInvalidParameters, "empty action / resource") - } - - // check if exists in verification success cache - _, ok := p.cache.Get(tok + act + res) - if ok { - glg.Debugf("use cached roletoken result. tok: %s, act: %s, res: %s", tok, act, res) - return nil - } - - rt, err := p.roleTokenParser.ParseAndValidateRoleToken(tok) - if err != nil { - glg.Debugf("error parse and validate role token, err: %v", err) - return errors.Wrap(err, "error verify role token") - } - if err = p.policyd.CheckPolicy(ctx, rt.Domain, rt.Roles, act, res); err != nil { - glg.Debugf("error check, err: %v", err) - return errors.Wrap(err, "role token unauthorizate") - } - glg.Debugf("set roletoken result. tok: %s, act: %s, res: %s", tok, act, res) - p.cache.SetWithExpire(tok+act+res, struct{}{}, p.cacheExp) - return nil -} - -//func (p *provider) VerifyRoleCert(cert []*x509.Certificate) error { -// return nil -//} diff --git a/providerd_mock_test.go b/providerd_mock_test.go deleted file mode 100644 index 329189a7..00000000 --- a/providerd_mock_test.go +++ /dev/null @@ -1,69 +0,0 @@ -/* -Copyright (C) 2018 Yahoo Japan Corporation Athenz team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package providerd - -import ( - "context" - "time" - - "github.com/pkg/errors" - "github.com/yahoojapan/athenz-policy-updater/policy" - "github.com/yahoojapan/athenz-policy-updater/pubkey" - "github.com/yahoojapan/athenz-policy-updater/role" -) - -type ConfdMock struct { - pubkey.Pubkeyd - confdExp time.Duration -} - -func (cm *ConfdMock) StartPubkeyUpdater(ctx context.Context) <-chan error { - ech := make(chan error, 1) - go func() { - time.Sleep(cm.confdExp) - ech <- errors.New("pubkey error") - }() - return ech -} - -type PolicydMock struct { - policy.Policyd - policydExp time.Duration - wantErr error -} - -func (pm *PolicydMock) StartPolicyUpdater(context.Context) <-chan error { - ech := make(chan error, 1) - go func() { - time.Sleep(pm.policydExp) - ech <- errors.New("policyd error") - }() - return ech -} - -func (pm *PolicydMock) CheckPolicy(ctx context.Context, domain string, roles []string, action, resource string) error { - return pm.wantErr -} - -type RoleTokenMock struct { - role.RoleTokenParser - wantErr error - rt *role.RoleToken -} - -func (rm *RoleTokenMock) ParseAndValidateRoleToken(tok string) (*role.RoleToken, error) { - return rm.rt, rm.wantErr -} diff --git a/providerd_test.go b/providerd_test.go deleted file mode 100644 index b5237852..00000000 --- a/providerd_test.go +++ /dev/null @@ -1,449 +0,0 @@ -/* -Copyright (C) 2018 Yahoo Japan Corporation Athenz team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package providerd - -import ( - "context" - "testing" - "time" - - "github.com/kpango/gache" - "github.com/pkg/errors" - "github.com/yahoojapan/athenz-policy-updater/policy" - "github.com/yahoojapan/athenz-policy-updater/pubkey" - "github.com/yahoojapan/athenz-policy-updater/role" -) - -func TestNew(t *testing.T) { - type args struct { - opts []Option - } - tests := []struct { - name string - args args - checkFunc func(Providerd, error) error - }{ - { - name: "test new success", - args: args{ - []Option{}, - }, - checkFunc: func(prov Providerd, err error) error { - if err != nil { - return errors.Wrap(err, "unexpected error") - } - if prov.(*provider).athenzURL != "www.athenz.com/zts/v1" { - return errors.New("invalid url") - } - if prov.(*provider).pubkeyd == nil { - return errors.New("cannot new pubkeyd") - } - if prov.(*provider).policyd == nil { - return errors.New("cannot new policyd") - } - return nil - }, - }, - { - name: "test new success with options", - args: args{ - []Option{AthenzURL("www.dummy.com")}, - }, - checkFunc: func(prov Providerd, err error) error { - if err != nil { - return errors.Wrap(err, "unexpected error") - } - if prov.(*provider).athenzURL != "www.dummy.com" { - return errors.New("invalid url") - } - return nil - }, - }, - { - name: "test NewPubkeyd returns error", - args: args{ - []Option{PubkeyEtagExpTime("dummy")}, - }, - checkFunc: func(prov Providerd, err error) error { - want := "error create pubkeyd: invalid etag expire time: time: invalid duration dummy" - if err.Error() != want { - return errors.Errorf("Unexpected error: %s, expected: %s", err, want) - } - return nil - }, - }, - { - name: "test NewPolicy returns error", - args: args{ - []Option{PolicyEtagExpTime("dummy")}, - }, - checkFunc: func(prov Providerd, err error) error { - if err.Error() != "error create policyd: error create policyd: invalid etag expire time: time: invalid duration dummy" { - return errors.Wrap(err, "unexpected error") - } - return nil - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, goter := New(tt.args.opts...) - if err := tt.checkFunc(got, goter); err != nil { - t.Errorf("New() error = %v", err) - } - }) - } -} - -func TestStartProviderd(t *testing.T) { - type fields struct { - pubkeyd pubkey.Pubkeyd - policyd policy.Policyd - cache gache.Gache - cacheExp time.Duration - } - type args struct { - ctx context.Context - } - type test struct { - name string - fields fields - args args - checkFunc func(Providerd, error) error - afterFunc func() - } - tests := []test{ - func() test { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*10)) - cm := &ConfdMock{ - confdExp: time.Second, - } - pm := &PolicydMock{ - policydExp: time.Second, - } - return test{ - name: "test context done", - fields: fields{ - pubkeyd: cm, - policyd: pm, - cache: gache.New(), - cacheExp: time.Minute, - }, - args: args{ - ctx: ctx, - }, - checkFunc: func(prov Providerd, err error) error { - if err.Error() != "context deadline exceeded" { - return errors.Wrap(err, "unexpected err") - } - return nil - }, - afterFunc: func() { - cancel() - }, - } - }(), - func() test { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) - cm := &ConfdMock{ - confdExp: time.Millisecond * 10, - } - pm := &PolicydMock{ - policydExp: time.Second, - } - return test{ - name: "test context pubkey updater returns error", - fields: fields{ - pubkeyd: cm, - policyd: pm, - cache: gache.New(), - cacheExp: time.Minute, - }, - args: args{ - ctx: ctx, - }, - checkFunc: func(prov Providerd, err error) error { - if err.Error() != "update pubkey error: pubkey error" { - return errors.Wrap(err, "unexpected err") - } - return nil - }, - afterFunc: func() { - cancel() - }, - } - }(), - func() test { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) - cm := &ConfdMock{ - confdExp: time.Second, - } - pm := &PolicydMock{ - policydExp: time.Millisecond * 10, - } - return test{ - name: "test policyd returns error", - fields: fields{ - pubkeyd: cm, - policyd: pm, - cache: gache.New(), - cacheExp: time.Minute, - }, - args: args{ - ctx: ctx, - }, - checkFunc: func(prov Providerd, err error) error { - if err.Error() != "update policy error: policyd error" { - return errors.Wrap(err, "unexpected err") - } - return nil - }, - afterFunc: func() { - cancel() - }, - } - }(), - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - prov := &provider{ - pubkeyd: tt.fields.pubkeyd, - policyd: tt.fields.policyd, - cache: tt.fields.cache, - cacheExp: tt.fields.cacheExp, - } - ch := prov.StartProviderd(tt.args.ctx) - goter := <-ch - if err := tt.checkFunc(prov, goter); err != nil { - t.Errorf("StartProviderd() error = %v", err) - } - tt.afterFunc() - }) - } -} - -func TestVerifyRoleToken(t *testing.T) { - type args struct { - ctx context.Context - tok string - act string - res string - } - type fields struct { - policyd policy.Policyd - cache gache.Gache - cacheExp time.Duration - roleTokenParser role.RoleTokenParser - } - type test struct { - name string - args args - fields fields - wantErr string - checkFunc func(*provider) error - } - tests := []test{ - func() test { - c := gache.New() - rm := &RoleTokenMock{ - rt: &role.RoleToken{}, - wantErr: nil, - } - cm := &PolicydMock{ - wantErr: nil, - } - return test{ - name: "test verify success", - args: args{ - ctx: context.Background(), - tok: "dummyTok", - act: "dummyAct", - res: "dummyRes", - }, - fields: fields{ - policyd: cm, - roleTokenParser: rm, - cache: c, - cacheExp: time.Minute, - }, - wantErr: "", - checkFunc: func(prov *provider) error { - _, ok := prov.cache.Get("dummyTokdummyActdummyRes") - if !ok { - return errors.New("cannot get dummyTokdummyActdummyRes from cache") - } - return nil - }, - } - }(), - func() test { - c := gache.New() - c.Set("dummyTokdummyActdummyRes", "dummy") - rm := &RoleTokenMock{ - rt: &role.RoleToken{}, - wantErr: nil, - } - cm := &PolicydMock{ - wantErr: nil, - } - return test{ - name: "test use cache success", - args: args{ - ctx: context.Background(), - tok: "dummyTok", - act: "dummyAct", - res: "dummyRes", - }, - fields: fields{ - policyd: cm, - roleTokenParser: rm, - cache: c, - cacheExp: time.Minute, - }, - wantErr: "", - } - }(), - func() test { - c := gache.New() - c.Set("dummyTokdummyActdummyRes", "dummy") - rm := &RoleTokenMock{ - rt: &role.RoleToken{}, - wantErr: nil, - } - cm := &PolicydMock{ - wantErr: nil, - } - return test{ - name: "test empty action", - args: args{ - ctx: context.Background(), - tok: "dummyTok", - act: "", - res: "dummyRes", - }, - fields: fields{ - policyd: cm, - roleTokenParser: rm, - cache: c, - cacheExp: time.Minute, - }, - wantErr: "empty action / resource: Access denied due to invalid/empty action/resource values", - } - }(), - func() test { - c := gache.New() - c.Set("dummyTokdummyActdummyRes", "dummy") - rm := &RoleTokenMock{ - rt: &role.RoleToken{}, - wantErr: nil, - } - cm := &PolicydMock{ - wantErr: nil, - } - return test{ - name: "test empty res", - args: args{ - ctx: context.Background(), - tok: "dummyTok", - act: "dummyAct", - res: "", - }, - fields: fields{ - policyd: cm, - roleTokenParser: rm, - cache: c, - cacheExp: time.Minute, - }, - wantErr: "empty action / resource: Access denied due to invalid/empty action/resource values", - } - }(), - func() test { - c := gache.New() - rm := &RoleTokenMock{ - wantErr: errors.New("cannot parse roletoken"), - } - cm := &PolicydMock{} - return test{ - name: "test parse roletoken error", - args: args{ - ctx: context.Background(), - tok: "dummyTok", - act: "dummyAct", - res: "dummyRes", - }, - fields: fields{ - policyd: cm, - roleTokenParser: rm, - cache: c, - cacheExp: time.Minute, - }, - wantErr: "error verify role token: cannot parse roletoken", - } - }(), - func() test { - c := gache.New() - rm := &RoleTokenMock{ - rt: &role.RoleToken{}, - } - cm := &PolicydMock{ - wantErr: errors.New("deny"), - } - return test{ - name: "test return deny", - args: args{ - ctx: context.Background(), - tok: "dummyTok", - act: "dummyAct", - res: "dummyRes", - }, - fields: fields{ - policyd: cm, - roleTokenParser: rm, - cache: c, - cacheExp: time.Minute, - }, - wantErr: "role token unauthorizate: deny", - } - }(), - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - prov := &provider{ - policyd: tt.fields.policyd, - roleTokenParser: tt.fields.roleTokenParser, - cache: tt.fields.cache, - cacheExp: tt.fields.cacheExp, - } - err := prov.VerifyRoleToken(tt.args.ctx, tt.args.tok, tt.args.act, tt.args.res) - if err != nil { - if err.Error() != tt.wantErr { - t.Errorf("VerifyRoleToken() unexpected error want:%s, result:%s", tt.wantErr, err.Error()) - return - } - } else { - if tt.wantErr != "" { - t.Errorf("VerifyRoleToken() return nil. want %s", tt.wantErr) - return - } - } - if tt.checkFunc != nil { - if err := tt.checkFunc(prov); err != nil { - t.Errorf("VerifyRoleToken() error: %v", err) - } - } - }) - } -} diff --git a/pubkey/pubkeyd.go b/pubkey/daemon.go similarity index 70% rename from pubkey/pubkeyd.go rename to pubkey/daemon.go index a288493d..dab7bb93 100644 --- a/pubkey/pubkeyd.go +++ b/pubkey/daemon.go @@ -26,22 +26,21 @@ import ( "sync" "time" - "github.com/pkg/errors" - "github.com/kpango/gache" "github.com/kpango/glg" + "github.com/pkg/errors" authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" "golang.org/x/sync/errgroup" ) -// Pubkeyd represent the daemon to retrieve public key data. -type Pubkeyd interface { - StartPubkeyUpdater(ctx context.Context) <-chan error - UpdatePubkey(context.Context) error +// Daemon represent the daemon to retrieve public key data. +type Daemon interface { + Start(ctx context.Context) <-chan error + Update(context.Context) error GetProvider() Provider } -type athenzPubkeyd struct { +type pubkeyd struct { athenzURL string sysAuthDomain string refreshDuration time.Duration @@ -86,9 +85,9 @@ var ( regex = regexp.MustCompile("^(http|https)://") ) -// NewPubkeyd represent the constructor of Pubkeyd -func NewPubkeyd(opts ...Option) (Pubkeyd, error) { - c := &athenzPubkeyd{ +// New represent the constructor of Pubkeyd +func New(opts ...Option) (Daemon, error) { + c := &pubkeyd{ confCache: &AthenzConfig{ ZMSPubKeys: new(sync.Map), ZTSPubKeys: new(sync.Map), @@ -105,43 +104,36 @@ func NewPubkeyd(opts ...Option) (Pubkeyd, error) { return c, nil } -// StartPubkeyUpdater starts the pubkey daemon to retrive the public key periodically -func (c *athenzPubkeyd) StartPubkeyUpdater(ctx context.Context) <-chan error { +// Start starts the pubkey daemon to retrive the public key periodically +func (p *pubkeyd) Start(ctx context.Context) <-chan error { glg.Info("Starting pubkey updator") + ech := make(chan error, 100) fch := make(chan struct{}, 1) - if err := c.UpdatePubkey(ctx); err != nil { - ech <- errors.Wrap(err, "error update athenz pubkey") + if err := p.Update(ctx); err != nil { + ech <- errors.Wrap(err, "error update pubkey") fch <- struct{}{} } go func() { defer close(fch) defer close(ech) - c.etagCache.StartExpired(ctx, c.etagFlushDur) - ticker := time.NewTicker(c.refreshDuration) - ebuf := errors.New("") + + p.etagCache.StartExpired(ctx, p.etagFlushDur) + ticker := time.NewTicker(p.refreshDuration) for { select { case <-ctx.Done(): - glg.Info("Stopping pubkeyd") + glg.Info("Stopping pubkey updater") ticker.Stop() - if ebuf.Error() != "" { - ech <- errors.Wrap(ctx.Err(), ebuf.Error()) - } else { - ech <- ctx.Err() - } + ech <- ctx.Err() return case <-fch: - if err := c.UpdatePubkey(ctx); err != nil { - err = errors.Wrap(err, "error update athenz pubkey") - select { - case ech <- errors.Wrap(ebuf, err.Error()): - ebuf = errors.New("") - default: - ebuf = errors.Wrap(ebuf, err.Error()) - } - time.Sleep(c.errRetryInterval) + if err := p.Update(ctx); err != nil { + ech <- errors.Wrap(err, "error update pubkey") + + time.Sleep(p.errRetryInterval) + select { case fch <- struct{}{}: default: @@ -149,14 +141,9 @@ func (c *athenzPubkeyd) StartPubkeyUpdater(ctx context.Context) <-chan error { } } case <-ticker.C: - if err := c.UpdatePubkey(ctx); err != nil { - err = errors.Wrap(err, "error update athenz pubkey") - select { - case ech <- errors.Wrap(ebuf, err.Error()): - ebuf = errors.New("") - default: - ebuf = errors.Wrap(ebuf, err.Error()) - } + if err := p.Update(ctx); err != nil { + ech <- errors.Wrap(err, "error update pubkey") + select { case fch <- struct{}{}: default: @@ -170,8 +157,8 @@ func (c *athenzPubkeyd) StartPubkeyUpdater(ctx context.Context) <-chan error { return ech } -// UpdatePubkey updates and cache athenz public key data -func (c *athenzPubkeyd) UpdatePubkey(ctx context.Context) error { +// Update updates and cache athenz public key data +func (p *pubkeyd) Update(ctx context.Context) error { glg.Info("Updating athenz pubkey") eg := errgroup.Group{} @@ -179,9 +166,9 @@ func (c *athenzPubkeyd) UpdatePubkey(ctx context.Context) error { updConf := func(env AthenzEnv, cache *sync.Map) error { cm := new(sync.Map) dec := new(authcore.YBase64) - pubKeys, upded, err := c.fetchPubKeyEntries(ctx, env) + pubKeys, upded, err := p.fetchPubKeyEntries(ctx, env) if err != nil { - glg.Errorf("Error updating athenz pubkey, error: %v", err) + glg.Errorf("Error updating athenz pubkey, env: %v, error: %v", env, err) return errors.Wrap(err, "error fetch public key entries") } if !upded { @@ -190,19 +177,19 @@ func (c *athenzPubkeyd) UpdatePubkey(ctx context.Context) error { } for _, key := range pubKeys.PublicKeys { - glg.Debugf("Decoding key, keyID: %v", key.ID) + glg.Debugf("Decoding key, env: %v, keyID: %v", env, key.ID) decKey, err := dec.DecodeString(key.Key) if err != nil { - glg.Errorf("error decoding key, error: %v", err) + glg.Errorf("error decoding key, env: %v, error: %v", env, err) return errors.Wrap(err, "error decoding key") } ver, err := authcore.NewVerifier(decKey) if err != nil { - glg.Errorf("error initializing verifier, error: %v", err) + glg.Errorf("error initializing verifier, env: %v, error: %v", env, err) return errors.Wrap(err, "error initializing verifier") } cm.Store(key.ID, ver) - glg.Debugf("Successfully decode key, keyID: %v", key.ID) + glg.Debugf("Successfully decode key, env: %v, keyID: %v", env, key.ID) } cm.Range(func(key interface{}, val interface{}) bool { cache.Store(key, val) @@ -221,7 +208,7 @@ func (c *athenzPubkeyd) UpdatePubkey(ctx context.Context) error { eg.Go(func() error { glg.Info("Updating ZTS athenz pubkey") - if err := updConf(EnvZTS, c.confCache.ZTSPubKeys); err != nil { + if err := updConf(EnvZTS, p.confCache.ZTSPubKeys); err != nil { return errors.Wrap(err, "Error updating ZTS athenz pubkey") } glg.Info("Update ZTS athenz pubkey success") @@ -230,7 +217,7 @@ func (c *athenzPubkeyd) UpdatePubkey(ctx context.Context) error { eg.Go(func() error { glg.Info("Updating ZMS athenz pubkey") - if err := updConf(EnvZMS, c.confCache.ZMSPubKeys); err != nil { + if err := updConf(EnvZMS, p.confCache.ZMSPubKeys); err != nil { return errors.Wrap(err, "Error updating ZMS athenz pubkey") } glg.Info("Update ZMS athenz pubkey success") @@ -245,14 +232,14 @@ func (c *athenzPubkeyd) UpdatePubkey(ctx context.Context) error { } // GetProvider returns the public key provider for user to get the public key -func (c *athenzPubkeyd) GetProvider() Provider { - return c.getPubKey +func (p *pubkeyd) GetProvider() Provider { + return p.getPubKey } -func (c *athenzPubkeyd) fetchPubKeyEntries(ctx context.Context, env AthenzEnv) (*SysAuthConfig, bool, error) { +func (p *pubkeyd) fetchPubKeyEntries(ctx context.Context, env AthenzEnv) (*SysAuthConfig, bool, error) { glg.Info("Fetching public key entries") // https://{www.athenz.com/zts/v1}/domain/sys.auth/service/zts - url := fmt.Sprintf("https://%s/domain/%s/service/%s", c.athenzURL, c.sysAuthDomain, env) + url := fmt.Sprintf("https://%s/domain/%s/service/%s", p.athenzURL, p.sysAuthDomain, env) glg.Debugf("Fetching public key from %s", url) req, err := http.NewRequest(http.MethodGet, url, nil) @@ -262,14 +249,14 @@ func (c *athenzPubkeyd) fetchPubKeyEntries(ctx context.Context, env AthenzEnv) ( } // etag header - t, ok := c.etagCache.Get(string(env)) + t, ok := p.etagCache.Get(string(env)) if ok { eTag := t.(*confCache).eTag glg.Debugf("ETag %v found in the cache", eTag) req.Header.Set("If-None-Match", eTag) } - r, err := c.client.Do(req.WithContext(ctx)) + r, err := p.client.Do(req.WithContext(ctx)) if err != nil { glg.Errorf("Error making HTTP request, error: %v", err) return nil, false, errors.Wrap(err, "error make http request") @@ -305,16 +292,16 @@ func (c *athenzPubkeyd) fetchPubKeyEntries(ctx context.Context, env AthenzEnv) ( eTag := r.Header.Get("ETag") if eTag != "" { glg.Debugf("Setting ETag %v", eTag) - c.etagCache.SetWithExpire(string(env), &confCache{eTag, sac}, c.etagExpTime) + p.etagCache.SetWithExpire(string(env), &confCache{eTag, sac}, p.etagExpTime) } glg.Info("Fetch public key entries success") return sac, true, nil } -func (c *athenzPubkeyd) getPubKey(env AthenzEnv, keyID string) authcore.Verifier { +func (p *pubkeyd) getPubKey(env AthenzEnv, keyID string) authcore.Verifier { if env == EnvZTS { - ver, ok := c.confCache.ZTSPubKeys.Load(keyID) + ver, ok := p.confCache.ZTSPubKeys.Load(keyID) if !ok { glg.Warnf("ZTS PubKey Load Failed keyID[%s] getZTSPubKey %v", keyID, ver) return nil @@ -322,7 +309,7 @@ func (c *athenzPubkeyd) getPubKey(env AthenzEnv, keyID string) authcore.Verifier return ver.(authcore.Verifier) } - ver, ok := c.confCache.ZMSPubKeys.Load(keyID) + ver, ok := p.confCache.ZMSPubKeys.Load(keyID) if !ok { glg.Warnf("ZMS PubKey Load Failed keyID[%s] getZMSPubKey %v", keyID, ver) return nil diff --git a/pubkey/pubkeyd_test.go b/pubkey/daemon_test.go similarity index 92% rename from pubkey/pubkeyd_test.go rename to pubkey/daemon_test.go index 860f529c..b485b23b 100644 --- a/pubkey/pubkeyd_test.go +++ b/pubkey/daemon_test.go @@ -32,61 +32,61 @@ import ( authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" ) -func Test_pubkey_NewPubkeyd(t *testing.T) { +func Test_pubkey_New(t *testing.T) { type args struct { opts []Option } type test struct { name string args args - checkFunc func(Pubkeyd, error) error + checkFunc func(Daemon, error) error } tests := []test{ { - name: "new athenz athenzPubkeyd success", + name: "new athenz pubkeyd success", args: args{ opts: []Option{}, }, - checkFunc: func(got Pubkeyd, err error) error { + checkFunc: func(got Daemon, err error) error { if err != nil { return err } - if got.(*athenzPubkeyd).sysAuthDomain != "sys.auth" { + if got.(*pubkeyd).sysAuthDomain != "sys.auth" { return errors.New("cannot set default options") } return nil }, }, { - name: "new athenz athenzPubkeyd success with options", + name: "new athenz pubkeyd success with options", args: args{ opts: []Option{ - SysAuthDomain("dummyd"), - AthenzURL("dummyURL"), + WithSysAuthDomain("dummyd"), + WithAthenzURL("dummyURL"), }, }, - checkFunc: func(got Pubkeyd, err error) error { + checkFunc: func(got Daemon, err error) error { if err != nil { return err } - if got.(*athenzPubkeyd).sysAuthDomain != "dummyd" || got.(*athenzPubkeyd).athenzURL != "dummyURL" { + if got.(*pubkeyd).sysAuthDomain != "dummyd" || got.(*pubkeyd).athenzURL != "dummyURL" { return errors.New("cannot set optional params") } return nil }, }, { - name: "new athenz athenzPubkeyd success with invalid options", + name: "new athenz pubkeyd success with invalid options", args: args{ opts: []Option{ - SysAuthDomain("dummyd"), - AthenzURL("dummyURL"), - ETagExpTime("invalid"), + WithSysAuthDomain("dummyd"), + WithAthenzURL("dummyURL"), + WithEtagExpTime("invalid"), }, }, - checkFunc: func(got Pubkeyd, err error) error { + checkFunc: func(got Daemon, err error) error { if got != nil { - return errors.New("get invalid Pubkeyd") + return errors.New("get invalid Daemon") } if err.Error() != "invalid etag expire time: time: invalid duration invalid" { return errors.Wrap(err, "unexpected error") @@ -97,17 +97,17 @@ func Test_pubkey_NewPubkeyd(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewPubkeyd(tt.args.opts...) + got, err := New(tt.args.opts...) err = tt.checkFunc(got, err) if err != nil { - t.Errorf("NewPubkeyd() = %v", err) + t.Errorf("New() = %v", err) } }) } } func Test_pubkey_getPubKey(t *testing.T) { - c := &athenzPubkeyd{ + c := &pubkeyd{ confCache: &AthenzConfig{ ZMSPubKeys: new(sync.Map), ZTSPubKeys: new(sync.Map), @@ -190,7 +190,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { name string fields fields args args - checkFunc func(c *athenzPubkeyd, sac *SysAuthConfig, upd bool, err error) error + checkFunc func(c *pubkeyd, sac *SysAuthConfig, upd bool, err error) error } tests := []test{ func() test { @@ -218,7 +218,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { ctx: context.Background(), env: "dummyEnv", }, - checkFunc: func(c *athenzPubkeyd, sac *SysAuthConfig, upd bool, err error) error { + checkFunc: func(c *pubkeyd, sac *SysAuthConfig, upd bool, err error) error { if err != nil { return err } @@ -294,7 +294,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { ctx: context.Background(), env: "dummyEnv", }, - checkFunc: func(c *athenzPubkeyd, sac *SysAuthConfig, upd bool, err error) error { + checkFunc: func(c *pubkeyd, sac *SysAuthConfig, upd bool, err error) error { if err != nil { return err } @@ -352,7 +352,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { ctx: context.Background(), env: "dummyEnv", }, - checkFunc: func(c *athenzPubkeyd, sac *SysAuthConfig, upd bool, err error) error { + checkFunc: func(c *pubkeyd, sac *SysAuthConfig, upd bool, err error) error { if err != nil { return err } @@ -406,7 +406,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { ctx: context.Background(), env: "dummyEnv", }, - checkFunc: func(c *athenzPubkeyd, sac *SysAuthConfig, upd bool, err error) error { + checkFunc: func(c *pubkeyd, sac *SysAuthConfig, upd bool, err error) error { wantErr := "http return status not OK: Fetch athenz pubkey error" if err != nil { if err.Error() == wantErr { @@ -441,7 +441,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { ctx: context.Background(), env: "dummyEnv", }, - checkFunc: func(c *athenzPubkeyd, sac *SysAuthConfig, upd bool, err error) error { + checkFunc: func(c *pubkeyd, sac *SysAuthConfig, upd bool, err error) error { wantErr := `error creating get pubkey request: parse https:// /domain/dummyDom/service/dummyEnv: invalid character " " in host name` if err != nil { if err.Error() == wantErr { @@ -476,7 +476,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { ctx: context.Background(), env: "dummyEnv", }, - checkFunc: func(c *athenzPubkeyd, sac *SysAuthConfig, upd bool, err error) error { + checkFunc: func(c *pubkeyd, sac *SysAuthConfig, upd bool, err error) error { wantErr := "json format not correct: EOF" if err != nil { if err.Error() == wantErr { @@ -492,7 +492,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &athenzPubkeyd{ + c := &pubkeyd{ refreshDuration: tt.fields.refreshDuration, errRetryInterval: tt.fields.errRetryInterval, etagCache: tt.fields.etagCache, @@ -513,7 +513,7 @@ func Test_pubkey_fetchPubKeyEntries(t *testing.T) { } func Test_pubkey_GetProvider(t *testing.T) { - c := &athenzPubkeyd{ + c := &pubkeyd{ confCache: &AthenzConfig{}, } type test struct { @@ -536,7 +536,7 @@ func Test_pubkey_GetProvider(t *testing.T) { } } -func Test_pubkey_UpdatePubkey(t *testing.T) { +func Test_pubkey_Update(t *testing.T) { type fields struct { refreshDuration time.Duration errRetryInterval time.Duration @@ -555,7 +555,7 @@ func Test_pubkey_UpdatePubkey(t *testing.T) { name string fields fields args args - checkFunc func(*athenzPubkeyd, error) error + checkFunc func(*pubkeyd, error) error } tests := []test{ func() test { @@ -590,7 +590,7 @@ func Test_pubkey_UpdatePubkey(t *testing.T) { args: args{ ctx: context.Background(), }, - checkFunc: func(c *athenzPubkeyd, goter error) error { + checkFunc: func(c *pubkeyd, goter error) error { if goter != nil { return goter } @@ -685,7 +685,7 @@ func Test_pubkey_UpdatePubkey(t *testing.T) { args: args{ ctx: context.Background(), }, - checkFunc: func(c *athenzPubkeyd, goter error) error { + checkFunc: func(c *pubkeyd, goter error) error { if goter != nil { return goter } @@ -754,7 +754,7 @@ func Test_pubkey_UpdatePubkey(t *testing.T) { args: args{ ctx: context.Background(), }, - checkFunc: func(c *athenzPubkeyd, goter error) error { + checkFunc: func(c *pubkeyd, goter error) error { wantErr := "error when processing pubkey: Error updating ZMS athenz pubkey: error fetch public key entries: json format not correct: EOF" if goter.Error() != wantErr { return errors.Wrap(goter, "unexpected error") @@ -795,7 +795,7 @@ func Test_pubkey_UpdatePubkey(t *testing.T) { args: args{ ctx: context.Background(), }, - checkFunc: func(c *athenzPubkeyd, goter error) error { + checkFunc: func(c *pubkeyd, goter error) error { wantErr := "error when processing pubkey: Error updating ZMS athenz pubkey: error decoding key: illegal base64 data at input byte 6" if goter.Error() != wantErr { return errors.Wrap(goter, "unexpected error") @@ -836,7 +836,7 @@ func Test_pubkey_UpdatePubkey(t *testing.T) { args: args{ ctx: context.Background(), }, - checkFunc: func(c *athenzPubkeyd, goter error) error { + checkFunc: func(c *pubkeyd, goter error) error { wantErr := "error when processing pubkey: Error updating ZTS athenz pubkey: error initializing verifier: Unable to load public key" if goter.Error() != wantErr { return errors.Wrap(goter, "unexpected error") @@ -849,7 +849,7 @@ func Test_pubkey_UpdatePubkey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &athenzPubkeyd{ + c := &pubkeyd{ refreshDuration: tt.fields.refreshDuration, errRetryInterval: tt.fields.errRetryInterval, etagCache: tt.fields.etagCache, @@ -860,9 +860,9 @@ func Test_pubkey_UpdatePubkey(t *testing.T) { client: tt.fields.client, confCache: tt.fields.confCache, } - err := c.UpdatePubkey(tt.args.ctx) + err := c.Update(tt.args.ctx) if err = tt.checkFunc(c, err); err != nil { - t.Errorf("c.UpdatePubkey() error = %v", err) + t.Errorf("c.Update() error = %v", err) } }) } @@ -887,7 +887,7 @@ func Test_pubkey_StartpubkeyUpdator(t *testing.T) { name string fields fields args args - checkFunc func(*athenzPubkeyd, <-chan error) error + checkFunc func(*pubkeyd, <-chan error) error } tests := []test{ func() test { @@ -926,7 +926,7 @@ func Test_pubkey_StartpubkeyUpdator(t *testing.T) { args: args{ ctx: ctx, }, - checkFunc: func(c *athenzPubkeyd, ch <-chan error) error { + checkFunc: func(c *pubkeyd, ch <-chan error) error { cancel() ind := 0 var err error @@ -995,7 +995,7 @@ func Test_pubkey_StartpubkeyUpdator(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) return test{ - name: "test UpdatePubkey failed", + name: "test Update failed", fields: fields{ athenzURL: strings.Replace(srv.URL, "https://", "", 1), sysAuthDomain: "dummyDom", @@ -1013,11 +1013,11 @@ func Test_pubkey_StartpubkeyUpdator(t *testing.T) { args: args{ ctx: ctx, }, - checkFunc: func(c *athenzPubkeyd, ch <-chan error) error { + checkFunc: func(c *pubkeyd, ch <-chan error) error { goter := <-ch cancel() - want := "error update athenz pubkey: error when processing pubkey: Error updating ZTS athenz pubkey: error fetch public key entries: json format not correct: EOF" + want := "error update pubkey: error when processing pubkey: Error updating ZTS athenz pubkey: error fetch public key entries: json format not correct: EOF" if goter.Error() != want { return errors.Errorf("got: %s, want: %s", goter, want) } @@ -1073,7 +1073,7 @@ func Test_pubkey_StartpubkeyUpdator(t *testing.T) { args: args{ ctx: ctx, }, - checkFunc: func(c *athenzPubkeyd, ch <-chan error) error { + checkFunc: func(c *pubkeyd, ch <-chan error) error { go func() { for { <-ch @@ -1137,7 +1137,7 @@ func Test_pubkey_StartpubkeyUpdator(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &athenzPubkeyd{ + c := &pubkeyd{ refreshDuration: tt.fields.refreshDuration, errRetryInterval: tt.fields.errRetryInterval, etagCache: tt.fields.etagCache, @@ -1148,9 +1148,9 @@ func Test_pubkey_StartpubkeyUpdator(t *testing.T) { client: tt.fields.client, confCache: tt.fields.confCache, } - ch := c.StartPubkeyUpdater(tt.args.ctx) + ch := c.Start(tt.args.ctx) if err := tt.checkFunc(c, ch); err != nil { - t.Errorf("c.StartPubkeyUpdater() error = %v", err) + t.Errorf("c.Start() error = %v", err) } }) } diff --git a/pubkey/option.go b/pubkey/option.go index daf6b9bc..a9b764f5 100644 --- a/pubkey/option.go +++ b/pubkey/option.go @@ -24,21 +24,21 @@ import ( var ( defaultOptions = []Option{ - SysAuthDomain("sys.auth"), - ETagExpTime("168h"), // 1 week - ETagFlushDur("84h"), - RefreshDuration("24h"), - ErrRetryInterval("1m"), - HTTPClient(&http.Client{}), + WithSysAuthDomain("sys.auth"), + WithEtagExpTime("168h"), // 1 week + WithEtagFlushDuration("84h"), + WithRefreshDuration("24h"), + WithErrRetryInterval("1m"), + WithHTTPClient(&http.Client{}), } ) // Option represents a functional options pattern interface -type Option func(*athenzPubkeyd) error +type Option func(*pubkeyd) error -// AthenzURL represents a AthenzURL functional option -func AthenzURL(url string) Option { - return func(c *athenzPubkeyd) error { +// WithAthenzURL represents a AthenzURL functional option +func WithAthenzURL(url string) Option { + return func(c *pubkeyd) error { if url == "" { return nil } @@ -47,9 +47,9 @@ func AthenzURL(url string) Option { } } -// SysAuthDomain represents a SysAuthDomain functional option -func SysAuthDomain(d string) Option { - return func(c *athenzPubkeyd) error { +// WithSysAuthDomain represents a SysAuthDomain functional option +func WithSysAuthDomain(d string) Option { + return func(c *pubkeyd) error { if d == "" { return nil } @@ -58,9 +58,9 @@ func SysAuthDomain(d string) Option { } } -// ETagExpTime represents a ETagExpTime functional option -func ETagExpTime(t string) Option { - return func(c *athenzPubkeyd) error { +// WithEtagExpTime represents a EtagExpTime functional option +func WithEtagExpTime(t string) Option { + return func(c *pubkeyd) error { if t == "" { return nil } @@ -74,9 +74,9 @@ func ETagExpTime(t string) Option { } } -// ETagFlushDur represents a ETagFlushDur functional option -func ETagFlushDur(t string) Option { - return func(c *athenzPubkeyd) error { +// WithEtagFlushDuration represents a EtagFlushDur functional option +func WithEtagFlushDuration(t string) Option { + return func(c *pubkeyd) error { if t == "" { return nil } @@ -90,9 +90,9 @@ func ETagFlushDur(t string) Option { } } -// RefreshDuration represents a RefreshDuration functional option -func RefreshDuration(t string) Option { - return func(c *athenzPubkeyd) error { +// WithRefreshDuration represents a RefreshDuration functional option +func WithRefreshDuration(t string) Option { + return func(c *pubkeyd) error { if t == "" { return nil } @@ -106,9 +106,9 @@ func RefreshDuration(t string) Option { } } -// HTTPClient represents a HTTPClient functional option -func HTTPClient(cl *http.Client) Option { - return func(c *athenzPubkeyd) error { +// WithHTTPClient represents a HTTPClient functional option +func WithHTTPClient(cl *http.Client) Option { + return func(c *pubkeyd) error { if c != nil { c.client = cl } @@ -116,9 +116,9 @@ func HTTPClient(cl *http.Client) Option { } } -// ErrRetryInterval represents a ErrRetryInterval functional option -func ErrRetryInterval(i string) Option { - return func(c *athenzPubkeyd) error { +// WithErrRetryInterval represents a ErrRetryInterval functional option +func WithErrRetryInterval(i string) Option { + return func(c *pubkeyd) error { if i == "" { return nil } diff --git a/pubkey/option_test.go b/pubkey/option_test.go index d7aa1a77..5f7e78b9 100644 --- a/pubkey/option_test.go +++ b/pubkey/option_test.go @@ -23,7 +23,7 @@ import ( "time" ) -func TestAthenzURL(t *testing.T) { +func TestWithAthenzURL(t *testing.T) { type args struct { athenzURL string } @@ -38,7 +38,7 @@ func TestAthenzURL(t *testing.T) { athenzURL: "", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.athenzURL != "" { @@ -53,7 +53,7 @@ func TestAthenzURL(t *testing.T) { athenzURL: "dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.athenzURL != "dummy" { @@ -68,7 +68,7 @@ func TestAthenzURL(t *testing.T) { athenzURL: "http://dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.athenzURL != "dummy" { @@ -83,7 +83,7 @@ func TestAthenzURL(t *testing.T) { athenzURL: "https://dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.athenzURL != "dummy" { @@ -98,7 +98,7 @@ func TestAthenzURL(t *testing.T) { athenzURL: "ftp://dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.athenzURL != "ftp://dummy" { @@ -110,19 +110,19 @@ func TestAthenzURL(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := AthenzURL(tt.args.athenzURL) + got := WithAthenzURL(tt.args.athenzURL) if got == nil { - t.Errorf("AthenzURL() = nil") + t.Errorf("WithAthenzURL() = nil") return } if err := tt.checkFunc(got); err != nil { - t.Errorf("AthenzURL() = %v", err) + t.Errorf("WithAthenzURL() = %v", err) } }) } } -func TestSysAuthDomain(t *testing.T) { +func TestWithSysAuthDomain(t *testing.T) { type args struct { domain string } @@ -137,7 +137,7 @@ func TestSysAuthDomain(t *testing.T) { domain: "dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.sysAuthDomain != "dummy" { @@ -152,7 +152,7 @@ func TestSysAuthDomain(t *testing.T) { domain: "", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.sysAuthDomain != "" { return fmt.Errorf("invalid domain wasset") @@ -163,19 +163,19 @@ func TestSysAuthDomain(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := SysAuthDomain(tt.args.domain) + got := WithSysAuthDomain(tt.args.domain) if got == nil { - t.Errorf("SysAuthDomain() = nil") + t.Errorf("WithSysAuthDomain() = nil") return } if err := tt.checkFunc(got); err != nil { - t.Errorf("SysAuthDomain() = %v", err) + t.Errorf("WithSysAuthDomain() = %v", err) } }) } } -func TestETagExpTime(t *testing.T) { +func TestWithEtagExpTime(t *testing.T) { type args struct { time string } @@ -190,7 +190,7 @@ func TestETagExpTime(t *testing.T) { time: "2h", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.etagExpTime != time.Duration(time.Hour*2) { @@ -205,9 +205,9 @@ func TestETagExpTime(t *testing.T) { time: "", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) - if !reflect.DeepEqual(c, &athenzPubkeyd{}) { + if !reflect.DeepEqual(c, &pubkeyd{}) { return fmt.Errorf("expected no changes, but got %v", c) } return nil @@ -219,7 +219,7 @@ func TestETagExpTime(t *testing.T) { time: "dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} err := got(c) if err == nil { @@ -231,19 +231,19 @@ func TestETagExpTime(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := ETagExpTime(tt.args.time) + got := WithEtagExpTime(tt.args.time) if got == nil { - t.Errorf("ETagExpTime() = nil") + t.Errorf("WithEtagExpTime() = nil") return } if err := tt.checkFunc(got); err != nil { - t.Errorf("ETagExpTime() = %v", err) + t.Errorf("WithEtagExpTime() = %v", err) } }) } } -func TestErrRetryInterval(t *testing.T) { +func TestWithErrRetryInterval(t *testing.T) { type args struct { time string } @@ -258,7 +258,7 @@ func TestErrRetryInterval(t *testing.T) { time: "2h", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.errRetryInterval != time.Duration(time.Hour*2) { @@ -273,9 +273,9 @@ func TestErrRetryInterval(t *testing.T) { time: "", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) - if !reflect.DeepEqual(c, &athenzPubkeyd{}) { + if !reflect.DeepEqual(c, &pubkeyd{}) { return fmt.Errorf("expected no changes, but got %v", c) } return nil @@ -287,7 +287,7 @@ func TestErrRetryInterval(t *testing.T) { time: "dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} err := got(c) if err == nil { @@ -299,19 +299,19 @@ func TestErrRetryInterval(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := ErrRetryInterval(tt.args.time) + got := WithErrRetryInterval(tt.args.time) if got == nil { - t.Errorf("ErrRetryInterval() = nil") + t.Errorf("WithErrRetryInterval() = nil") return } if err := tt.checkFunc(got); err != nil { - t.Errorf("ErrRetryInterval() = %v", err) + t.Errorf("WithErrRetryInterval() = %v", err) } }) } } -func TestETagFlushDur(t *testing.T) { +func TestWithEtagFlushDuration(t *testing.T) { type args struct { dur string } @@ -326,7 +326,7 @@ func TestETagFlushDur(t *testing.T) { dur: "2h", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.etagFlushDur != time.Duration(time.Hour*2) { @@ -341,7 +341,7 @@ func TestETagFlushDur(t *testing.T) { dur: "dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} err := got(c) if err == nil { @@ -356,9 +356,9 @@ func TestETagFlushDur(t *testing.T) { dur: "", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) - if !reflect.DeepEqual(c, &athenzPubkeyd{}) { + if !reflect.DeepEqual(c, &pubkeyd{}) { return fmt.Errorf("expected no changes, but got %v", c) } return nil @@ -367,19 +367,19 @@ func TestETagFlushDur(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := ETagFlushDur(tt.args.dur) + got := WithEtagFlushDuration(tt.args.dur) if got == nil { - t.Errorf("ETagFlushDur() = nil") + t.Errorf("WithEtagFlushDuration() = nil") return } if err := tt.checkFunc(got); err != nil { - t.Errorf("ETagFlushDur() = %v", err) + t.Errorf("WithEtagFlushDuration() = %v", err) } }) } } -func TestRefreshDuration(t *testing.T) { +func TestWithRefreshDuration(t *testing.T) { type args struct { dur string } @@ -394,7 +394,7 @@ func TestRefreshDuration(t *testing.T) { dur: "2h", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) if c.refreshDuration != time.Duration(time.Hour*2) { @@ -409,7 +409,7 @@ func TestRefreshDuration(t *testing.T) { dur: "dummy", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} err := got(c) if err == nil { @@ -424,9 +424,9 @@ func TestRefreshDuration(t *testing.T) { dur: "", }, checkFunc: func(got Option) error { - c := &athenzPubkeyd{} + c := &pubkeyd{} got(c) - if !reflect.DeepEqual(c, &athenzPubkeyd{}) { + if !reflect.DeepEqual(c, &pubkeyd{}) { return fmt.Errorf("expected no changes, but got %v", c) } return nil @@ -435,19 +435,19 @@ func TestRefreshDuration(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := RefreshDuration(tt.args.dur) + got := WithRefreshDuration(tt.args.dur) if got == nil { - t.Errorf("RefreshDuration() = nil") + t.Errorf("WithRefreshDuration() = nil") return } if err := tt.checkFunc(got); err != nil { - t.Errorf("RefreshDuration() = %v", err) + t.Errorf("WithRefreshDuration() = %v", err) } }) } } -func TestHTTPClient(t *testing.T) { +func TestWithHTTPClient(t *testing.T) { type args struct { c *http.Client } @@ -465,7 +465,7 @@ func TestHTTPClient(t *testing.T) { c: c, }, checkFunc: func(opt Option) error { - cd := &athenzPubkeyd{} + cd := &pubkeyd{} if err := opt(cd); err != nil { return err } @@ -483,11 +483,11 @@ func TestHTTPClient(t *testing.T) { nil, }, checkFunc: func(opt Option) error { - cd := &athenzPubkeyd{} + cd := &pubkeyd{} if err := opt(cd); err != nil { return err } - if !reflect.DeepEqual(cd, &athenzPubkeyd{}) { + if !reflect.DeepEqual(cd, &pubkeyd{}) { return fmt.Errorf("expected no changes, but got %v", cd) } return nil @@ -496,9 +496,9 @@ func TestHTTPClient(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := HTTPClient(tt.args.c) + got := WithHTTPClient(tt.args.c) if err := tt.checkFunc(got); err != nil { - t.Errorf("HTTPClient() error = %v", err) + t.Errorf("WithHTTPClient() error = %v", err) } }) } diff --git a/role/asserts/private.pem b/role/asserts/private.pem new file mode 100644 index 00000000..418e79d4 --- /dev/null +++ b/role/asserts/private.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAnM+7BmFS0Ld1u1yl6FrlP8bSSbtwcH0soGn63I1kWxpnrLXl +YiNpYL6ThZCe2TjV4r5m06/cKg/D/P5fPplsqmEyUSN2G550HugW/7bQVShSaJpv +a1MocHxL9pS+xUfWPVt3+SPSM/0W2gUP3ljBQ8MRiJasQe4/PMT1O0LVMHHLswV4 +du1qNqbSzR3Qomlx/sk9F0qWOm5vI8r369pFXg01dMeYiiM4KC0RYPkeSWp23rLk +/wOoWha2o0u9h+5ewkfAYRAxUmttJ6eQ+yt1koEGJ2tgqt8G+0oKTYz3LOeJKaid +2IK0esaBQfAzKvdj34ZnL+ilWw2gewe5U1wQDwIDAQABAoIBAHi1LzJqGGWx0162 +or+JuI6vbZB0SMlOkdupuQGtlWLLoKLCIiC5QZTHHqfh+2Ua6wnvpxesd72pBSTq +aka1s4Az8ZejxHbeMmTYI1wUTao/r6/1sW7cRHTSOWdGeNNDyRbSIjgV6uk6GS0a +WGy/xYVz0zthQJg/3U3aDyve9lyjB8P0THnMKUkpJFapxuHYcQm3PBXSM0YMUaeQ +T/GU3dxVC4YnijQQGBiezOANRVNwHVxdg/SUPoOH1aqNecN0ecWQ0BOnxUrixQsG +3K13Ubkl1JtfltM77tlC1INYNxPwr+vA9K+lwNCJhMZghRdlp+UTakzYFx6OEgQR +BYTckgECgYEAyH0pCvHCMFIJ7GK5Fquw7PGzBC2wdovATCH+sfbS65tu8saWXviM +BwUdY3YsUNv/dk60SKOxU7uQGpQ+w/qcmK/WYLsxFthTFubSUy9QbbVMzpQeBDpi +9AaqY0R4CGHCi4e/sIfJTipEX7XTZDj6rEQwcGkkWdaCMCgOx9XqrHECgYEAyDqr +31ycJjepF3T1QvUuelKgdrcHHoD4jRseug2GQeIv9g8e36E1X9wQYdh67uAIwVTU +y4TZvnRvleBEkd1NkNcJlLmkhKnLnECcNIS17grghN1UFsMLH+uFhRDPyPoyXjA8 +ZAQGh7vfo0shs/L89uHqd98H/Uuu/4lv7EOKxH8CgYEAp7dMJkOn1xRKCN1tSHHI +R+7Jeq1d+U1fSFEH54g7Wa1lWuKV2EzDlvvIYfPxjpL7WdTtK/cD2LAHFLT+7KMx +bOlrxO/TWPEOURI20C+8cIpB/m4Zzh+pt8n7r58ParOdM2wUB3EQDbt+BzLr6+Ne +j81bWC4coqq3reFUvAdPkYECgYB6XHCANWYvbMBm11CytIbMtgXdxogROhuqj7I9 +XNp5FLLemeryGuA1TpSsVtD5feubyi2ome0/GITAgKcmxKkMJH10Z+aENAd723ga +GCfd6sO+LkufBV6dCR81bEqutUdmi++750HeXQ+UCOv1vj6c2P6idqe5QTEWdHTz +W2tIHQKBgQCBorUmvzxuUIhQXmpXUpwBdwOs3xDDBmewHtCCIHC0omgiVYa0p3w9 +aHqDnkBQ7yd3OQFn5LwFaCLN44anhz5w84trWFvIvip0HoCYuCKmumP5jNoQ5f8Q +oBOEwiBq22SeABgCaj7pdQV7zOp/qD5mV/TaHQLoiSwBFYpTGKVGJA== +-----END RSA PRIVATE KEY----- diff --git a/role/asserts/public.pem b/role/asserts/public.pem new file mode 100644 index 00000000..fdf188aa --- /dev/null +++ b/role/asserts/public.pem @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAnM+7BmFS0Ld1u1yl6Frl +P8bSSbtwcH0soGn63I1kWxpnrLXlYiNpYL6ThZCe2TjV4r5m06/cKg/D/P5fPpls +qmEyUSN2G550HugW/7bQVShSaJpva1MocHxL9pS+xUfWPVt3+SPSM/0W2gUP3ljB +Q8MRiJasQe4/PMT1O0LVMHHLswV4du1qNqbSzR3Qomlx/sk9F0qWOm5vI8r369pF +Xg01dMeYiiM4KC0RYPkeSWp23rLk/wOoWha2o0u9h+5ewkfAYRAxUmttJ6eQ+yt1 +koEGJ2tgqt8G+0oKTYz3LOeJKaid2IK0esaBQfAzKvdj34ZnL+ilWw2gewe5U1wQ +DwIDAQAB +-----END PUBLIC KEY----- diff --git a/role/claim.go b/role/claim.go new file mode 100644 index 00000000..a7ce2ea8 --- /dev/null +++ b/role/claim.go @@ -0,0 +1,65 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package role + +import ( + "fmt" + "time" + + jwt "github.com/dgrijalva/jwt-go" +) + +// Claim represents role jwt claim data. +type Claim struct { + Domain string `json:"d"` + Email string `json:"email"` + KeyID string `json:"k"` + MFA string `json:"mfa"` + Role string `json:"r"` + Salt string `json:"a"` + UserID string `json:"u"` + UserName string `json:"n"` + Version string `json:"v"` + jwt.StandardClaims +} + +// Valid is copy from source code, and changed c.VerifyExpiresAt parameter. +func (c *Claim) Valid() error { + vErr := new(jwt.ValidationError) + now := jwt.TimeFunc().Unix() + + if c.VerifyExpiresAt(now, true) == false { + delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0)) + vErr.Inner = fmt.Errorf("token is expired by %v", delta) + vErr.Errors |= jwt.ValidationErrorExpired + } + + if c.VerifyIssuedAt(now, false) == false { + vErr.Inner = fmt.Errorf("Token used before issued") + vErr.Errors |= jwt.ValidationErrorIssuedAt + } + + if c.VerifyNotBefore(now, false) == false { + vErr.Inner = fmt.Errorf("token is not valid yet") + vErr.Errors |= jwt.ValidationErrorNotValidYet + } + + if vErr.Errors == 0 { + return nil + } + + return vErr +} diff --git a/role/option.go b/role/option.go new file mode 100644 index 00000000..a7c5470f --- /dev/null +++ b/role/option.go @@ -0,0 +1,42 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package role + +import ( + "github.com/yahoojapan/athenz-authorizer/jwk" + "github.com/yahoojapan/athenz-authorizer/pubkey" +) + +var ( + defaultOptions = []Option{} +) + +// Option represents a functional options pattern interface +type Option func(*rtp) + +// WithPubkeyProvider represents set pubkey provider functional option +func WithPubkeyProvider(pkp pubkey.Provider) Option { + return func(r *rtp) { + r.pkp = pkp + } +} + +// WithJWKProvider represents set pubkey provider functional option +func WithJWKProvider(jwkp jwk.Provider) Option { + return func(r *rtp) { + r.jwkp = jwkp + } +} diff --git a/role/option_test.go b/role/option_test.go new file mode 100644 index 00000000..e0a3f6fd --- /dev/null +++ b/role/option_test.go @@ -0,0 +1,136 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package role + +import ( + "fmt" + "reflect" + "testing" + + authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" + "github.com/yahoojapan/athenz-authorizer/jwk" + "github.com/yahoojapan/athenz-authorizer/pubkey" +) + +func TestWithPubkeyProvider(t *testing.T) { + type args struct { + pkp pubkey.Provider + } + type test struct { + name string + args args + checkFunc func(Option) error + } + tests := []test{ + func() test { + pkp := pubkey.Provider(func(pubkey.AthenzEnv, string) authcore.Verifier { + return nil + }) + return test{ + name: "set success", + args: args{ + pkp: pkp, + }, + checkFunc: func(opt Option) error { + pol := &rtp{} + opt(pol) + if reflect.ValueOf(pol.pkp) != reflect.ValueOf(pkp) { + return fmt.Errorf("Error") + } + + return nil + }, + } + }(), + { + name: "empty value", + args: args{ + nil, + }, + checkFunc: func(opt Option) error { + pol := &rtp{} + opt(pol) + if !reflect.DeepEqual(pol, &rtp{}) { + return fmt.Errorf("expected no changes, but got %v", pol) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithPubkeyProvider(tt.args.pkp) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithPubkeyProvider() error: %v", err) + } + }) + } +} + +func TestWithJWKProvider(t *testing.T) { + type args struct { + jwkp jwk.Provider + } + type test struct { + name string + args args + checkFunc func(Option) error + } + tests := []test{ + func() test { + pkp := jwk.Provider(func(string) interface{} { + return nil + }) + return test{ + name: "set success", + args: args{ + jwkp: pkp, + }, + checkFunc: func(opt Option) error { + pol := &rtp{} + opt(pol) + if reflect.ValueOf(pol.jwkp) != reflect.ValueOf(pkp) { + return fmt.Errorf("Error") + } + + return nil + }, + } + }(), + { + name: "empty value", + args: args{ + nil, + }, + checkFunc: func(opt Option) error { + pol := &rtp{} + opt(pol) + if !reflect.DeepEqual(pol, &rtp{}) { + return fmt.Errorf("expected no changes, but got %v", pol) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithJWKProvider(tt.args.jwkp) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithJWKProvider() error: %v", err) + } + }) + } +} diff --git a/role/role.go b/role/processor.go similarity index 51% rename from role/role.go rename to role/processor.go index 319ee017..7bfd571b 100644 --- a/role/role.go +++ b/role/processor.go @@ -18,29 +18,35 @@ package role import ( "strings" + jwt "github.com/dgrijalva/jwt-go" "github.com/pkg/errors" - "github.com/yahoojapan/athenz-policy-updater/pubkey" + "github.com/yahoojapan/athenz-authorizer/jwk" + "github.com/yahoojapan/athenz-authorizer/pubkey" ) -// RoleTokenParser represents the role token parser interface. -type RoleTokenParser interface { - ParseAndValidateRoleToken(tok string) (*RoleToken, error) +// Processor represents the role token parser interface. +type Processor interface { + ParseAndValidateRoleToken(tok string) (*Token, error) + ParseAndValidateRoleJWT(cred string) (*Claim, error) } type rtp struct { - pkp pubkey.Provider + pkp pubkey.Provider + jwkp jwk.Provider } -// NewRoleTokenParser returns the RoleTokenParser instance. -func NewRoleTokenParser(prov pubkey.Provider) RoleTokenParser { - return &rtp{ - pkp: prov, +// New returns the Role instance. +func New(opts ...Option) Processor { + r := new(rtp) + for _, opt := range append(defaultOptions, opts...) { + opt(r) } + return r } -// ParseAndValidateRoleToken return the parsed and validiated role token, and return any parsing and validate errors. -func (r *rtp) ParseAndValidateRoleToken(tok string) (*RoleToken, error) { - rt, err := r.parseRoleToken(tok) +// ParseAndValidateRoleToken return the parsed and validated role token, and return any parsing and validate errors. +func (r *rtp) ParseAndValidateRoleToken(tok string) (*Token, error) { + rt, err := r.parseToken(tok) if err != nil { return nil, errors.Wrap(err, "error parse role token") } @@ -51,13 +57,13 @@ func (r *rtp) ParseAndValidateRoleToken(tok string) (*RoleToken, error) { return rt, nil } -func (r *rtp) parseRoleToken(tok string) (*RoleToken, error) { +func (r *rtp) parseToken(tok string) (*Token, error) { st := strings.SplitN(tok, ";s=", 2) if len(st) != 2 { return nil, errors.Wrap(ErrRoleTokenInvalid, "no signature found") } - rt := &RoleToken{ + rt := &Token{ UnsignedToken: st[0], } @@ -73,7 +79,20 @@ func (r *rtp) parseRoleToken(tok string) (*RoleToken, error) { return rt, nil } -func (r *rtp) validate(rt *RoleToken) error { +func (r *rtp) ParseAndValidateRoleJWT(cred string) (*Claim, error) { + tok, err := jwt.ParseWithClaims(cred, &Claim{}, r.keyFunc) + if err != nil { + return nil, err + } + + if claims, ok := tok.Claims.(*Claim); ok && tok.Valid { + return claims, nil + } + + return nil, errors.New("error invalid jwt token") +} + +func (r *rtp) validate(rt *Token) error { if rt.Expired() { return errors.Wrapf(ErrRoleTokenExpired, "token expired") } @@ -83,3 +102,18 @@ func (r *rtp) validate(rt *RoleToken) error { } return ver.Verify(rt.UnsignedToken, rt.Signature) } + +// keyFunc extract the key id from the token, and return corresponding key +func (r *rtp) keyFunc(token *jwt.Token) (interface{}, error) { + keyID, ok := token.Header["kid"] + if !ok { + return nil, errors.New("kid not written in header") + } + + key := r.jwkp(keyID.(string)) + if key == nil { + return nil, errors.Errorf("key cannot be found, keyID: %s", keyID) + } + + return key, nil +} diff --git a/role/processor_test.go b/role/processor_test.go new file mode 100644 index 00000000..d51385e2 --- /dev/null +++ b/role/processor_test.go @@ -0,0 +1,511 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package role + +import ( + "crypto/rsa" + "fmt" + "io/ioutil" + "reflect" + "testing" + "time" + + jwt "github.com/dgrijalva/jwt-go" + authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" + "github.com/yahoojapan/athenz-authorizer/jwk" + "github.com/yahoojapan/athenz-authorizer/pubkey" +) + +func TestNew(t *testing.T) { + type args struct { + opts []Option + } + type test struct { + name string + args args + want Processor + } + tests := []test{ + { + name: "new success", + args: args{ + opts: nil, + }, + want: &rtp{ + nil, + nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := New(tt.args.opts...); !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_rtp_ParseAndValidateRoleToken(t *testing.T) { + type fields struct { + pkp pubkey.Provider + } + type args struct { + tok string + } + tests := []struct { + name string + fields fields + args args + want *Token + wantErr bool + }{ + { + name: "parse validate success", + fields: fields{ + pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { + return VerifierMock{ + VerifyFunc: func(string, string) error { + return nil + }, + } + }, + }, + args: args{ + tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=9999999999;k=0;i=172.16.168.25;s=dummysignature", + }, + want: &Token{ + UnsignedToken: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=9999999999;k=0;i=172.16.168.25", + Domain: "dummy.sidecartest", + ExpiryTime: time.Unix(9999999999, 0), + KeyID: "0", + Roles: []string{"users"}, + Signature: "dummysignature", + }, + }, + { + name: "parse error", + args: args{ + tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=9999999999;k=0;i=172.16.168.25", + }, + wantErr: true, + }, + { + name: "validate error", + fields: fields{ + pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { + return VerifierMock{ + VerifyFunc: func(string, string) error { + return fmt.Errorf("") + }, + } + }, + }, + args: args{ + tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=9999999999;k=0;i=172.16.168.25;s=dummysignature", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &rtp{ + pkp: tt.fields.pkp, + } + got, err := r.ParseAndValidateRoleToken(tt.args.tok) + if (err != nil) != tt.wantErr { + t.Errorf("rtp.ParseAndValidateRoleToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("rtp.ParseAndValidateRoleToken() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_rtp_parseToken(t *testing.T) { + type fields struct { + pkp pubkey.Provider + } + type args struct { + tok string + } + tests := []struct { + name string + fields fields + args args + want *Token + wantErr bool + }{ + { + name: "parse success", + args: args{ + tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=1550643321;k=0;i=172.16.168.25;s=dummysignature", + }, + want: &Token{ + UnsignedToken: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=1550643321;k=0;i=172.16.168.25", + Domain: "dummy.sidecartest", + ExpiryTime: time.Date(2019, 2, 20, 6, 15, 21, 0, time.UTC).Local(), + KeyID: "0", + Roles: []string{"users"}, + Signature: "dummysignature", + }, + }, + { + name: "signature not found", + args: args{ + tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=1550643321;k=0;i=172.16.168.25", + }, + wantErr: true, + }, + { + name: "invalid key value format", + args: args{ + tok: "v=Z1;d=dummy.sidecartest=;r=users;eabcd;s=dummy", + }, + wantErr: true, + }, + { + name: "set value error", + args: args{ + tok: "v=Z1;d=dummy.sidecartest=;r=users;e=abcd;s=dummy", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &rtp{ + pkp: tt.fields.pkp, + } + got, err := r.parseToken(tt.args.tok) + if (err != nil) != tt.wantErr { + t.Errorf("rtp.parseToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("rtp.parseToken() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_rtp_validate(t *testing.T) { + type fields struct { + pkp pubkey.Provider + } + type args struct { + rt *Token + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "validate success", + fields: fields{ + pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { + return VerifierMock{ + VerifyFunc: func(string, string) error { + return nil + }, + } + }, + }, + args: args{ + &Token{ + ExpiryTime: time.Now().Add(time.Hour), + }, + }, + }, + { + name: "token expired", + fields: fields{ + pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { + return VerifierMock{ + VerifyFunc: func(string, string) error { + return nil + }, + } + }, + }, + args: args{ + &Token{ + ExpiryTime: time.Now().Add(-1 * time.Hour), + }, + }, + wantErr: true, + }, + { + name: "validate error", + fields: fields{ + pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { + return VerifierMock{ + VerifyFunc: func(string, string) error { + return fmt.Errorf("") + }, + } + }, + }, + args: args{ + &Token{ + ExpiryTime: time.Now().Add(time.Hour), + }, + }, + wantErr: true, + }, + { + name: "verifier not found", + fields: fields{ + pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { + return nil + }, + }, + args: args{ + &Token{ + ExpiryTime: time.Now().Add(time.Hour), + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &rtp{ + pkp: tt.fields.pkp, + } + if err := r.validate(tt.args.rt); (err != nil) != tt.wantErr { + t.Errorf("rtp.validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_rtp_ParseAndValidateRoleJWT(t *testing.T) { + type fields struct { + pkp pubkey.Provider + jwkp jwk.Provider + } + type args struct { + cred string + } + type test struct { + name string + fields fields + args args + want *Claim + wantErr bool + } + + LoadRSAPublicKeyFromDisk := func(location string) *rsa.PublicKey { + keyData, e := ioutil.ReadFile(location) + if e != nil { + panic(e.Error()) + } + key, e := jwt.ParseRSAPublicKeyFromPEM(keyData) + if e != nil { + panic(e.Error()) + } + return key + } + + tests := []test{ + func() test { + return test{ + name: "verify jwt success", + fields: fields{ + jwkp: jwk.Provider(func(kid string) interface{} { + return LoadRSAPublicKeyFromDisk("./asserts/public.pem") + }), + }, + args: args{ + cred: `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ.eyJuYW1lIjoiSm9obiBEb2UiLCJhZG1pbiI6dHJ1ZSwiZXhwIjo5OTk5OTk5OTk5fQ.MBv8JoDPjlwhwCzPdkVH0C7HGjtLsVdVsbduNSbnIVtLEcD1yfsVqUKpUupYx2h6o_gKgjTbNG2C6zidV6YsxXu5s-D-YSN15MO_Mjm1WJducK0OJURC8o7u83LcgoEXZQTjA3gQVBGSbyNELCBQKN451OHMOPcIYDLdgXS4iqiZPPBxd1VuNGoMtUshZQR5mGp5F3Yk1YQg9QPicN4-gDh-PF5l87ouTj6O1WyxGuY2qHmGzun3xe_Ma1kzslbL95MtzOLR6seCaSCfanUxC2FjD2hPj4I7HZuYIIFsQRAb_pguhh4dkEkb3op5XcpgoHQr26SlkKAUEFLmUa6qvg`, + }, + want: func() *Claim { + c := &Claim{} + c.ExpiresAt = 9999999999 + return c + }(), + wantErr: false, + } + }(), + func() test { + return test{ + name: "verify jwt fail, no expiration definied", + fields: fields{ + jwkp: jwk.Provider(func(kid string) interface{} { + return LoadRSAPublicKeyFromDisk("./asserts/public.pem") + }), + }, + args: args{ + cred: `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ.eyJuYW1lIjoiSm9obiBEb2UiLCJhZG1pbiI6dHJ1ZX0.UtLx_xg2OWF7_sk9P7jcBsS9WqE4st_gvSskRoG92ktDXjSsBa-p2LmArFnFHp-cb3qnXUwc3_Ksg9w10r0iVpxg8lZfGUCmIfauaaoCuxRdogWIAaY4mIXyglQcSgIruo17wMJ-kHyJxr50lWMiyxFYf6ANUE8W2FaiDgwQuGraF4UQKDwmytGai1mHnc8_u5CanEmETWdax-Pe37BikPorljCIoYIyMTpIfdjM3A8s5Ipo8SHagnUPU0a-jS1sU2UjLo4vnDnPwur_6d5im9XuZD6DGHgaQRo4Zh-ZdvEJR8QTtdb2op14jzTaQGLYJNbPiH8yklBhtKMCAPHFuw`, + }, + wantErr: true, + } + }(), + func() test { + return test{ + name: "verify jwt fail, expired jwt", + fields: fields{ + jwkp: jwk.Provider(func(kid string) interface{} { + return LoadRSAPublicKeyFromDisk("./asserts/public.pem") + }), + }, + args: args{ + cred: `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ.eyJuYW1lIjoiSm9obiBEb2UiLCJhZG1pbiI6dHJ1ZSwiZXhwIjoxfQ.h5jrpuSZDjpqo8Ri-yUzq22qis_CIMuTQE6WR5myHW8Z8VhEOLInZU59kmu5Ardud3gjjtMI6kIJrUcVeYBcmE_MG4iMiah767hB-09Bm_lmh6mdEK3wP_m8_JX4OWKHqHyZSZgjJKGNCT-yHZEXuOLpydCLpIaL7znAA3-eDAnyUjZcVipA0J-BwS1I27zHOW6NumQEuXQMau2f1pH4Z77e3etNGA3yG7yG30YaqaSEWfah9BMZwgLx2fnuHAbcyNEpSl5nHZYdTyINtMsurUkDuou8c1G0WIvu4Rn2Wksey0GWdVNsclqeNaFsgsHyVwKsOVFvslQ3qTcwSjw73Q`, + }, + wantErr: true, + } + }(), + func() test { + return test{ + name: "verify jwt fail, invalid signature", + fields: fields{ + jwkp: jwk.Provider(func(kid string) interface{} { + return LoadRSAPublicKeyFromDisk("./asserts/public.pem") + }), + }, + args: args{ + cred: `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ.eyJuYW1lIjoiSm9obiBEb2UiLCJhZG1pbiI6dHJ1ZSwiZXhwIjoxfQ.h5jrpuSZDjpqo8Ri-yUzq22qis_CIMuTQE6WR5myHW8Z8VhEOLInZU59kmu5Ardud3gjjtMI6kIJrUcVeYBcmE_MG4iMiah767hB-09Bm_lmh6mdEK3wP_m8_JX4OWKHqHyZSZgjJKGNCT-yHZEXuOLpydCLpIaL7znAA3-eDAnyUjZcVipA0J-BwS1I27zHOW6NumQEuXQMau2f1pH4Z77e3etNGA3yG7yG30YaqaSEWfah9BMZwgLx2fnuHAbcyNEpSl5nHZYdTyINtMsurUkDuou8c1G0WIvu4Rn2Wksey0GWdVNsclqeNaFsgsHyVwKsOVFvslQ3qTcwSjw73Qe`, + }, + wantErr: true, + } + }(), + func() test { + return test{ + name: "verify jwt fail, invalid jwt format", + fields: fields{ + jwkp: jwk.Provider(func(kid string) interface{} { + return LoadRSAPublicKeyFromDisk("./asserts/public.pem") + }), + }, + args: args{ + cred: `dummy`, + }, + wantErr: true, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &rtp{ + pkp: tt.fields.pkp, + jwkp: tt.fields.jwkp, + } + got, err := r.ParseAndValidateRoleJWT(tt.args.cred) + if (err != nil) != tt.wantErr { + t.Errorf("rtp.ParseAndValidateRoleJWT() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("rtp.ParseAndValidateRoleJWT() = %+v, want %v", got, tt.want) + } + }) + } +} + +func Test_rtp_keyFunc(t *testing.T) { + type fields struct { + pkp pubkey.Provider + jwkp jwk.Provider + } + type args struct { + token *jwt.Token + } + type test struct { + name string + fields fields + args args + want interface{} + wantErr bool + } + tests := []test{ + { + name: "key return success", + fields: fields{ + jwkp: jwk.Provider(func(kid string) interface{} { + if kid == "1" { + return "key" + } + return nil + }), + }, + args: args{ + token: &jwt.Token{ + Header: map[string]interface{}{ + "kid": "1", + }, + }, + }, + want: "key", + }, + { + name: "key header not found", + fields: fields{ + jwkp: jwk.Provider(func(kid string) interface{} { + if kid == "1" { + return "key" + } + return nil + }), + }, + args: args{ + token: &jwt.Token{ + Header: map[string]interface{}{}, + }, + }, + wantErr: true, + }, + { + name: "key not found", + fields: fields{ + jwkp: jwk.Provider(func(kid string) interface{} { + if kid == "1" { + return nil + } + return "key" + }), + }, + args: args{ + token: &jwt.Token{ + Header: map[string]interface{}{ + "kid": "1", + }, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &rtp{ + pkp: tt.fields.pkp, + jwkp: tt.fields.jwkp, + } + got, err := r.keyFunc(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("rtp.keyFunc() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("rtp.keyFunc() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/role/role_test.go b/role/role_test.go deleted file mode 100644 index c1e266cd..00000000 --- a/role/role_test.go +++ /dev/null @@ -1,298 +0,0 @@ -/* -Copyright (C) 2018 Yahoo Japan Corporation Athenz team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package role - -import ( - "fmt" - "reflect" - "testing" - "time" - - authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" - "github.com/yahoojapan/athenz-policy-updater/pubkey" -) - -func TestNewRoleTokenParser(t *testing.T) { - type args struct { - prov pubkey.Provider - } - type test struct { - name string - args args - want RoleTokenParser - } - tests := []test{ - func() test { - /* p := pubkey.Provider(func(pubkey.AthenzEnv, string) authcore.Verifier { - return nil - })*/ - return test{ - name: "new success", - args: args{ - nil, - }, - want: &rtp{ - nil, - }, - } - }(), - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := NewRoleTokenParser(tt.args.prov); !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewRoleTokenParser() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_rtp_ParseAndValidateRoleToken(t *testing.T) { - type fields struct { - pkp pubkey.Provider - } - type args struct { - tok string - } - tests := []struct { - name string - fields fields - args args - want *RoleToken - wantErr bool - }{ - { - name: "parse validate success", - fields: fields{ - pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(string, string) error { - return nil - }, - } - }, - }, - args: args{ - tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=9999999999;k=0;i=172.16.168.25;s=dummysignature", - }, - want: &RoleToken{ - UnsignedToken: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=9999999999;k=0;i=172.16.168.25", - Domain: "dummy.sidecartest", - ExpiryTime: time.Unix(9999999999, 0), - KeyID: "0", - Roles: []string{"users"}, - Signature: "dummysignature", - }, - }, - { - name: "parse error", - args: args{ - tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=9999999999;k=0;i=172.16.168.25", - }, - wantErr: true, - }, - { - name: "validate error", - fields: fields{ - pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(string, string) error { - return fmt.Errorf("") - }, - } - }, - }, - args: args{ - tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=9999999999;k=0;i=172.16.168.25;s=dummysignature", - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &rtp{ - pkp: tt.fields.pkp, - } - got, err := r.ParseAndValidateRoleToken(tt.args.tok) - if (err != nil) != tt.wantErr { - t.Errorf("rtp.ParseAndValidateRoleToken() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("rtp.ParseAndValidateRoleToken() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_rtp_parseRoleToken(t *testing.T) { - type fields struct { - pkp pubkey.Provider - } - type args struct { - tok string - } - tests := []struct { - name string - fields fields - args args - want *RoleToken - wantErr bool - }{ - { - name: "parse success", - args: args{ - tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=1550643321;k=0;i=172.16.168.25;s=dummysignature", - }, - want: &RoleToken{ - UnsignedToken: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=1550643321;k=0;i=172.16.168.25", - Domain: "dummy.sidecartest", - ExpiryTime: time.Date(2019, 2, 20, 6, 15, 21, 0, time.UTC).Local(), - KeyID: "0", - Roles: []string{"users"}, - Signature: "dummysignature", - }, - }, - { - name: "signature not found", - args: args{ - tok: "v=Z1;d=dummy.sidecartest;r=users;p=takumats.tenant.test;h=dummyhost;a=e55ee6ddc3e3c27c;t=1550463321;e=1550643321;k=0;i=172.16.168.25", - }, - wantErr: true, - }, - { - name: "invalid key value format", - args: args{ - tok: "v=Z1;d=dummy.sidecartest=;r=users;eabcd;s=dummy", - }, - wantErr: true, - }, - { - name: "set value error", - args: args{ - tok: "v=Z1;d=dummy.sidecartest=;r=users;e=abcd;s=dummy", - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &rtp{ - pkp: tt.fields.pkp, - } - got, err := r.parseRoleToken(tt.args.tok) - if (err != nil) != tt.wantErr { - t.Errorf("rtp.parseRoleToken() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("rtp.parseRoleToken() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_rtp_validate(t *testing.T) { - type fields struct { - pkp pubkey.Provider - } - type args struct { - rt *RoleToken - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - { - name: "validate success", - fields: fields{ - pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(string, string) error { - return nil - }, - } - }, - }, - args: args{ - &RoleToken{ - ExpiryTime: time.Now().Add(time.Hour), - }, - }, - }, - { - name: "token expired", - fields: fields{ - pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(string, string) error { - return nil - }, - } - }, - }, - args: args{ - &RoleToken{ - ExpiryTime: time.Now().Add(-1 * time.Hour), - }, - }, - wantErr: true, - }, - { - name: "validate error", - fields: fields{ - pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(string, string) error { - return fmt.Errorf("") - }, - } - }, - }, - args: args{ - &RoleToken{ - ExpiryTime: time.Now().Add(time.Hour), - }, - }, - wantErr: true, - }, - { - name: "verifier not found", - fields: fields{ - pkp: func(pubkey.AthenzEnv, string) authcore.Verifier { - return nil - }, - }, - args: args{ - &RoleToken{ - ExpiryTime: time.Now().Add(time.Hour), - }, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &rtp{ - pkp: tt.fields.pkp, - } - if err := r.validate(tt.args.rt); (err != nil) != tt.wantErr { - t.Errorf("rtp.validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/role/role_token.go b/role/token.go similarity index 92% rename from role/role_token.go rename to role/token.go index bab23bcc..4543bd01 100644 --- a/role/role_token.go +++ b/role/token.go @@ -23,8 +23,8 @@ import ( "github.com/pkg/errors" ) -// RoleToken represents role token data. -type RoleToken struct { +// Token represents role token data. +type Token struct { // Version string // required Domain string // required Roles []string // required @@ -42,7 +42,7 @@ type RoleToken struct { } // SetParams sets the value for corresponding key data. -func (r *RoleToken) SetParams(key, value string) error { +func (r *Token) SetParams(key, value string) error { switch key { // case "a": // r.Salt = value @@ -82,6 +82,6 @@ func (r *RoleToken) SetParams(key, value string) error { } // Expired returns if the role token is expired or not. -func (r *RoleToken) Expired() bool { +func (r *Token) Expired() bool { return time.Now().After(r.ExpiryTime) } diff --git a/role/role_token_test.go b/role/token_test.go similarity index 84% rename from role/role_token_test.go rename to role/token_test.go index ed0aa988..ec3b8a63 100644 --- a/role/role_token_test.go +++ b/role/token_test.go @@ -23,7 +23,7 @@ import ( "time" ) -func TestRoleToken_SetParams(t *testing.T) { +func TestToken_SetParams(t *testing.T) { type fields struct { Domain string Roles []string @@ -40,7 +40,7 @@ func TestRoleToken_SetParams(t *testing.T) { name string fields fields args args - checkFunc func(got *RoleToken) error + checkFunc func(got *Token) error wantErr bool }{ { @@ -50,8 +50,8 @@ func TestRoleToken_SetParams(t *testing.T) { key: "d", value: "dummyd", }, - checkFunc: func(got *RoleToken) error { - expected := &RoleToken{ + checkFunc: func(got *Token) error { + expected := &Token{ Domain: "dummyd", } @@ -69,8 +69,8 @@ func TestRoleToken_SetParams(t *testing.T) { key: "e", value: "1550643321", }, - checkFunc: func(got *RoleToken) error { - expected := &RoleToken{ + checkFunc: func(got *Token) error { + expected := &Token{ ExpiryTime: func() time.Time { t, _ := strconv.ParseInt("1550643321", 10, 64) return time.Unix(t, 0) @@ -91,7 +91,7 @@ func TestRoleToken_SetParams(t *testing.T) { key: "e", value: "1550643321", }, - checkFunc: func(got *RoleToken) error { + checkFunc: func(got *Token) error { // 2019-02-20 06:15:21 +0000 UTC expected := time.Date(2019, 2, 20, 6, 15, 21, 0, time.UTC) if !expected.Equal(got.ExpiryTime) { @@ -117,8 +117,8 @@ func TestRoleToken_SetParams(t *testing.T) { key: "k", value: "dummyk", }, - checkFunc: func(got *RoleToken) error { - expected := &RoleToken{ + checkFunc: func(got *Token) error { + expected := &Token{ KeyID: "dummyk", } @@ -136,8 +136,8 @@ func TestRoleToken_SetParams(t *testing.T) { key: "r", value: "r1,r2", }, - checkFunc: func(got *RoleToken) error { - expected := &RoleToken{ + checkFunc: func(got *Token) error { + expected := &Token{ Roles: []string{"r1", "r2"}, } @@ -155,8 +155,8 @@ func TestRoleToken_SetParams(t *testing.T) { key: "s", value: "dummys", }, - checkFunc: func(got *RoleToken) error { - expected := &RoleToken{ + checkFunc: func(got *Token) error { + expected := &Token{ Signature: "dummys", } @@ -170,7 +170,7 @@ func TestRoleToken_SetParams(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &RoleToken{ + r := &Token{ Domain: tt.fields.Domain, Roles: tt.fields.Roles, ExpiryTime: tt.fields.ExpiryTime, @@ -179,18 +179,18 @@ func TestRoleToken_SetParams(t *testing.T) { UnsignedToken: tt.fields.UnsignedToken, } if err := r.SetParams(tt.args.key, tt.args.value); (err != nil) != tt.wantErr { - t.Errorf("RoleToken.SetParams() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Token.SetParams() error = %v, wantErr %v", err, tt.wantErr) } if tt.checkFunc != nil { if err := tt.checkFunc(r); err != nil { - t.Errorf("RoleToken set not expected, err: %v", err) + t.Errorf("Token set not expected, err: %v", err) } } }) } } -func TestRoleToken_Expired(t *testing.T) { +func TestToken_Expired(t *testing.T) { type fields struct { Domain string Roles []string @@ -221,7 +221,7 @@ func TestRoleToken_Expired(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &RoleToken{ + r := &Token{ Domain: tt.fields.Domain, Roles: tt.fields.Roles, ExpiryTime: tt.fields.ExpiryTime, @@ -230,7 +230,7 @@ func TestRoleToken_Expired(t *testing.T) { UnsignedToken: tt.fields.UnsignedToken, } if got := r.Expired(); got != tt.want { - t.Errorf("RoleToken.Expired() = %v, want %v", got, tt.want) + t.Errorf("Token.Expired() = %v, want %v", got, tt.want) } }) }