diff --git a/tool/tctl/common/collection.go b/tool/tctl/common/collection.go index 9ec30ae873de9..c4890fb75d705 100644 --- a/tool/tctl/common/collection.go +++ b/tool/tctl/common/collection.go @@ -688,41 +688,6 @@ func (c *databaseServerCollection) writeYAML(w io.Writer) error { return utils.WriteYAML(w, c.servers) } -type databaseCollection struct { - databases []types.Database -} - -func (c *databaseCollection) Resources() (r []types.Resource) { - for _, resource := range c.databases { - r = append(r, resource) - } - return r -} - -func (c *databaseCollection) WriteText(w io.Writer, verbose bool) error { - var rows [][]string - for _, database := range c.databases { - labels := common.FormatLabels(database.GetAllLabels(), verbose) - rows = append(rows, []string{ - common.FormatResourceName(database, verbose), - database.GetProtocol(), - database.GetURI(), - labels, - }) - } - headers := []string{"Name", "Protocol", "URI", "Labels"} - var t asciitable.Table - if verbose { - t = asciitable.MakeTable(headers, rows...) - } else { - t = asciitable.MakeTableWithTruncatedColumn(headers, rows, "Labels") - } - // stable sort by name. - t.SortRowsBy([]int{0}, true) - _, err := t.AsBuffer().WriteTo(w) - return trace.Wrap(err) -} - type lockCollection struct { locks []types.Lock } diff --git a/tool/tctl/common/collection_test.go b/tool/tctl/common/collection_test.go index 25959f7b77ea0..28851a72a678b 100644 --- a/tool/tctl/common/collection_test.go +++ b/tool/tctl/common/collection_test.go @@ -195,7 +195,7 @@ func testDatabaseCollection_writeText(t *testing.T) { rdsDiscoveredNameLabel), } test := writeTextTest{ - collection: &databaseCollection{databases: databases}, + collection: resources.NewDatabaseCollection(databases), wantNonVerboseTable: func() string { table := asciitable.MakeTableWithTruncatedColumn( []string{"Name", "Protocol", "URI", "Labels"}, diff --git a/tool/tctl/common/resource_command.go b/tool/tctl/common/resource_command.go index 281cea2786c6b..be64b2681f422 100644 --- a/tool/tctl/common/resource_command.go +++ b/tool/tctl/common/resource_command.go @@ -29,7 +29,6 @@ import ( "reflect" "slices" "sort" - "strings" "time" "github.com/alecthomas/kingpin/v2" @@ -150,7 +149,6 @@ func (rc *ResourceCommand) Initialize(app *kingpin.Application, _ *tctlcfg.Globa types.KindNetworkRestrictions: rc.createNetworkRestrictions, types.KindApp: rc.createApp, types.KindAppServer: rc.createAppServer, - types.KindDatabase: rc.createDatabase, types.KindKubernetesCluster: rc.createKubeCluster, types.KindToken: rc.createToken, types.KindInstaller: rc.createInstaller, @@ -453,7 +451,8 @@ func (rc *ResourceCommand) Create(ctx context.Context, client *authclient.Client } return trace.Wrap(err) } - return nil + // continue to next resource + continue } // Else fallback to the legacy logic @@ -1267,29 +1266,6 @@ func (rc *ResourceCommand) updateUserTask(ctx context.Context, client *authclien return nil } -func (rc *ResourceCommand) createDatabase(ctx context.Context, client *authclient.Client, raw services.UnknownResource) error { - database, err := services.UnmarshalDatabase(raw.Raw, services.DisallowUnknown()) - if err != nil { - return trace.Wrap(err) - } - database.SetOrigin(types.OriginDynamic) - if err := client.CreateDatabase(ctx, database); err != nil { - if trace.IsAlreadyExists(err) { - if !rc.force { - return trace.AlreadyExists("database %q already exists", database.GetName()) - } - if err := client.UpdateDatabase(ctx, database); err != nil { - return trace.Wrap(err) - } - fmt.Printf("database %q has been updated\n", database.GetName()) - return nil - } - return trace.Wrap(err) - } - fmt.Printf("database %q has been created\n", database.GetName()) - return nil -} - func (rc *ResourceCommand) createToken(ctx context.Context, client *authclient.Client, raw services.UnknownResource) error { token, err := services.UnmarshalProvisionToken(raw.Raw, services.DisallowUnknown()) if err != nil { @@ -1871,8 +1847,8 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client *authclient.Client return trace.Wrap(err) } resDesc := "database server" - servers = filterByNameOrDiscoveredName(servers, rc.ref.Name) - name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc) + servers = resources.FilterByNameOrDiscoveredName(servers, rc.ref.Name) + name, err := resources.GetOneResourceNameToDelete(servers, rc.ref, resDesc) if err != nil { return trace.Wrap(err) } @@ -1893,22 +1869,6 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client *authclient.Client return trace.Wrap(err) } fmt.Printf("application %q has been deleted\n", rc.ref.Name) - case types.KindDatabase: - // TODO(okraport) DELETE IN v21.0.0, replace with regular Collect - databases, err := clientutils.CollectWithFallback(ctx, client.ListDatabases, client.GetDatabases) - if err != nil { - return trace.Wrap(err) - } - resDesc := "database" - databases = filterByNameOrDiscoveredName(databases, rc.ref.Name) - name, err := getOneResourceNameToDelete(databases, rc.ref, resDesc) - if err != nil { - return trace.Wrap(err) - } - if err := client.DeleteDatabase(ctx, name); err != nil { - return trace.Wrap(err) - } - fmt.Printf("%s %q has been deleted\n", resDesc, name) case types.KindKubernetesCluster: // TODO(okraport) DELETE IN v21.0.0, replace with regular Collect clusters, err := clientutils.CollectWithFallback(ctx, client.ListKubernetesClusters, client.GetKubernetesClusters) @@ -1916,8 +1876,8 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client *authclient.Client return trace.Wrap(err) } resDesc := "Kubernetes cluster" - clusters = filterByNameOrDiscoveredName(clusters, rc.ref.Name) - name, err := getOneResourceNameToDelete(clusters, rc.ref, resDesc) + clusters = resources.FilterByNameOrDiscoveredName(clusters, rc.ref.Name) + name, err := resources.GetOneResourceNameToDelete(clusters, rc.ref, resDesc) if err != nil { return trace.Wrap(err) } @@ -1992,8 +1952,8 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client *authclient.Client return trace.Wrap(err) } resDesc := "Kubernetes server" - servers = filterByNameOrDiscoveredName(servers, rc.ref.Name) - name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc) + servers = resources.FilterByNameOrDiscoveredName(servers, rc.ref.Name) + name, err := resources.GetOneResourceNameToDelete(servers, rc.ref, resDesc) if err != nil { return trace.Wrap(err) } @@ -2647,7 +2607,7 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client *authclient return &databaseServerCollection{servers: servers}, nil } - servers = filterByNameOrDiscoveredName(servers, rc.ref.Name) + servers = resources.FilterByNameOrDiscoveredName(servers, rc.ref.Name) if len(servers) == 0 { return nil, trace.NotFound("database server %q not found", rc.ref.Name) } @@ -2663,7 +2623,7 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client *authclient altNameFn := func(r types.KubeServer) string { return r.GetHostname() } - servers = filterByNameOrDiscoveredName(servers, rc.ref.Name, altNameFn) + servers = resources.FilterByNameOrDiscoveredName(servers, rc.ref.Name, altNameFn) if len(servers) == 0 { return nil, trace.NotFound("Kubernetes server %q not found", rc.ref.Name) } @@ -2710,21 +2670,6 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client *authclient return nil, trace.Wrap(err) } return &appCollection{apps: []types.Application{app}}, nil - case types.KindDatabase: - // TODO(okraport): DELETE IN v21.0.0, replace with regular Collect - databases, err := clientutils.CollectWithFallback(ctx, client.ListDatabases, client.GetDatabases) - if err != nil { - return nil, trace.Wrap(err) - } - - if rc.ref.Name == "" { - return &databaseCollection{databases: databases}, nil - } - databases = filterByNameOrDiscoveredName(databases, rc.ref.Name) - if len(databases) == 0 { - return nil, trace.NotFound("database %q not found", rc.ref.Name) - } - return &databaseCollection{databases: databases}, nil case types.KindKubernetesCluster: // TODO(okraport) DELETE IN v21.0.0, replace with regular Collect clusters, err := clientutils.CollectWithFallback(ctx, client.ListKubernetesClusters, client.GetKubernetesClusters) @@ -2734,7 +2679,7 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client *authclient if rc.ref.Name == "" { return &kubeClusterCollection{clusters: clusters}, nil } - clusters = filterByNameOrDiscoveredName(clusters, rc.ref.Name) + clusters = resources.FilterByNameOrDiscoveredName(clusters, rc.ref.Name) if len(clusters) == 0 { return nil, trace.NotFound("Kubernetes cluster %q not found", rc.ref.Name) } @@ -3729,105 +3674,6 @@ func findDeviceByIDOrTag(ctx context.Context, remote devicepb.DeviceTrustService return nil, trace.BadParameter("found multiple devices for asset tag %q, please retry using the device ID instead", idOrTag) } -// keepFn is a predicate function that returns true if a resource should be -// retained by filterResources. -type keepFn[T types.ResourceWithLabels] func(T) bool - -// filterResources takes a list of resources and returns a filtered list of -// resources for which the `keep` predicate function returns true. -func filterResources[T types.ResourceWithLabels](resources []T, keep keepFn[T]) []T { - out := make([]T, 0, len(resources)) - for _, r := range resources { - if keep(r) { - out = append(out, r) - } - } - return out -} - -// altNameFn is a func that returns an alternative name for a resource. -type altNameFn[T types.ResourceWithLabels] func(T) string - -// filterByNameOrDiscoveredName filters resources by name or "discovered name". -// It prefers exact name filtering first - if none of the resource names match -// exactly (i.e. all of the resources are filtered out), then it retries and -// filters the resources by "discovered name" of resource name instead, which -// comes from an auto-discovery label. -func filterByNameOrDiscoveredName[T types.ResourceWithLabels](resources []T, prefixOrName string, extra ...altNameFn[T]) []T { - // prefer exact names - out := filterByName(resources, prefixOrName, extra...) - if len(out) == 0 { - // fallback to looking for discovered name label matches. - out = filterByDiscoveredName(resources, prefixOrName) - } - return out -} - -// filterByName filters resources by exact name match. -func filterByName[T types.ResourceWithLabels](resources []T, name string, altNameFns ...altNameFn[T]) []T { - return filterResources(resources, func(r T) bool { - if r.GetName() == name { - return true - } - for _, altName := range altNameFns { - if altName(r) == name { - return true - } - } - return false - }) -} - -// filterByDiscoveredName filters resources that have a "discovered name" label -// that matches the given name. -func filterByDiscoveredName[T types.ResourceWithLabels](resources []T, name string) []T { - return filterResources(resources, func(r T) bool { - discoveredName, ok := r.GetLabel(types.DiscoveredNameLabel) - return ok && discoveredName == name - }) -} - -// getOneResourceNameToDelete checks a list of resources to ensure there is -// exactly one resource name among them, and returns that name or an error. -// Heartbeat resources can have the same name but different host ID, so this -// still allows a user to delete multiple heartbeats of the same name, for -// example `$ tctl rm db_server/someDB`. -func getOneResourceNameToDelete[T types.ResourceWithLabels](rs []T, ref services.Ref, resDesc string) (string, error) { - seen := make(map[string]struct{}) - for _, r := range rs { - seen[r.GetName()] = struct{}{} - } - switch len(seen) { - case 1: // need exactly one. - return rs[0].GetName(), nil - case 0: - return "", trace.NotFound("%v %q not found", resDesc, ref.Name) - default: - names := make([]string, 0, len(rs)) - for _, r := range rs { - names = append(names, r.GetName()) - } - msg := formatAmbiguousDeleteMessage(ref, resDesc, names) - return "", trace.BadParameter("%s", msg) - } -} - -// formatAmbiguousDeleteMessage returns a formatted message when a user is -// attempting to delete multiple resources by an ambiguous prefix of the -// resource names. -func formatAmbiguousDeleteMessage(ref services.Ref, resDesc string, names []string) string { - slices.Sort(names) - // choose an actual resource for the example in the error. - exampleRef := ref - exampleRef.Name = names[0] - return fmt.Sprintf(`%s matches multiple auto-discovered %vs: -%v - -Use the full resource name that was generated by the Teleport Discovery service, for example: -$ tctl rm %s`, - ref.String(), resDesc, strings.Join(names, "\n"), exampleRef.String()) -} - func (rc *ResourceCommand) createAuditQuery(ctx context.Context, client *authclient.Client, raw services.UnknownResource) error { in, err := services.UnmarshalAuditQuery(raw.Raw, services.DisallowUnknown()) if err != nil { diff --git a/tool/tctl/common/resource_command_test.go b/tool/tctl/common/resource_command_test.go index 6fbb2c2916e1f..952b6da81a7d1 100644 --- a/tool/tctl/common/resource_command_test.go +++ b/tool/tctl/common/resource_command_test.go @@ -1375,7 +1375,6 @@ func TestDatabaseResource(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, err) test := dynamicResourceTest[*types.DatabaseV3]{ kind: types.KindDatabase, resourceYAML: dbYAML, @@ -1454,113 +1453,6 @@ func TestAppResource(t *testing.T) { test.run(t) } -func TestGetOneResourceNameToDelete(t *testing.T) { - foo1 := mustCreateNewKubeServer(t, "foo-eks", "host-foo1", "foo", nil) - foo2 := mustCreateNewKubeServer(t, "foo-eks", "host-foo2", "foo", nil) - fooBar1 := mustCreateNewKubeServer(t, "foo-bar-eks-us-west-1", "host-foo-bar1", "foo-bar", nil) - fooBar2 := mustCreateNewKubeServer(t, "foo-bar-eks-us-west-2", "host-foo-bar2", "foo-bar", nil) - tests := []struct { - desc string - refName string - wantErrContains string - resources []types.KubeServer - wantName string - }{ - { - desc: "one resource is ok", - refName: "foo-bar-eks-us-west-1", - resources: []types.KubeServer{fooBar1}, - wantName: "foo-bar-eks-us-west-1", - }, - { - desc: "multiple resources with same name is ok", - refName: "foo", - resources: []types.KubeServer{foo1, foo2}, - wantName: "foo-eks", - }, - { - desc: "zero resources is an error", - refName: "xxx", - wantErrContains: `kubernetes server "xxx" not found`, - }, - { - desc: "multiple resources with different names is an error", - refName: "foo-bar", - resources: []types.KubeServer{fooBar1, fooBar2}, - wantErrContains: "matches multiple", - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - ref := services.Ref{Kind: types.KindKubeServer, Name: test.refName} - resDesc := "kubernetes server" - name, err := getOneResourceNameToDelete(test.resources, ref, resDesc) - if test.wantErrContains != "" { - require.ErrorContains(t, err, test.wantErrContains) - return - } - require.Equal(t, test.wantName, name) - }) - } -} - -func TestFilterByNameOrDiscoveredName(t *testing.T) { - foo1 := mustCreateNewKubeServer(t, "foo-eks-us-west-1", "host-foo", "foo", nil) - foo2 := mustCreateNewKubeServer(t, "foo-eks-us-west-2", "host-foo", "foo", nil) - fooBar1 := mustCreateNewKubeServer(t, "foo-bar", "host-foo-bar1", "", nil) - fooBar2 := mustCreateNewKubeServer(t, "foo-bar-eks-us-west-2", "host-foo-bar2", "foo-bar", nil) - resources := []types.KubeServer{ - foo1, foo2, fooBar1, fooBar2, - } - hostNameGetter := func(ks types.KubeServer) string { return ks.GetHostname() } - tests := []struct { - desc string - filter string - altNameGetters []altNameFn[types.KubeServer] - want []types.KubeServer - }{ - { - desc: "filters by exact name", - filter: "foo-eks-us-west-1", - want: []types.KubeServer{foo1}, - }, - { - desc: "filters by exact name over discovered names", - filter: "foo-bar", - want: []types.KubeServer{fooBar1}, - }, - { - desc: "filters by discovered name", - filter: "foo", - want: []types.KubeServer{foo1, foo2}, - }, - { - desc: "checks alt names for exact matches", - filter: "host-foo", - altNameGetters: []altNameFn[types.KubeServer]{hostNameGetter}, - want: []types.KubeServer{foo1, foo2}, - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - got := filterByNameOrDiscoveredName(resources, test.filter, test.altNameGetters...) - require.Empty(t, cmp.Diff(test.want, got)) - }) - } -} - -func TestFormatAmbiguousDeleteMessage(t *testing.T) { - ref := services.Ref{Kind: types.KindDatabase, Name: "x"} - resDesc := "database" - names := []string{"xbbb", "xaaa", "xccc", "xb"} - got := formatAmbiguousDeleteMessage(ref, resDesc, names) - require.Contains(t, got, "db/x matches multiple auto-discovered databases", - "should have formatted the ref used and pluralized the resource description") - wantSortedNames := strings.Join([]string{"xaaa", "xb", "xbbb", "xccc"}, "\n") - require.Contains(t, got, wantSortedNames, "should have sorted the matching names") - require.Contains(t, got, "$ tctl rm db/xaaa", "should have contained an example command") -} - // requireEqual creates an assertion function with a bound `expected` value // for use with table-driven tests func requireEqual(expected any) require.ValueAssertionFunc { diff --git a/tool/tctl/common/resources/database.go b/tool/tctl/common/resources/database.go new file mode 100644 index 0000000000000..a516abdf8f3fd --- /dev/null +++ b/tool/tctl/common/resources/database.go @@ -0,0 +1,159 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package resources + +import ( + "context" + "fmt" + "io" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/clientutils" + "github.com/gravitational/teleport/lib/asciitable" + "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/services" + sliceutils "github.com/gravitational/teleport/lib/utils/slices" + "github.com/gravitational/teleport/tool/common" +) + +type databaseCollection struct { + databases types.Databases +} + +// NewDatabaseCollection creates a [Collection] over the provided databases. +func NewDatabaseCollection(databases types.Databases) Collection { + return &databaseCollection{databases: databases} +} + +func (c *databaseCollection) Resources() []types.Resource { + return sliceutils.Map(c.databases, func(db types.Database) types.Resource { + return db + }) +} + +func (c *databaseCollection) WriteText(w io.Writer, verbose bool) error { + rows := make([][]string, 0, len(c.databases)) + for _, database := range c.databases { + labels := common.FormatLabels(database.GetAllLabels(), verbose) + rows = append(rows, []string{ + common.FormatResourceName(database, verbose), + database.GetProtocol(), + database.GetURI(), + labels, + }) + } + headers := []string{"Name", "Protocol", "URI", "Labels"} + var t asciitable.Table + if verbose { + t = asciitable.MakeTable(headers, rows...) + } else { + t = asciitable.MakeTableWithTruncatedColumn(headers, rows, "Labels") + } + // stable sort by name. + t.SortRowsBy([]int{0}, true) + _, err := t.AsBuffer().WriteTo(w) + return trace.Wrap(err) +} + +func databaseHandler() Handler { + return Handler{ + getHandler: getDatabase, + createHandler: createDatabase, + updateHandler: updateDatabase, + deleteHandler: deleteDatabase, + singleton: false, + mfaRequired: false, + description: "A dynamic resource representing a database that can be proxied via a database service.", + } +} + +func getDatabase(ctx context.Context, client *authclient.Client, ref services.Ref, opts GetOpts) (Collection, error) { + // TODO(greedy52) implement resource filtering on the backend. + // TODO(okraport) DELETE IN v21.0.0, replace with regular Collect + databases, err := clientutils.CollectWithFallback(ctx, client.ListDatabases, client.GetDatabases) + if err != nil { + return nil, trace.Wrap(err) + } + + if ref.Name == "" { + return NewDatabaseCollection(databases), nil + } + databases = FilterByNameOrDiscoveredName(databases, ref.Name) + if len(databases) == 0 { + return nil, trace.NotFound("database %q not found", ref.Name) + } + return NewDatabaseCollection(databases), nil +} + +func createDatabase(ctx context.Context, client *authclient.Client, raw services.UnknownResource, opts CreateOpts) error { + database, err := services.UnmarshalDatabase(raw.Raw, services.DisallowUnknown()) + if err != nil { + return trace.Wrap(err) + } + database.SetOrigin(types.OriginDynamic) + if err := client.CreateDatabase(ctx, database); err != nil { + if trace.IsAlreadyExists(err) { + if !opts.Force { + return trace.AlreadyExists("database %q already exists", database.GetName()) + } + if err := client.UpdateDatabase(ctx, database); err != nil { + return trace.Wrap(err) + } + fmt.Printf("database %q has been updated\n", database.GetName()) + return nil + } + return trace.Wrap(err) + } + fmt.Printf("database %q has been created\n", database.GetName()) + return nil +} + +func updateDatabase(ctx context.Context, client *authclient.Client, raw services.UnknownResource, opts CreateOpts) error { + database, err := services.UnmarshalDatabase(raw.Raw, services.DisallowUnknown()) + if err != nil { + return trace.Wrap(err) + } + database.SetOrigin(types.OriginDynamic) + if err := client.UpdateDatabase(ctx, database); err != nil { + return trace.Wrap(err) + } + fmt.Printf("database %q has been updated\n", database.GetName()) + return nil +} + +func deleteDatabase(ctx context.Context, client *authclient.Client, ref services.Ref) error { + // TODO(okraport) DELETE IN v21.0.0, replace with regular Collect + databases, err := clientutils.CollectWithFallback(ctx, client.ListDatabases, client.GetDatabases) + if err != nil { + return trace.Wrap(err) + } + resDesc := "database" + databases = FilterByNameOrDiscoveredName(databases, ref.Name) + name, err := GetOneResourceNameToDelete(databases, ref, resDesc) + if err != nil { + return trace.Wrap(err) + } + if err := client.DeleteDatabase(ctx, name); err != nil { + return trace.Wrap(err) + } + fmt.Printf("%s %q has been deleted\n", resDesc, name) + return nil +} diff --git a/tool/tctl/common/resources/resource.go b/tool/tctl/common/resources/resource.go index 4452615bd8640..1ad45be0c39b9 100644 --- a/tool/tctl/common/resources/resource.go +++ b/tool/tctl/common/resources/resource.go @@ -33,8 +33,9 @@ import ( // to the Handler format. func Handlers() map[string]Handler { return map[string]Handler{ - types.KindRole: roleHandler(), - types.KindUser: userHandler(), + types.KindRole: roleHandler(), + types.KindUser: userHandler(), + types.KindDatabase: databaseHandler(), } } diff --git a/tool/tctl/common/resources/utils.go b/tool/tctl/common/resources/utils.go new file mode 100644 index 0000000000000..269b41228ef01 --- /dev/null +++ b/tool/tctl/common/resources/utils.go @@ -0,0 +1,118 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package resources + +import ( + "fmt" + "slices" + "strings" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/iterutils" + "github.com/gravitational/teleport/lib/services" +) + +// AltResourceNameFunc is a func that returns an alternative name for a resource. +type AltResourceNameFunc[T types.ResourceWithLabels] func(T) string + +// FilterByNameOrDiscoveredName filters resources by name or +// "discovered name". It prefers exact name filtering first - if none of the +// resource names match exactly (i.e. all of the resources are filtered out), +// then it retries and filters the resources by "discovered name" of resource +// name instead, which comes from an auto-discovery label. +func FilterByNameOrDiscoveredName[T types.ResourceWithLabels](resources []T, prefixOrName string, extra ...AltResourceNameFunc[T]) []T { + // prefer exact names + out := filterResourcesByName(resources, prefixOrName, extra...) + if len(out) == 0 { + // fallback to looking for discovered name label matches. + out = filterByDiscoveredName(resources, prefixOrName) + } + return out +} + +// filterResourcesByName filters resources by exact name match. +func filterResourcesByName[T types.ResourceWithLabels](resources []T, name string, altNameFns ...AltResourceNameFunc[T]) []T { + return filterResources(resources, func(r T) bool { + if r.GetName() == name { + return true + } + for _, altName := range altNameFns { + if altName(r) == name { + return true + } + } + return false + }) +} + +// filterByDiscoveredName filters resources that have a "discovered name" label +// that matches the given name. +func filterByDiscoveredName[T types.ResourceWithLabels](resources []T, name string) []T { + return filterResources(resources, func(r T) bool { + discoveredName, ok := r.GetLabel(types.DiscoveredNameLabel) + return ok && discoveredName == name + }) +} + +func filterResources[T types.ResourceWithLabels](resources []T, keepFn func(T) bool) []T { + return slices.Collect(iterutils.Filter(keepFn, slices.Values(resources))) +} + +// GetOneResourceNameToDelete checks a list of resources to ensure there is +// exactly one resource name among them, and returns that name or an error. +// Heartbeat resources can have the same name but different host ID, so this +// still allows a user to delete multiple heartbeats of the same name, for +// example `$ tctl rm db_server/someDB`. +func GetOneResourceNameToDelete[T types.ResourceWithLabels](rs []T, ref services.Ref, resDesc string) (string, error) { + seen := make(map[string]struct{}) + for _, r := range rs { + seen[r.GetName()] = struct{}{} + } + switch len(seen) { + case 1: // need exactly one. + return rs[0].GetName(), nil + case 0: + return "", trace.NotFound("%v %q not found", resDesc, ref.Name) + default: + names := make([]string, 0, len(rs)) + for _, r := range rs { + names = append(names, r.GetName()) + } + msg := formatAmbiguousDeleteMessage(ref, resDesc, names) + return "", trace.BadParameter("%s", msg) + } +} + +// formatAmbiguousDeleteMessage returns a formatted message when a user is +// attempting to delete multiple resources by an ambiguous prefix of the +// resource names. +func formatAmbiguousDeleteMessage(ref services.Ref, resDesc string, names []string) string { + slices.Sort(names) + // choose an actual resource for the example in the error. + exampleRef := ref + exampleRef.Name = names[0] + return fmt.Sprintf(`%s matches multiple auto-discovered %vs: +%v + +Use the full resource name that was generated by the Teleport Discovery service, for example: +$ tctl rm %s`, + ref.String(), resDesc, strings.Join(names, "\n"), exampleRef.String()) +} diff --git a/tool/tctl/common/resources/utils_test.go b/tool/tctl/common/resources/utils_test.go new file mode 100644 index 0000000000000..4fc1817159ff3 --- /dev/null +++ b/tool/tctl/common/resources/utils_test.go @@ -0,0 +1,190 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package resources + +import ( + "maps" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" +) + +func TestGetOneResourceNameToDelete(t *testing.T) { + foo1 := mustCreateNewKubeServer(t, "foo-eks", "host-foo1", "foo", nil) + foo2 := mustCreateNewKubeServer(t, "foo-eks", "host-foo2", "foo", nil) + fooBar1 := mustCreateNewKubeServer(t, "foo-bar-eks-us-west-1", "host-foo-bar1", "foo-bar", nil) + fooBar2 := mustCreateNewKubeServer(t, "foo-bar-eks-us-west-2", "host-foo-bar2", "foo-bar", nil) + tests := []struct { + desc string + refName string + wantErrContains string + resources []types.KubeServer + wantName string + }{ + { + desc: "one resource is ok", + refName: "foo-bar-eks-us-west-1", + resources: []types.KubeServer{fooBar1}, + wantName: "foo-bar-eks-us-west-1", + }, + { + desc: "multiple resources with same name is ok", + refName: "foo", + resources: []types.KubeServer{foo1, foo2}, + wantName: "foo-eks", + }, + { + desc: "zero resources is an error", + refName: "xxx", + wantErrContains: `kubernetes server "xxx" not found`, + }, + { + desc: "multiple resources with different names is an error", + refName: "foo-bar", + resources: []types.KubeServer{fooBar1, fooBar2}, + wantErrContains: "matches multiple", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ref := services.Ref{Kind: types.KindKubeServer, Name: test.refName} + resDesc := "kubernetes server" + name, err := GetOneResourceNameToDelete(test.resources, ref, resDesc) + if test.wantErrContains != "" { + require.ErrorContains(t, err, test.wantErrContains) + return + } + require.Equal(t, test.wantName, name) + }) + } +} + +func TestFilterByNameOrDiscoveredName(t *testing.T) { + foo1 := mustCreateNewKubeServer(t, "foo-eks-us-west-1", "host-foo", "foo", nil) + foo2 := mustCreateNewKubeServer(t, "foo-eks-us-west-2", "host-foo", "foo", nil) + fooBar1 := mustCreateNewKubeServer(t, "foo-bar", "host-foo-bar1", "", nil) + fooBar2 := mustCreateNewKubeServer(t, "foo-bar-eks-us-west-2", "host-foo-bar2", "foo-bar", nil) + resources := []types.KubeServer{ + foo1, foo2, fooBar1, fooBar2, + } + hostNameGetter := func(ks types.KubeServer) string { return ks.GetHostname() } + tests := []struct { + desc string + filter string + altNameGetters []AltResourceNameFunc[types.KubeServer] + want []types.KubeServer + }{ + { + desc: "filters by exact name", + filter: "foo-eks-us-west-1", + want: []types.KubeServer{foo1}, + }, + { + desc: "filters by exact name over discovered names", + filter: "foo-bar", + want: []types.KubeServer{fooBar1}, + }, + { + desc: "filters by discovered name", + filter: "foo", + want: []types.KubeServer{foo1, foo2}, + }, + { + desc: "checks alt names for exact matches", + filter: "host-foo", + altNameGetters: []AltResourceNameFunc[types.KubeServer]{hostNameGetter}, + want: []types.KubeServer{foo1, foo2}, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + got := FilterByNameOrDiscoveredName(resources, test.filter, test.altNameGetters...) + require.Empty(t, cmp.Diff(test.want, got)) + }) + } +} + +func TestFormatAmbiguousDeleteMessage(t *testing.T) { + ref := services.Ref{Kind: types.KindDatabase, Name: "x"} + resDesc := "database" + names := []string{"xbbb", "xaaa", "xccc", "xb"} + got := formatAmbiguousDeleteMessage(ref, resDesc, names) + require.Contains(t, got, "db/x matches multiple auto-discovered databases", + "should have formatted the ref used and pluralized the resource description") + wantSortedNames := strings.Join([]string{"xaaa", "xb", "xbbb", "xccc"}, "\n") + require.Contains(t, got, wantSortedNames, "should have sorted the matching names") + require.Contains(t, got, "$ tctl rm db/xaaa", "should have contained an example command") +} + +func makeTestLabels(extraStaticLabels map[string]string) map[string]string { + labels := make(map[string]string) + maps.Copy(labels, staticLabelsFixture) + maps.Copy(labels, extraStaticLabels) + return labels +} + +func mustCreateNewKubeCluster(t *testing.T, name, discoveredName string, extraStaticLabels map[string]string) *types.KubernetesClusterV3 { + t.Helper() + if extraStaticLabels == nil { + extraStaticLabels = make(map[string]string) + } + if discoveredName != "" { + extraStaticLabels[types.DiscoveredNameLabel] = discoveredName + } + cluster, err := types.NewKubernetesClusterV3( + types.Metadata{ + Name: name, + Labels: makeTestLabels(extraStaticLabels), + }, + types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{ + "date": { + Period: types.NewDuration(1 * time.Second), + Command: []string{"date"}, + Result: "Tue 11 Oct 2022 10:21:58 WEST", + }, + }, + }, + ) + require.NoError(t, err) + return cluster +} + +func mustCreateNewKubeServer(t *testing.T, name, hostname, discoveredName string, extraStaticLabels map[string]string) *types.KubernetesServerV3 { + t.Helper() + cluster := mustCreateNewKubeCluster(t, name, discoveredName, extraStaticLabels) + kubeServer, err := types.NewKubernetesServerV3FromCluster(cluster, hostname, uuid.New().String()) + require.NoError(t, err) + return kubeServer +} + +var ( + staticLabelsFixture = map[string]string{ + "label1": "val1", + "label2": "val2", + "label3": "val3", + } +)