diff --git a/lib/auth/authclient/api.go b/lib/auth/authclient/api.go index 7d590837c6371..950c03bfdd696 100644 --- a/lib/auth/authclient/api.go +++ b/lib/auth/authclient/api.go @@ -1311,6 +1311,9 @@ type Cache interface { // GetPluginStaticCredentialsByLabels will get a list of plugin static credentials resource by matching labels. GetPluginStaticCredentialsByLabels(ctx context.Context, labels map[string]string) ([]types.PluginStaticCredentials, error) + // PluginGetter defines methods for fetching plugins. + services.PluginGetter + // GitServerGetter defines methods for fetching Git servers. services.GitServerGetter diff --git a/lib/cache/plugins.go b/lib/cache/plugins.go index 4a1bc2b99f5c0..633c38f965bd9 100644 --- a/lib/cache/plugins.go +++ b/lib/cache/plugins.go @@ -165,6 +165,31 @@ func (c *Cache) ListPlugins(ctx context.Context, limit int, startKey string, wit return plugins, nextKey, nil } +// HasPluginType will return true if a plugin of the given type is registered. +func (c *Cache) HasPluginType(ctx context.Context, pluginType types.PluginType) (bool, error) { + _, span := c.Tracer.Start(ctx, "cache/HasPluginType") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.plugins) + if err != nil { + return false, trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + // Cache is currently not available; check for the plugin type existence upstream. + ok, err := c.Config.Plugin.HasPluginType(ctx, pluginType) + return ok, trace.Wrap(err) + } + + for plugin := range rg.store.resources(pluginNameIndex, "", "") { + if plugin.GetType() == pluginType { + return true, nil + } + } + return false, nil +} + // stripPluginSecrets returns a cloned plugin, optionally removing secrets. // This allows conditional filtering based on the `withSecrets` flag. func stripAndClonePluginSecrets(in types.Plugin, withSecrets bool) types.Plugin { diff --git a/lib/cache/plugins_test.go b/lib/cache/plugins_test.go index 272f123e9386f..534933ca9d2ce 100644 --- a/lib/cache/plugins_test.go +++ b/lib/cache/plugins_test.go @@ -19,13 +19,16 @@ package cache import ( "context" "testing" + "time" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" ) -func newPlugin(name string) types.Plugin { +func newPlugin(name string) *types.PluginV1 { return &types.PluginV1{ Metadata: types.Metadata{Name: name}, Spec: types.PluginSpecV1{ @@ -38,7 +41,7 @@ func newPlugin(name string) types.Plugin { } } -func newPluginWithCreds(name string) types.Plugin { +func newPluginWithCreds(name string) *types.PluginV1 { item := newPlugin(name) creds := types.PluginCredentialsV1{ Credentials: &types.PluginCredentialsV1_StaticCredentialsRef{ @@ -142,3 +145,49 @@ func TestPlugin(t *testing.T) { }) }) } + +func TestPlugin_HasPluginType(t *testing.T) { + t.Parallel() + ctx := t.Context() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + slackPlugin := newPlugin("test_slack_1") + slackPlugin.Spec.Settings = &types.PluginSpecV1_SlackAccessPlugin{ + SlackAccessPlugin: &types.PluginSlackAccessSettings{ + FallbackChannel: "#foo", + }, + } + scimPlugin := newPlugin("test_scim_1") + scimPlugin.Spec.Settings = &types.PluginSpecV1_Scim{ + Scim: &types.PluginSCIMSettings{ + SamlConnectorName: "example-saml-connector", + }, + } + + err := p.plugin.CreatePlugin(ctx, slackPlugin) + require.NoError(t, err) + + err = p.plugin.CreatePlugin(ctx, scimPlugin) + require.NoError(t, err) + + // Wait for cache propagation. + require.EventuallyWithT(t, func(t *assert.CollectT) { + plugins, _, err := p.cache.ListPlugins(ctx, 0, "", false) + require.NoError(t, err) + require.Len(t, plugins, 2) + }, 15*time.Second, 100*time.Millisecond) + + has, err := p.cache.HasPluginType(ctx, types.PluginTypeSlack) + require.NoError(t, err) + require.True(t, has) + + has, err = p.cache.HasPluginType(ctx, types.PluginTypeSCIM) + require.NoError(t, err) + require.True(t, has) + + has, err = p.cache.HasPluginType(ctx, types.PluginTypeOkta) + require.NoError(t, err) + require.False(t, has) +} diff --git a/lib/services/plugins.go b/lib/services/plugins.go index 571761925c506..5a9c749c2df4d 100644 --- a/lib/services/plugins.go +++ b/lib/services/plugins.go @@ -29,16 +29,20 @@ import ( "github.com/gravitational/teleport/lib/utils" ) +type PluginGetter interface { + GetPlugin(ctx context.Context, name string, withSecrets bool) (types.Plugin, error) + GetPlugins(ctx context.Context, withSecrets bool) ([]types.Plugin, error) + ListPlugins(ctx context.Context, limit int, startKey string, withSecrets bool) ([]types.Plugin, string, error) + HasPluginType(ctx context.Context, pluginType types.PluginType) (bool, error) +} + // Plugins is the plugin service type Plugins interface { + PluginGetter CreatePlugin(ctx context.Context, plugin types.Plugin) error UpdatePlugin(ctx context.Context, plugin types.Plugin) (types.Plugin, error) DeleteAllPlugins(ctx context.Context) error DeletePlugin(ctx context.Context, name string) error - GetPlugin(ctx context.Context, name string, withSecrets bool) (types.Plugin, error) - GetPlugins(ctx context.Context, withSecrets bool) ([]types.Plugin, error) - ListPlugins(ctx context.Context, limit int, startKey string, withSecrets bool) ([]types.Plugin, string, error) - HasPluginType(ctx context.Context, pluginType types.PluginType) (bool, error) SetPluginCredentials(ctx context.Context, name string, creds types.PluginCredentials) error SetPluginStatus(ctx context.Context, name string, creds types.PluginStatus) error }