From 8560f7e21b4af6c3b64e7da9729d19e405d9e55b Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Tue, 30 Sep 2025 11:14:29 -0700 Subject: [PATCH 1/2] use new join service for host joins --- integration/proxy/proxy_helpers.go | 7 +- lib/auth/bot_test.go | 36 +++--- lib/auth/join/join.go | 4 + lib/auth/join_test.go | 14 +-- .../workloadidentityv1_test.go | 4 +- lib/auth/storage/storage.go | 25 ++++ lib/auth/tls_test.go | 60 ++++------ lib/join/joinclient/join.go | 27 ++++- lib/service/connect.go | 111 +++++++++++------- lib/service/service.go | 21 +++- 10 files changed, 193 insertions(+), 116 deletions(-) diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index 9fd91dc400115..3c14595c3dae1 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -52,10 +52,10 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/integration/helpers" "github.com/gravitational/teleport/lib" - "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/kube/kubeconfig" testingkubemock "github.com/gravitational/teleport/lib/kube/proxy/testing/kube_server" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -653,11 +653,10 @@ func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token str t.Setenv("AWS_REGION", "us-west-2") node := uuid.NewString() - _, err = join.Register(context.TODO(), join.RegisterParams{ + _, err = joinclient.Join(t.Context(), joinclient.JoinParams{ Token: token, ID: state.IdentityID{ - Role: types.RoleNode, - HostUUID: node, + Role: types.RoleInstance, NodeName: node, }, ProxyServer: proxyAddr, diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index f857b5e2f5b2d..a650748587d0f 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -57,7 +57,6 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/authtest" - "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" @@ -65,6 +64,7 @@ import ( libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/kube/token" "github.com/gravitational/teleport/lib/oidc/fakeissuer" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -154,7 +154,7 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) { require.NoError(t, err) require.NoError(t, client.CreateToken(ctx, token)) - result, err := join.Register(ctx, join.RegisterParams{ + result, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -294,7 +294,7 @@ func TestBotJoinAttrs_Kubernetes(t *testing.T) { require.NoError(t, err) require.NoError(t, client.CreateToken(ctx, tok)) - result, err := join.Register(ctx, join.RegisterParams{ + result, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: tok.GetName(), JoinMethod: types.JoinMethodKubernetes, ID: state.IdentityID{ @@ -406,7 +406,7 @@ func TestRegisterBotInstance(t *testing.T) { require.NoError(t, err) require.NoError(t, client.CreateToken(ctx, token)) - result, err := join.Register(ctx, join.RegisterParams{ + result, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -552,7 +552,7 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) { require.NoError(t, err) require.NoError(t, client.CreateToken(ctx, token)) - result, err := join.Register(ctx, join.RegisterParams{ + result, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -628,7 +628,7 @@ func TestRegisterBotCertificateExtensions(t *testing.T) { require.NoError(t, err) require.NoError(t, client.CreateToken(ctx, token)) - result, err := join.Register(ctx, join.RegisterParams{ + result, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -823,8 +823,8 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { } // authClientForRegisterResult is a test helper that creats an auth client for -// the given [*join.RegisterResult]. -func authClientForRegisterResult(t *testing.T, ctx context.Context, addr *utils.NetAddr, result *join.RegisterResult) *authclient.Client { +// the given [*joinclient.JoinResult]. +func authClientForRegisterResult(t *testing.T, ctx context.Context, addr *utils.NetAddr, result *joinclient.JoinResult) *authclient.Client { privateKeyPEM, err := keys.MarshalPrivateKey(result.PrivateKey) require.NoError(t, err) sshPub, err := ssh.NewPublicKey(result.PrivateKey.Public()) @@ -895,14 +895,14 @@ func instanceIDFromCerts(t *testing.T, certs *proto.Certs) (string, uint64) { return ident.BotInstanceID, ident.Generation } -// registerHelper calls `join.Register` with the given token, prefilling params +// registerHelper calls `joinclient.Join` with the given token, prefilling params // where possible. Overrides may be applied with `fns`. func registerHelper( ctx context.Context, token types.ProvisionToken, addr *utils.NetAddr, - fns ...func(*join.RegisterParams), -) (*join.RegisterResult, error) { - params := join.RegisterParams{ + fns ...func(*joinclient.JoinParams), +) (*joinclient.JoinResult, error) { + params := joinclient.JoinParams{ JoinMethod: token.GetJoinMethod(), Token: token.GetName(), ID: state.IdentityID{ @@ -918,7 +918,7 @@ func registerHelper( fn(¶ms) } - result, err := join.Register(ctx, params) + result, err := joinclient.Join(ctx, params) return result, trace.Wrap(err) } @@ -1015,7 +1015,7 @@ func TestRegisterBot_BotInstanceRejoin(t *testing.T) { require.NoError(t, a.UpsertToken(ctx, awsToken)) // Join as a "bot" with both token types. - k8sResult, err := registerHelper(ctx, k8sToken, addr, func(p *join.RegisterParams) { + k8sResult, err := registerHelper(ctx, k8sToken, addr, func(p *joinclient.JoinParams) { p.KubernetesReadFileFunc = k8sReadFileFunc }) require.NoError(t, err) @@ -1035,7 +1035,7 @@ func TestRegisterBot_BotInstanceRejoin(t *testing.T) { // Rejoin using the k8s client and make sure we're issued certs with the // same instance ID. k8sClient := authClientForRegisterResult(t, ctx, addr, k8sResult) - rejoinedK8sResult, err := registerHelper(ctx, k8sToken, addr, func(p *join.RegisterParams) { + rejoinedK8sResult, err := registerHelper(ctx, k8sToken, addr, func(p *joinclient.JoinParams) { p.KubernetesReadFileFunc = k8sReadFileFunc p.AuthClient = k8sClient }) @@ -1049,7 +1049,7 @@ func TestRegisterBot_BotInstanceRejoin(t *testing.T) { // join service, the instance ID must be provided to auth by the proxy as // part of the `RegisterUsingTokenRequest`. iamClient := authClientForRegisterResult(t, ctx, addr, awsResult) - rejoinedAWSResult, err := registerHelper(ctx, awsToken, addr, func(p *join.RegisterParams) { + rejoinedAWSResult, err := registerHelper(ctx, awsToken, addr, func(p *joinclient.JoinParams) { p.AuthClient = iamClient }) require.NoError(t, err) @@ -1229,7 +1229,7 @@ func TestRegisterBotMultipleTokens(t *testing.T) { require.NoError(t, err) require.NoError(t, client.CreateToken(ctx, tokenB)) - resultA, err := join.Register(ctx, join.RegisterParams{ + resultA, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: tokenA.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -1242,7 +1242,7 @@ func TestRegisterBotMultipleTokens(t *testing.T) { initialInstanceA, _ := instanceIDFromCerts(t, certsA) require.NotEmpty(t, initialInstanceA) - resultB, err := join.Register(ctx, join.RegisterParams{ + resultB, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: tokenB.GetName(), ID: state.IdentityID{ Role: types.RoleBot, diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go index 226a386016a89..98734617f532b 100644 --- a/lib/auth/join/join.go +++ b/lib/auth/join/join.go @@ -261,6 +261,10 @@ type RegisterResult struct { // running on a different host than the auth server. This method requires a // provision token that will be used to authenticate as an identity that should // be allowed to join the cluster. +// +// Deprecated: this function is superceded by lib/join/joinclient.Join +// +// TODO(nklaassen): DELETE IN 20 func Register(ctx context.Context, params RegisterParams) (result *RegisterResult, err error) { ctx, span := tracer.Start(ctx, "Register") defer func() { tracing.EndSpan(span, err) }() diff --git a/lib/auth/join_test.go b/lib/auth/join_test.go index d690c380bea99..6d35271fc0e9a 100644 --- a/lib/auth/join_test.go +++ b/lib/auth/join_test.go @@ -36,12 +36,12 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authtest" - "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -299,9 +299,9 @@ func newBotToken(t *testing.T, tokenName, botName string, role types.SystemRole, return token } -// TestRegister_Bot tests that a provision token can be used to generate +// TestJoin_Bot tests that a provision token can be used to generate // renewable certificates for a non-interactive user. -func TestRegister_Bot(t *testing.T) { +func TestJoin_Bot(t *testing.T) { t.Parallel() ctx := context.Background() @@ -369,7 +369,7 @@ func TestRegister_Bot(t *testing.T) { } { t.Run(test.desc, func(t *testing.T) { start := srv.Clock().Now() - result, err := join.Register(ctx, join.RegisterParams{ + result, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: test.token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -413,9 +413,9 @@ func TestRegister_Bot(t *testing.T) { } } -// TestRegister_Bot_Expiry checks that bot certificate expiry can be set, and +// TestJoin_Bot_Expiry checks that bot certificate expiry can be set, and // does not exceed the limit. -func TestRegister_Bot_Expiry(t *testing.T) { +func TestJoin_Bot_Expiry(t *testing.T) { t.Parallel() ctx := context.Background() @@ -465,7 +465,7 @@ func TestRegister_Bot_Expiry(t *testing.T) { tok := newBotToken(t, uuid.NewString(), botName, types.RoleBot, srv.Clock().Now().Add(time.Hour)) require.NoError(t, srv.Auth().UpsertToken(ctx, tok)) - result, err := join.Register(ctx, join.RegisterParams{ + result, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: tok.GetName(), ID: state.IdentityID{ Role: types.RoleBot, diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go index c1b8f66d6e93a..886101f2eb0ca 100644 --- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go +++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go @@ -56,12 +56,12 @@ import ( "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/authtest" - "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/cryptosuites" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/join/joinclient" libjwt "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/oidc/fakeissuer" @@ -298,7 +298,7 @@ func TestIssueWorkloadIdentityE2E(t *testing.T) { require.NoError(t, err) // With the basic setup complete, we can now "fake" a join. - botCerts, err := join.Register(ctx, join.RegisterParams{ + botCerts, err := joinclient.Join(ctx, joinclient.JoinParams{ Token: token.GetName(), JoinMethod: types.JoinMethodKubernetes, ID: state.IdentityID{ diff --git a/lib/auth/storage/storage.go b/lib/auth/storage/storage.go index 5bb255bcc2a2c..dc7c573ac52a7 100644 --- a/lib/auth/storage/storage.go +++ b/lib/auth/storage/storage.go @@ -382,6 +382,31 @@ func readHostIDFromStorages(ctx context.Context, dataDir string, kubeBackend sta return hostID, trace.Wrap(err) } +// PersistAssignedHostID writes an assigned host ID to state storage and the +// host_uuid file. This should not be called in the same process as +// ReadOrGenerateHostID, it is intended to persist a host UUID assigned by the +// Auth service that was not generated locally. With the new auth-assigned host +// persisted to storage to maintain compatibility with any other processes that +// UUID flow the agent doesn't even need to read the host ID, it is only +// may read it. +func (p *ProcessStorage) PersistAssignedHostID(ctx context.Context, cfg *servicecfg.Config, hostID string) error { + if p.stateStorage != nil { + if _, err := p.stateStorage.Put( + ctx, + backend.Item{ + Key: backend.NewKey(hostid.FileName), + Value: []byte(hostID), + }, + ); err != nil { + return trace.Wrap(err, "persisting host ID to state storage") + } + } + if err := hostid.WriteFile(cfg.DataDir, hostID); err != nil { + return trace.Wrap(err, "persisting host ID to file") + } + return nil +} + // persistHostIDToStorages writes the host ID to local data and to // Kubernetes Secret if this process is running on a Kubernetes Cluster. func persistHostIDToStorages(ctx context.Context, cfg *servicecfg.Config, hostID string, kubeBackend stateBackend) error { diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 5efb4e71d2906..ea9a828bb9125 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -65,7 +65,6 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/authtest" - "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/authz" @@ -75,6 +74,7 @@ import ( "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/itertools/stream" + "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/modules/modulestest" @@ -3613,9 +3613,8 @@ func TestTLSFailover(t *testing.T) { } } -// TestRegisterCAPin makes sure that registration only works with a valid -// CA pin. -func TestRegisterCAPin(t *testing.T) { +// TestJoinCAPin makes sure that joining only works with a valid CA pin. +func TestJoinCAPin(t *testing.T) { t.Parallel() ctx := context.Background() @@ -3639,14 +3638,13 @@ func TestRegisterCAPin(t *testing.T) { require.Len(t, caPins, 1) caPin := caPins[0] - // Attempt to register with valid CA pin, should work. - _, err = join.Register(ctx, join.RegisterParams{ + // Attempt to join with valid CA pin, should work. + _, err = joinclient.Join(ctx, joinclient.JoinParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ - HostUUID: "once", NodeName: "node-name", - Role: types.RoleProxy, + Role: types.RoleInstance, }, AdditionalPrincipals: []string{"example.com"}, CAPins: []string{caPin}, @@ -3654,15 +3652,14 @@ func TestRegisterCAPin(t *testing.T) { }) require.NoError(t, err) - // Attempt to register with multiple CA pins where the auth server only + // Attempt to join with multiple CA pins where the auth server only // matches one, should work. - _, err = join.Register(ctx, join.RegisterParams{ + _, err = joinclient.Join(ctx, joinclient.JoinParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ - HostUUID: "once", NodeName: "node-name", - Role: types.RoleProxy, + Role: types.RoleInstance, }, AdditionalPrincipals: []string{"example.com"}, CAPins: []string{"sha256:123", caPin}, @@ -3670,14 +3667,13 @@ func TestRegisterCAPin(t *testing.T) { }) require.NoError(t, err) - // Attempt to register with invalid CA pin, should fail. - _, err = join.Register(ctx, join.RegisterParams{ + // Attempt to join with invalid CA pin, should fail. + _, err = joinclient.Join(ctx, joinclient.JoinParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ - HostUUID: "once", NodeName: "node-name", - Role: types.RoleProxy, + Role: types.RoleInstance, }, AdditionalPrincipals: []string{"example.com"}, CAPins: []string{"sha256:123"}, @@ -3685,14 +3681,13 @@ func TestRegisterCAPin(t *testing.T) { }) require.Error(t, err) - // Attempt to register with multiple invalid CA pins, should fail. - _, err = join.Register(ctx, join.RegisterParams{ + // Attempt to join with multiple invalid CA pins, should fail. + _, err = joinclient.Join(ctx, joinclient.JoinParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ - HostUUID: "once", NodeName: "node-name", - Role: types.RoleProxy, + Role: types.RoleInstance, }, AdditionalPrincipals: []string{"example.com"}, CAPins: []string{"sha256:123", "sha256:456"}, @@ -3719,14 +3714,13 @@ func TestRegisterCAPin(t *testing.T) { require.NoError(t, err) require.Len(t, caPins, 2) - // Attempt to register with multiple CA pins, should work - _, err = join.Register(ctx, join.RegisterParams{ + // Attempt to join with multiple CA pins, should work + _, err = joinclient.Join(ctx, joinclient.JoinParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ - HostUUID: "once", NodeName: "node-name", - Role: types.RoleProxy, + Role: types.RoleInstance, }, AdditionalPrincipals: []string{"example.com"}, CAPins: caPins, @@ -3735,9 +3729,9 @@ func TestRegisterCAPin(t *testing.T) { require.NoError(t, err) } -// TestRegisterCAPath makes sure registration only works with a valid CA +// TestJoinCAPath makes sure joining only works with a valid CA // file on disk. -func TestRegisterCAPath(t *testing.T) { +func TestJoinCAPath(t *testing.T) { t.Parallel() ctx := context.Background() @@ -3754,14 +3748,13 @@ func TestRegisterCAPath(t *testing.T) { testSrv.Auth(), ) - // Attempt to register with nothing at the CA path, should work. - _, err := join.Register(ctx, join.RegisterParams{ + // Attempt to join with nothing at the CA path, should work. + _, err := joinclient.Join(ctx, joinclient.JoinParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ - HostUUID: "once", NodeName: "node-name", - Role: types.RoleProxy, + Role: types.RoleInstance, }, AdditionalPrincipals: []string{"example.com"}, Clock: clock, @@ -3781,14 +3774,13 @@ func TestRegisterCAPath(t *testing.T) { err = os.WriteFile(caPath, certPem, teleport.FileMaskOwnerOnly) require.NoError(t, err) - // Attempt to register with valid CA path, should work. - _, err = join.Register(ctx, join.RegisterParams{ + // Attempt to join with valid CA path, should work. + _, err = joinclient.Join(ctx, joinclient.JoinParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ - HostUUID: "once", NodeName: "node-name", - Role: types.RoleProxy, + Role: types.RoleInstance, }, AdditionalPrincipals: []string{"example.com"}, CAPath: caPath, diff --git a/lib/join/joinclient/join.go b/lib/join/joinclient/join.go index a6444173f065e..cfb50f9cc86b4 100644 --- a/lib/join/joinclient/join.go +++ b/lib/join/joinclient/join.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/join/internal/messages" "github.com/gravitational/teleport/lib/join/joinv1" + "github.com/gravitational/teleport/lib/utils/hostid" ) type ( @@ -48,17 +49,41 @@ func Join(ctx context.Context, params JoinParams) (*JoinResult, error) { if err := params.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } + if params.ID.HostUUID != "" { + return nil, trace.BadParameter("HostUUID must not be provided to Join, it will be assigned by the Auth server") + } + if params.ID.Role != types.RoleInstance && params.ID.Role != types.RoleBot { + return nil, trace.BadParameter("Only Instance and Bot roles may be used for direct join attempts") + } slog.InfoContext(ctx, "Trying to join with the new join service") result, err := joinNew(ctx, params) if trace.IsNotImplemented(err) { // Fall back to joining via legacy service. slog.InfoContext(ctx, "Falling back to joining via the legacy join service", "error", err) - result, err := authjoin.Register(ctx, params) + // Non-bots must generate their own host UUID when joining via legacy service. + if params.ID.Role != types.RoleBot { + hostID, err := hostid.Generate(ctx, params.JoinMethod) + if err != nil { + return nil, trace.Wrap(err, "generating host ID") + } + params.ID.HostUUID = hostID + } + result, err := LegacyJoin(ctx, params) return result, trace.Wrap(err) } return result, trace.Wrap(err) } +// LegacyJoin is used to join the cluster via the legacy service with client-chosen host UUIDs. +func LegacyJoin(ctx context.Context, params JoinParams) (*JoinResult, error) { + if params.ID.Role != types.RoleBot && params.ID.HostUUID == "" { + return nil, trace.BadParameter("HostUUID is required for LegacyJoin") + } + //nolint:staticcheck // SA1019 falling back to deprecated method for compatibility. + result, err := authjoin.Register(ctx, params) + return result, trace.Wrap(err) +} + func joinNew(ctx context.Context, params JoinParams) (*JoinResult, error) { if params.AuthClient != nil { return joinViaAuthClient(ctx, params, params.AuthClient) diff --git a/lib/service/connect.go b/lib/service/connect.go index 2c63c9b0f15fb..ea293041f0b18 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -46,17 +46,16 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/entitlements" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" - "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/openssh" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -426,6 +425,15 @@ func (process *TeleportProcess) firstTimeConnect(role types.SystemRole) (*Connec if err != nil { return nil, trace.NewAggregate(err, connector.Close()) } + + if role == types.RoleInstance { + // Instance always joins first, only try to persist host ID after + // successfully completing the instance join and persisting the state + // and identity. + if err := process.storage.PersistAssignedHostID(process.GracefulExitContext(), process.Config, identity.ID.HostID()); err != nil { + return nil, trace.Wrap(err, "persisting host ID to storage") + } + } process.logger.InfoContext(process.ExitContext(), "The process successfully wrote the credentials and state to the disk.", "identity", role) return connector, nil } @@ -463,7 +471,7 @@ func (process *TeleportProcess) firstTimeConnectIdentityRemote(role types.System if role == types.RoleInstance { // Always need to go through the join process to get the first Instance // identity. - return process.join(role) + return process.instanceJoin() } // Wait for the instance connector to see if it can be used to reregister // without going through the join process. @@ -488,7 +496,7 @@ func (process *TeleportProcess) firstTimeConnectIdentityRemote(role types.System // with this requested role, which should only happen if the new join // service with auth-assigned host UUIDs is not available. process.Config.Logger.InfoContext(process.GracefulExitContext(), "Instance identity does not include required system role, must re-join with a provision token", "role", role) - return process.joinWithHostUUID(role, instanceIdentity.ID.HostID()) + return process.legacyJoinWithHostUUID(role, instanceIdentity.ID.HostID()) } // The instance connector does have the role requested, we can reregister // without going through the join process. @@ -511,46 +519,76 @@ func (process *TeleportProcess) firstTimeConnectIdentityRemote(role types.System return identity, trace.Wrap(err) } -func (process *TeleportProcess) join(role types.SystemRole) (*state.Identity, error) { - // TODO(nklaassen): Host UUID should be generated and assigned by the auth - // service during joining. - hostUUID, err := process.storage.ReadOrGenerateHostID(process.GracefulExitContext(), process.Config) +func (process *TeleportProcess) instanceJoin() (*state.Identity, error) { + id := state.IdentityID{ + Role: types.RoleInstance, + NodeName: process.Config.Hostname, + } + additionalPrincipals, dnsNames := process.instanceAdditionalPrincipals() + joinParams, err := process.makeJoinParams(id, additionalPrincipals, dnsNames) if err != nil { return nil, trace.Wrap(err) } - if _, err := uuid.Parse(hostUUID); err != nil && !aws.IsEC2NodeID(hostUUID) { - process.Config.Logger.WarnContext(process.GracefulExitContext(), "Host UUID is not a true UUID (not eligible for UUID-based proxying)", "host_uuid", hostUUID) - } - return process.joinWithHostUUID(role, hostUUID) -} - -func (process *TeleportProcess) joinWithHostUUID(role types.SystemRole, hostUUID string) (*state.Identity, error) { - if !process.Config.HasToken() { - return nil, trace.BadParameter("%v must join a cluster and needs a provisioning token", role) - } - process.logger.InfoContext(process.ExitContext(), "Joining the cluster with a secure token.") - token, err := process.Config.Token() + joinResult, err := joinclient.Join(process.GracefulExitContext(), *joinParams) if err != nil { + if utils.IsUntrustedCertErr(err) { + return nil, trace.WrapWithMessage(err, utils.SelfSignedCertsMsg) + } return nil, trace.Wrap(err) } - - dataDir := defaults.DataDir - if process.Config.DataDir != "" { - dataDir = process.Config.DataDir + privateKeyPEM, err := keys.MarshalPrivateKey(joinResult.PrivateKey) + if err != nil { + return nil, trace.Wrap(err) } + identity, err := state.ReadIdentityFromKeyPair(privateKeyPEM, joinResult.Certs) + return identity, trace.Wrap(err) +} +func (process *TeleportProcess) legacyJoinWithHostUUID(role types.SystemRole, hostUUID string) (*state.Identity, error) { id := state.IdentityID{ Role: role, - HostUUID: hostUUID, NodeName: process.Config.Hostname, + HostUUID: hostUUID, } additionalPrincipals, dnsNames, err := process.getAdditionalPrincipals(role, hostUUID) if err != nil { return nil, trace.Wrap(err) } + joinParams, err := process.makeJoinParams(id, additionalPrincipals, dnsNames) + if err != nil { + return nil, trace.Wrap(err) + } + process.logger.InfoContext(process.ExitContext(), "Joining the cluster with a secure token.") + joinResult, err := joinclient.LegacyJoin(process.GracefulExitContext(), *joinParams) + if err != nil { + if utils.IsUntrustedCertErr(err) { + return nil, trace.WrapWithMessage(err, utils.SelfSignedCertsMsg) + } + return nil, trace.Wrap(err) + } + privateKeyPEM, err := keys.MarshalPrivateKey(joinResult.PrivateKey) + if err != nil { + return nil, trace.Wrap(err) + } + identity, err := state.ReadIdentityFromKeyPair(privateKeyPEM, joinResult.Certs) + return identity, trace.Wrap(err) +} - registerParams := join.RegisterParams{ +func (process *TeleportProcess) makeJoinParams( + id state.IdentityID, + additionalPrincipals []string, + dnsNames []string, +) (*joinclient.JoinParams, error) { + if !process.Config.HasToken() { + return nil, trace.BadParameter("must join a cluster but no token was configured") + } + token, err := process.Config.Token() + if err != nil { + return nil, trace.Wrap(err) + } + dataDir := cmp.Or(process.Config.DataDir, defaults.DataDir) + joinParams := &joinclient.JoinParams{ Token: token, ID: id, AuthServers: process.Config.AuthServerAddresses(), @@ -569,27 +607,12 @@ func (process *TeleportProcess) joinWithHostUUID(role types.SystemRole, hostUUID FIPS: process.Config.FIPS, Insecure: lib.IsInsecureDevMode(), } - if registerParams.JoinMethod == types.JoinMethodAzure { - registerParams.AzureParams = join.AzureParams{ + if joinParams.JoinMethod == types.JoinMethodAzure { + joinParams.AzureParams = joinclient.AzureParams{ ClientID: process.Config.JoinParams.Azure.ClientID, } } - - result, err := join.Register(process.ExitContext(), registerParams) - if err != nil { - if utils.IsUntrustedCertErr(err) { - return nil, trace.WrapWithMessage(err, utils.SelfSignedCertsMsg) - } - return nil, trace.Wrap(err) - } - - privateKeyPEM, err := keys.MarshalPrivateKey(result.PrivateKey) - if err != nil { - return nil, trace.Wrap(err) - } - - identity, err := state.ReadIdentityFromKeyPair(privateKeyPEM, result.Certs) - return identity, trace.Wrap(err) + return joinParams, nil } func (process *TeleportProcess) initOpenSSH() { diff --git a/lib/service/service.go b/lib/service/service.go index d90e18a0ff631..bcab8dea8dc71 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4240,11 +4240,22 @@ func (process *TeleportProcess) initTracingService() error { return nil } +func (process *TeleportProcess) instanceAdditionalPrincipals() (principals []string, dnsNames []string) { + if process.Config.Hostname != "" { + principals = append(principals, process.Config.Hostname) + if lh := utils.ToLowerCaseASCII(process.Config.Hostname); lh != process.Config.Hostname { + // openssh expects all hostnames to be lowercase + principals = append(principals, lh) + } + } + // Add default DNSNames to the dnsNames list. + dnsNames = append(dnsNames, auth.DefaultDNSNamesForRole(types.RoleInstance)...) + return principals, dnsNames +} + // getAdditionalPrincipals returns a list of additional principals to add // to role's service certificates. -func (process *TeleportProcess) getAdditionalPrincipals(role types.SystemRole, hostUUID string) ([]string, []string, error) { - var principals []string - var dnsNames []string +func (process *TeleportProcess) getAdditionalPrincipals(role types.SystemRole, hostUUID string) (principals []string, dnsNames []string, err error) { if process.Config.Hostname != "" { principals = append(principals, process.Config.Hostname) if lh := utils.ToLowerCaseASCII(process.Config.Hostname); lh != process.Config.Hostname { @@ -4252,12 +4263,10 @@ func (process *TeleportProcess) getAdditionalPrincipals(role types.SystemRole, h principals = append(principals, lh) } } - var addrs []utils.NetAddr - // Add default DNSNames to the dnsNames list. - // For identities generated by teleport <= v6.1.6 the teleport.cluster.local DNS is not present dnsNames = append(dnsNames, auth.DefaultDNSNamesForRole(role)...) + var addrs []utils.NetAddr switch role { case types.RoleProxy: addrs = append(process.Config.Proxy.PublicAddrs, From 00605da825b670dc73b368e8138f895b1e2c239f Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 15 Oct 2025 11:25:29 -0700 Subject: [PATCH 2/2] fall back to legacy join on connection errors --- lib/join/joinclient/join.go | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/lib/join/joinclient/join.go b/lib/join/joinclient/join.go index cfb50f9cc86b4..2f333712cd1a3 100644 --- a/lib/join/joinclient/join.go +++ b/lib/join/joinclient/join.go @@ -21,6 +21,7 @@ import ( "crypto" "crypto/x509" "encoding/pem" + "errors" "log/slog" "github.com/gravitational/trace" @@ -57,7 +58,7 @@ func Join(ctx context.Context, params JoinParams) (*JoinResult, error) { } slog.InfoContext(ctx, "Trying to join with the new join service") result, err := joinNew(ctx, params) - if trace.IsNotImplemented(err) { + if trace.IsNotImplemented(err) || errors.As(err, new(*connectionError)) { // Fall back to joining via legacy service. slog.InfoContext(ctx, "Falling back to joining via the legacy join service", "error", err) // Non-bots must generate their own host UUID when joining via legacy service. @@ -129,7 +130,7 @@ func joinViaProxy(ctx context.Context, params JoinParams, proxyAddr string) (*Jo }, ) if err != nil { - return nil, trace.Wrap(err) + return nil, &connectionError{trace.Wrap(err, "building proxy client")} } defer conn.Close() return joinWithClient(ctx, params, joinv1.NewClientFromConn(conn)) @@ -138,7 +139,7 @@ func joinViaProxy(ctx context.Context, params JoinParams, proxyAddr string) (*Jo func joinViaAuth(ctx context.Context, params JoinParams) (*JoinResult, error) { authClient, err := authjoin.NewAuthClient(ctx, params) if err != nil { - return nil, trace.Wrap(err, "building auth client") + return nil, &connectionError{trace.Wrap(err, "building auth client")} } defer authClient.Close() return joinViaAuthClient(ctx, params, authClient) @@ -168,7 +169,10 @@ func joinWithClient(ctx context.Context, params JoinParams, client *joinv1.Clien defer cancel() stream, err := client.Join(ctx) if err != nil { - return nil, trace.Wrap(err) + // Connection errors are usually delayed until the first request is + // attempted, wrap with a connectionError here to allow a fallback to + // the legacy join method. + return nil, &connectionError{trace.Wrap(err, "initiating join stream")} } defer stream.CloseSend() @@ -363,3 +367,15 @@ func generateKeys(ctx context.Context, suite types.SignatureAlgorithmSuite) (cry PublicSSHKey: sshPub.Marshal(), }, nil } + +type connectionError struct { + wrapped error +} + +func (e *connectionError) Error() string { + return e.wrapped.Error() +} + +func (e *connectionError) Unwrap() error { + return e.wrapped +}