Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"strings"
"testing"
"text/template"
"time"

"github.com/digitorus/pkcs7"
Expand Down Expand Up @@ -64,6 +67,7 @@ import (
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/join/iamjoin"
"github.com/gravitational/teleport/lib/join/joinclient"
"github.com/gravitational/teleport/lib/kube/token"
"github.com/gravitational/teleport/lib/oidc/fakeissuer"
Expand Down Expand Up @@ -689,9 +693,9 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
remoteAddr := "42.42.42.42:42"

t.Run("IAM method", func(t *testing.T) {
a.SetHTTPClientForAWSSTS(&mockClient{
a.SetHTTPClientForAWSSTS(&mockSTSClient{
respStatusCode: http.StatusOK,
respBody: responseFromAWSIdentity(auth.AWSIdentity{
respBody: responseFromAWSIdentity(iamjoin.AWSIdentity{
Account: "1234",
Arn: "arn:aws::1111",
}),
Expand Down Expand Up @@ -822,6 +826,54 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
})
}

func responseFromAWSIdentity(id iamjoin.AWSIdentity) string {
return fmt.Sprintf(`{
"GetCallerIdentityResponse": {
"GetCallerIdentityResult": {
"Account": "%s",
"Arn": "%s"
}}}`, id.Account, id.Arn)
}

type mockSTSClient struct {
respStatusCode int
respBody string
}

func (c *mockSTSClient) Do(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: c.respStatusCode,
Body: io.NopCloser(strings.NewReader(c.respBody)),
}, nil
}

var identityRequestTemplate = template.Must(template.New("sts-request").Parse(`POST / HTTP/1.1
Host: {{.Host}}
User-Agent: aws-sdk-go/1.37.17 (go1.17.1; darwin; amd64)
Content-Length: 43
Accept: application/json
Authorization: AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20211102/us-east-1/sts/aws4_request, SignedHeaders=accept;content-length;content-type;host;x-amz-date;x-amz-security-token;{{.SignedHeader}}, Signature=111
Content-Type: application/x-www-form-urlencoded; charset=utf-8
X-Amz-Date: 20211102T204300Z
X-Amz-Security-Token: aaa
X-Teleport-Challenge: {{.Challenge}}

Action=GetCallerIdentity&Version=2011-06-15`))

type identityRequestTemplateInput struct {
Host string
SignedHeader string
Challenge string
}

func defaultIdentityRequestTemplateInput(challenge string) identityRequestTemplateInput {
return identityRequestTemplateInput{
Host: "sts.amazonaws.com",
SignedHeader: "x-teleport-challenge;",
Challenge: challenge,
}
}

// authClientForRegisterResult is a test helper that creats an auth client for
// the given [*joinclient.JoinResult].
func authClientForRegisterResult(t *testing.T, ctx context.Context, addr *utils.NetAddr, result *joinclient.JoinResult) *authclient.Client {
Expand Down Expand Up @@ -948,9 +1000,9 @@ func TestRegisterBot_BotInstanceRejoin(t *testing.T) {
return nil, errMockInvalidToken
})

a.SetHTTPClientForAWSSTS(&mockClient{
a.SetHTTPClientForAWSSTS(&mockSTSClient{
respStatusCode: http.StatusOK,
respBody: responseFromAWSIdentity(auth.AWSIdentity{
respBody: responseFromAWSIdentity(iamjoin.AWSIdentity{
Account: "1234",
Arn: "arn:aws::1111",
}),
Expand Down
15 changes: 0 additions & 15 deletions lib/auth/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"net/url"
"time"

"github.com/coreos/go-semver/semver"
"github.com/jonboulle/clockwork"
"github.com/julienschmidt/httprouter"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -197,10 +196,6 @@ func (a *Server) ResetPassword(ctx context.Context, username string) error {
return a.resetPassword(ctx, username)
}

func (a *Server) SetHTTPClientForAWSSTS(clt utils.HTTPDoClient) {
a.httpClientForAWSSTS = clt
}

func (a *Server) SetJWKSValidator(clt JWKSValidator) {
a.k8sJWKSValidator = clt
}
Expand Down Expand Up @@ -410,7 +405,6 @@ func CheckOracleAllowRules(claims oracle.Claims, token string, allowRules []*typ
}

type GitHubManager = githubManager
type AWSIdentity = awsIdentity
type AttestedData = attestedData
type SignedAttestedData = signedAttestedData
type JWKSValidator = k8sJWKSValidator
Expand All @@ -421,7 +415,6 @@ type AzureVerifyTokenFunc = azureVerifyTokenFunc
type AccessTokenClaims = accessTokenClaims
type EC2Client = ec2Client
type EC2ClientKey = ec2ClientKey
type IAMRegisterOption = iamRegisterOption

func WithAzureCerts(certs []*x509.Certificate) AzureRegisterOption {
return func(cfg *AzureRegisterConfig) {
Expand All @@ -441,14 +434,6 @@ func WithAzureVMClientGetter(getVMClient vmClientGetter) AzureRegisterOption {
}
}

func WithFIPS(b bool) iamRegisterOption {
return withFips(b)
}

func WithAuthVersion(v *semver.Version) iamRegisterOption {
return withAuthVersion(v)
}

func (s *TLSServer) GRPCServer() *GRPCServer {
return s.grpcServer
}
1 change: 1 addition & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5852,6 +5852,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
Authorizer: cfg.Authorizer,
AuthService: cfg.AuthServer,
Clock: cfg.AuthServer.clock,
FIPS: cfg.AuthServer.fips,
}))

integrationServiceServer, err := integrationv1.NewService(&integrationv1.ServiceConfig{
Expand Down
35 changes: 4 additions & 31 deletions lib/auth/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ package auth

import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"log/slog"
Expand All @@ -44,6 +42,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/join"
"github.com/gravitational/teleport/lib/join/joinutils"
)

// checkTokenJoinRequestCommon checks all token join rules that are common to
Expand Down Expand Up @@ -92,7 +91,7 @@ func (a *Server) handleJoinFailure(
}

// Fetch and encode rawJoinAttrs if they are available.
attributesStruct, err := rawJoinAttrsToStruct(rawJoinAttrs)
attributesStruct, err := joinutils.RawJoinAttrsToStruct(rawJoinAttrs)
if err != nil {
a.logger.WarnContext(ctx, "Unable to fetch join attributes from join method", "error", err)
}
Expand Down Expand Up @@ -394,7 +393,7 @@ func (a *Server) GenerateBotCertsForJoin(
},
}
var err error
joinEvent.Attributes, err = rawJoinAttrsToStruct(params.RawJoinClaims)
joinEvent.Attributes, err = joinutils.RawJoinAttrsToStruct(params.RawJoinClaims)
if err != nil {
a.logger.WarnContext(
ctx,
Expand Down Expand Up @@ -552,7 +551,7 @@ func (a *Server) GenerateHostCertsForJoin(
RemoteAddr: params.RemoteAddr,
},
}
joinEvent.Attributes, err = rawJoinAttrsToStruct(params.RawJoinClaims)
joinEvent.Attributes, err = joinutils.RawJoinAttrsToStruct(params.RawJoinClaims)
if err != nil {
a.logger.WarnContext(ctx, "Unable to fetch join attributes from join method", "error", err)
}
Expand All @@ -562,21 +561,6 @@ func (a *Server) GenerateHostCertsForJoin(
return certs, nil
}

func rawJoinAttrsToStruct(in any) (*apievents.Struct, error) {
if in == nil {
return nil, nil
}
attrBytes, err := json.Marshal(in)
if err != nil {
return nil, trace.Wrap(err, "marshaling join attributes")
}
out := &apievents.Struct{}
if err := out.UnmarshalJSON(attrBytes); err != nil {
return nil, trace.Wrap(err, "unmarshaling join attributes")
}
return out, nil
}

func rawJoinAttrsToGoogleStruct(in any) (*structpb.Struct, error) {
if in == nil {
return nil, nil
Expand All @@ -591,14 +575,3 @@ func rawJoinAttrsToGoogleStruct(in any) (*structpb.Struct, error) {
}
return out, nil
}

func generateChallenge(encoding *base64.Encoding, length int) (string, error) {
// read crypto-random bytes to generate the challenge
challengeRawBytes := make([]byte, length)
if _, err := rand.Read(challengeRawBytes); err != nil {
return "", trace.Wrap(err)
}

// encode the challenge to base64 so it can be sent over HTTP
return encoding.EncodeToString(challengeRawBytes), nil
}
8 changes: 4 additions & 4 deletions lib/auth/join/iam/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@ type stsIdentityRequestOptions struct {
imdsClient imdsClient
}

type stsIdentityRequestOption func(cfg *stsIdentityRequestOptions)
type STSIdentityRequestOption func(cfg *stsIdentityRequestOptions)

// WithFIPSEndpoint is a functional option to use a FIPS STS endpoint. In non-US
// regions, this will use the us-east-1 FIPS endpoint.
func WithFIPSEndpoint(useFIPS bool) stsIdentityRequestOption {
func WithFIPSEndpoint(useFIPS bool) STSIdentityRequestOption {
return func(opts *stsIdentityRequestOptions) {
opts.useFIPS = useFIPS
}
}

// WithIMDSClient is a functional option to use a custom IMDS client.
func WithIMDSClient(clt imdsClient) stsIdentityRequestOption {
func WithIMDSClient(clt imdsClient) STSIdentityRequestOption {
return func(opts *stsIdentityRequestOptions) {
opts.imdsClient = clt
}
Expand All @@ -75,7 +75,7 @@ type imdsClient interface {

// CreateSignedSTSIdentityRequest is called on the client side and returns an
// sts:GetCallerIdentity request signed with the local AWS credentials
func CreateSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) {
func CreateSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...STSIdentityRequestOption) ([]byte, error) {
var options stsIdentityRequestOptions
for _, opt := range opts {
opt(&options)
Expand Down
9 changes: 8 additions & 1 deletion lib/auth/join/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ type RegisterParams struct {
GitlabParams GitlabParams
// BoundKeypairParams contains parameters specific to bound keypair joining.
BoundKeypairParams *BoundKeypairParams
// CreateSignedSTSIdentityRequestFunc overrides the function used to
// generate a signed AWs sts:GetCallerIdentity request.
CreateSignedSTSIdentityRequestFunc func(ctx context.Context, challenge string, opts ...iam.STSIdentityRequestOption) ([]byte, error)
}

func (r *RegisterParams) CheckAndSetDefaults() error {
Expand All @@ -205,6 +208,10 @@ func (r *RegisterParams) CheckAndSetDefaults() error {
return trace.BadParameter("no auth or proxy servers set")
}

if r.CreateSignedSTSIdentityRequestFunc == nil {
r.CreateSignedSTSIdentityRequestFunc = iam.CreateSignedSTSIdentityRequest
}

return nil
}

Expand Down Expand Up @@ -797,7 +804,7 @@ func registerUsingIAMMethod(
// Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request.
certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
// create the signed sts:GetCallerIdentity request and include the challenge
signedRequest, err := iam.CreateSignedSTSIdentityRequest(ctx, challenge,
signedRequest, err := params.CreateSignedSTSIdentityRequestFunc(ctx, challenge,
iam.WithFIPSEndpoint(params.FIPS),
)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/join/joinutils"
liboidc "github.com/gravitational/teleport/lib/oidc"
"github.com/gravitational/teleport/lib/utils"
)
Expand Down Expand Up @@ -430,7 +431,7 @@ func (a *Server) checkAzureRequest(
}

func generateAzureChallenge() (string, error) {
challenge, err := generateChallenge(base64.RawURLEncoding, 24)
challenge, err := joinutils.GenerateChallenge(base64.RawURLEncoding, 24)
return challenge, trace.Wrap(err)
}

Expand Down
18 changes: 2 additions & 16 deletions lib/auth/join_gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ package auth

import (
"context"
"regexp"
"strings"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/gitlab"
"github.com/gravitational/teleport/lib/join/joinutils"
)

type gitlabIDTokenValidator interface {
Expand Down Expand Up @@ -87,20 +86,7 @@ func joinRuleGlobMatch(want string, got string) (bool, error) {
if want == "" {
return true, nil
}
return globMatch(want, got)
}

// globMatch performs simple a simple glob-style match test on a string.
// - '*' matches zero or more characters.
// - '?' matches any single character.
// It returns true if a match is detected.
func globMatch(pattern, str string) (bool, error) {
pattern = regexp.QuoteMeta(pattern)
pattern = strings.ReplaceAll(pattern, `\*`, ".*")
pattern = strings.ReplaceAll(pattern, `\?`, ".")
pattern = "^" + pattern + "$"
matched, err := regexp.MatchString(pattern, str)
return matched, trace.Wrap(err)
return joinutils.GlobMatch(want, got)
}

func checkGitLabAllowRules(token *types.ProvisionTokenV2, claims *gitlab.IDTokenClaims) error {
Expand Down
Loading
Loading