Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions tool/tctl/common/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,41 +688,6 @@ func (c *databaseServerCollection) writeYAML(w io.Writer) error {
return utils.WriteYAML(w, c.servers)
}

type databaseCollection struct {
databases []types.Database
}

func (c *databaseCollection) Resources() (r []types.Resource) {
for _, resource := range c.databases {
r = append(r, resource)
}
return r
}

func (c *databaseCollection) WriteText(w io.Writer, verbose bool) error {
var rows [][]string
for _, database := range c.databases {
labels := common.FormatLabels(database.GetAllLabels(), verbose)
rows = append(rows, []string{
common.FormatResourceName(database, verbose),
database.GetProtocol(),
database.GetURI(),
labels,
})
}
headers := []string{"Name", "Protocol", "URI", "Labels"}
var t asciitable.Table
if verbose {
t = asciitable.MakeTable(headers, rows...)
} else {
t = asciitable.MakeTableWithTruncatedColumn(headers, rows, "Labels")
}
// stable sort by name.
t.SortRowsBy([]int{0}, true)
_, err := t.AsBuffer().WriteTo(w)
return trace.Wrap(err)
}

type lockCollection struct {
locks []types.Lock
}
Expand Down
2 changes: 1 addition & 1 deletion tool/tctl/common/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func testDatabaseCollection_writeText(t *testing.T) {
rdsDiscoveredNameLabel),
}
test := writeTextTest{
collection: &databaseCollection{databases: databases},
collection: resources.NewDatabaseCollection(databases),
wantNonVerboseTable: func() string {
table := asciitable.MakeTableWithTruncatedColumn(
[]string{"Name", "Protocol", "URI", "Labels"},
Expand Down
176 changes: 11 additions & 165 deletions tool/tctl/common/resource_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"reflect"
"slices"
"sort"
"strings"
"time"

"github.com/alecthomas/kingpin/v2"
Expand Down Expand Up @@ -150,7 +149,6 @@ func (rc *ResourceCommand) Initialize(app *kingpin.Application, _ *tctlcfg.Globa
types.KindNetworkRestrictions: rc.createNetworkRestrictions,
types.KindApp: rc.createApp,
types.KindAppServer: rc.createAppServer,
types.KindDatabase: rc.createDatabase,
types.KindKubernetesCluster: rc.createKubeCluster,
types.KindToken: rc.createToken,
types.KindInstaller: rc.createInstaller,
Expand Down Expand Up @@ -453,7 +451,8 @@ func (rc *ResourceCommand) Create(ctx context.Context, client *authclient.Client
}
return trace.Wrap(err)
}
return nil
// continue to next resource
continue
}
// Else fallback to the legacy logic

Expand Down Expand Up @@ -1267,29 +1266,6 @@ func (rc *ResourceCommand) updateUserTask(ctx context.Context, client *authclien
return nil
}

func (rc *ResourceCommand) createDatabase(ctx context.Context, client *authclient.Client, raw services.UnknownResource) error {
database, err := services.UnmarshalDatabase(raw.Raw, services.DisallowUnknown())
if err != nil {
return trace.Wrap(err)
}
database.SetOrigin(types.OriginDynamic)
if err := client.CreateDatabase(ctx, database); err != nil {
if trace.IsAlreadyExists(err) {
if !rc.force {
return trace.AlreadyExists("database %q already exists", database.GetName())
}
if err := client.UpdateDatabase(ctx, database); err != nil {
return trace.Wrap(err)
}
fmt.Printf("database %q has been updated\n", database.GetName())
return nil
}
return trace.Wrap(err)
}
fmt.Printf("database %q has been created\n", database.GetName())
return nil
}

func (rc *ResourceCommand) createToken(ctx context.Context, client *authclient.Client, raw services.UnknownResource) error {
token, err := services.UnmarshalProvisionToken(raw.Raw, services.DisallowUnknown())
if err != nil {
Expand Down Expand Up @@ -1871,8 +1847,8 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client *authclient.Client
return trace.Wrap(err)
}
resDesc := "database server"
servers = filterByNameOrDiscoveredName(servers, rc.ref.Name)
name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc)
servers = resources.FilterByNameOrDiscoveredName(servers, rc.ref.Name)
name, err := resources.GetOneResourceNameToDelete(servers, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -1893,31 +1869,15 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client *authclient.Client
return trace.Wrap(err)
}
fmt.Printf("application %q has been deleted\n", rc.ref.Name)
case types.KindDatabase:
// TODO(okraport) DELETE IN v21.0.0, replace with regular Collect
databases, err := clientutils.CollectWithFallback(ctx, client.ListDatabases, client.GetDatabases)
if err != nil {
return trace.Wrap(err)
}
resDesc := "database"
databases = filterByNameOrDiscoveredName(databases, rc.ref.Name)
name, err := getOneResourceNameToDelete(databases, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
if err := client.DeleteDatabase(ctx, name); err != nil {
return trace.Wrap(err)
}
fmt.Printf("%s %q has been deleted\n", resDesc, name)
case types.KindKubernetesCluster:
// TODO(okraport) DELETE IN v21.0.0, replace with regular Collect
clusters, err := clientutils.CollectWithFallback(ctx, client.ListKubernetesClusters, client.GetKubernetesClusters)
if err != nil {
return trace.Wrap(err)
}
resDesc := "Kubernetes cluster"
clusters = filterByNameOrDiscoveredName(clusters, rc.ref.Name)
name, err := getOneResourceNameToDelete(clusters, rc.ref, resDesc)
clusters = resources.FilterByNameOrDiscoveredName(clusters, rc.ref.Name)
name, err := resources.GetOneResourceNameToDelete(clusters, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1992,8 +1952,8 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client *authclient.Client
return trace.Wrap(err)
}
resDesc := "Kubernetes server"
servers = filterByNameOrDiscoveredName(servers, rc.ref.Name)
name, err := getOneResourceNameToDelete(servers, rc.ref, resDesc)
servers = resources.FilterByNameOrDiscoveredName(servers, rc.ref.Name)
name, err := resources.GetOneResourceNameToDelete(servers, rc.ref, resDesc)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -2647,7 +2607,7 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client *authclient
return &databaseServerCollection{servers: servers}, nil
}

servers = filterByNameOrDiscoveredName(servers, rc.ref.Name)
servers = resources.FilterByNameOrDiscoveredName(servers, rc.ref.Name)
if len(servers) == 0 {
return nil, trace.NotFound("database server %q not found", rc.ref.Name)
}
Expand All @@ -2663,7 +2623,7 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client *authclient
altNameFn := func(r types.KubeServer) string {
return r.GetHostname()
}
servers = filterByNameOrDiscoveredName(servers, rc.ref.Name, altNameFn)
servers = resources.FilterByNameOrDiscoveredName(servers, rc.ref.Name, altNameFn)
if len(servers) == 0 {
return nil, trace.NotFound("Kubernetes server %q not found", rc.ref.Name)
}
Expand Down Expand Up @@ -2710,21 +2670,6 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client *authclient
return nil, trace.Wrap(err)
}
return &appCollection{apps: []types.Application{app}}, nil
case types.KindDatabase:
// TODO(okraport): DELETE IN v21.0.0, replace with regular Collect
databases, err := clientutils.CollectWithFallback(ctx, client.ListDatabases, client.GetDatabases)
if err != nil {
return nil, trace.Wrap(err)
}

if rc.ref.Name == "" {
return &databaseCollection{databases: databases}, nil
}
databases = filterByNameOrDiscoveredName(databases, rc.ref.Name)
if len(databases) == 0 {
return nil, trace.NotFound("database %q not found", rc.ref.Name)
}
return &databaseCollection{databases: databases}, nil
case types.KindKubernetesCluster:
// TODO(okraport) DELETE IN v21.0.0, replace with regular Collect
clusters, err := clientutils.CollectWithFallback(ctx, client.ListKubernetesClusters, client.GetKubernetesClusters)
Expand All @@ -2734,7 +2679,7 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client *authclient
if rc.ref.Name == "" {
return &kubeClusterCollection{clusters: clusters}, nil
}
clusters = filterByNameOrDiscoveredName(clusters, rc.ref.Name)
clusters = resources.FilterByNameOrDiscoveredName(clusters, rc.ref.Name)
if len(clusters) == 0 {
return nil, trace.NotFound("Kubernetes cluster %q not found", rc.ref.Name)
}
Expand Down Expand Up @@ -3729,105 +3674,6 @@ func findDeviceByIDOrTag(ctx context.Context, remote devicepb.DeviceTrustService
return nil, trace.BadParameter("found multiple devices for asset tag %q, please retry using the device ID instead", idOrTag)
}

// keepFn is a predicate function that returns true if a resource should be
// retained by filterResources.
type keepFn[T types.ResourceWithLabels] func(T) bool

// filterResources takes a list of resources and returns a filtered list of
// resources for which the `keep` predicate function returns true.
func filterResources[T types.ResourceWithLabels](resources []T, keep keepFn[T]) []T {
out := make([]T, 0, len(resources))
for _, r := range resources {
if keep(r) {
out = append(out, r)
}
}
return out
}

// altNameFn is a func that returns an alternative name for a resource.
type altNameFn[T types.ResourceWithLabels] func(T) string

// filterByNameOrDiscoveredName filters resources by name or "discovered name".
// It prefers exact name filtering first - if none of the resource names match
// exactly (i.e. all of the resources are filtered out), then it retries and
// filters the resources by "discovered name" of resource name instead, which
// comes from an auto-discovery label.
func filterByNameOrDiscoveredName[T types.ResourceWithLabels](resources []T, prefixOrName string, extra ...altNameFn[T]) []T {
// prefer exact names
out := filterByName(resources, prefixOrName, extra...)
if len(out) == 0 {
// fallback to looking for discovered name label matches.
out = filterByDiscoveredName(resources, prefixOrName)
}
return out
}

// filterByName filters resources by exact name match.
func filterByName[T types.ResourceWithLabels](resources []T, name string, altNameFns ...altNameFn[T]) []T {
return filterResources(resources, func(r T) bool {
if r.GetName() == name {
return true
}
for _, altName := range altNameFns {
if altName(r) == name {
return true
}
}
return false
})
}

// filterByDiscoveredName filters resources that have a "discovered name" label
// that matches the given name.
func filterByDiscoveredName[T types.ResourceWithLabels](resources []T, name string) []T {
return filterResources(resources, func(r T) bool {
discoveredName, ok := r.GetLabel(types.DiscoveredNameLabel)
return ok && discoveredName == name
})
}

// getOneResourceNameToDelete checks a list of resources to ensure there is
// exactly one resource name among them, and returns that name or an error.
// Heartbeat resources can have the same name but different host ID, so this
// still allows a user to delete multiple heartbeats of the same name, for
// example `$ tctl rm db_server/someDB`.
func getOneResourceNameToDelete[T types.ResourceWithLabels](rs []T, ref services.Ref, resDesc string) (string, error) {
seen := make(map[string]struct{})
for _, r := range rs {
seen[r.GetName()] = struct{}{}
}
switch len(seen) {
case 1: // need exactly one.
return rs[0].GetName(), nil
case 0:
return "", trace.NotFound("%v %q not found", resDesc, ref.Name)
default:
names := make([]string, 0, len(rs))
for _, r := range rs {
names = append(names, r.GetName())
}
msg := formatAmbiguousDeleteMessage(ref, resDesc, names)
return "", trace.BadParameter("%s", msg)
}
}

// formatAmbiguousDeleteMessage returns a formatted message when a user is
// attempting to delete multiple resources by an ambiguous prefix of the
// resource names.
func formatAmbiguousDeleteMessage(ref services.Ref, resDesc string, names []string) string {
slices.Sort(names)
// choose an actual resource for the example in the error.
exampleRef := ref
exampleRef.Name = names[0]
return fmt.Sprintf(`%s matches multiple auto-discovered %vs:
%v

Use the full resource name that was generated by the Teleport Discovery service, for example:
$ tctl rm %s`,
ref.String(), resDesc, strings.Join(names, "\n"), exampleRef.String())
}

func (rc *ResourceCommand) createAuditQuery(ctx context.Context, client *authclient.Client, raw services.UnknownResource) error {
in, err := services.UnmarshalAuditQuery(raw.Raw, services.DisallowUnknown())
if err != nil {
Expand Down
Loading
Loading