Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,27 @@ func (c *Client) GetSnowflakeSessions(ctx context.Context) ([]types.WebSession,
return out, nil
}

// ListSnowflakeSessions returns a page of Snowflake web sessions.
func (c *Client) ListSnowflakeSessions(ctx context.Context, limit int, start string) ([]types.WebSession, string, error) {
resp, err := c.grpc.ListSnowflakeSessions(ctx, &proto.ListSnowflakeSessionsRequest{
PageSize: int32(limit),
PageToken: start,
})
if err != nil {
return nil, "", trace.Wrap(err)
}
sessions := make([]types.WebSession, len(resp.Sessions))
for i := range resp.Sessions {
sessions[i] = resp.Sessions[i]
}
return sessions, resp.NextPageToken, nil
}

// RangeSnowflakeSessions returns Snowflake web sessions within the range [start, end).
func (c *Client) RangeSnowflakeSessions(ctx context.Context, start, end string) iter.Seq2[types.WebSession, error] {
return clientutils.RangeResources(ctx, start, end, c.ListSnowflakeSessions, types.WebSession.GetName)
}

// CreateAppSession creates an application web session. Application web
// sessions represent a browser session the client holds.
func (c *Client) CreateAppSession(ctx context.Context, req *proto.CreateAppSessionRequest) (types.WebSession, error) {
Expand Down
2,369 changes: 1,418 additions & 951 deletions api/client/proto/authservice.pb.go

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions api/client/proto/authservice_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions api/proto/teleport/legacy/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2776,6 +2776,22 @@ message ListKubernetesClustersResponse {
string next_page_token = 2;
}

message ListSnowflakeSessionsRequest {
// The maximum number of items to return.
// The server may impose a different page size at its discretion.
int32 page_size = 1;
// The next_page_token value returned from a previous List request, if any.
string page_token = 2;
}

message ListSnowflakeSessionsResponse {
// Sessions is a list of Snowflake web sessions.
repeated types.WebSessionV2 sessions = 1;
// Token to retrieve the next page of results, or empty if there are no
// more results in the list.
string next_page_token = 2;
}

// AuthService is authentication/authorization service implementation
service AuthService {
// InventoryControlStream is the per-instance stream used to advertise teleport instance
Expand Down Expand Up @@ -2993,6 +3009,9 @@ service AuthService {
rpc GetSnowflakeSession(GetSnowflakeSessionRequest) returns (GetSnowflakeSessionResponse);
// GetSnowflakeSessions gets all Snowflake web sessions.
rpc GetSnowflakeSessions(google.protobuf.Empty) returns (GetSnowflakeSessionsResponse);
// ListSnowflakeSessions returns a page of Snowflake web sessions.
rpc ListSnowflakeSessions(ListSnowflakeSessionsRequest) returns (ListSnowflakeSessionsResponse);

// DeleteSnowflakeSession removes a Snowflake web session.
rpc DeleteSnowflakeSession(DeleteSnowflakeSessionRequest) returns (google.protobuf.Empty);
// DeleteAllSnowflakeSessions removes all Snowflake web sessions.
Expand Down
20 changes: 18 additions & 2 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -5802,11 +5802,27 @@ func (a *ServerWithRoles) GetSnowflakeSessions(ctx context.Context) ([]types.Web
}
}

sessions, err := a.authServer.GetSnowflakeSessions(ctx)
out, err := iterstream.Collect(a.authServer.RangeSnowflakeSessions(ctx, "", ""))
if err != nil {
return nil, trace.Wrap(err)
}
return sessions, nil

return out, nil
}

// ListSnowflakeSessions returns a page of Snowflake web sessions.
func (a *ServerWithRoles) ListSnowflakeSessions(ctx context.Context, limit int, start string) ([]types.WebSession, string, error) {
if !a.hasBuiltinRole(types.RoleDatabase) {
if err := a.authorizeAction(types.KindWebSession, types.VerbList, types.VerbRead); err != nil {
return nil, "", trace.Wrap(err)
}
}

return generic.CollectPageAndCursor(
a.authServer.RangeSnowflakeSessions(ctx, start, ""),
limit,
types.WebSession.GetName,
)
}

// CreateAppSession creates an application web session. Application web
Expand Down
50 changes: 50 additions & 0 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9018,6 +9018,56 @@ func TestGetSnowflakeSessions(t *testing.T) {
}
}

func TestListSnowflakeSessions(t *testing.T) {
t.Parallel()
srv := newTestTLSServer(t)
alice, bob, admin := createSessionTestUsers(t, srv.Auth())

client, err := srv.NewClient(authtest.TestBuiltin(types.RoleDatabase))
require.NoError(t, err)
ctx := t.Context()
opts := []cmp.Option{
cmpopts.SortSlices(func(a, b types.WebSession) bool {
return a.GetName() < b.GetName()
}),
cmpopts.IgnoreFields(types.Metadata{}, "Revision"),
}

createSession := func(user string) types.WebSession {
session, err := client.CreateSnowflakeSession(ctx, types.CreateSnowflakeSessionRequest{
Username: user,
TokenTTL: time.Minute * 15,
SessionToken: "test-token-" + user,
})
require.NoError(t, err)
return session
}

expected := []types.WebSession{
createSession(alice),
createSession(bob),
createSession(admin),
}

sessions, next, err := client.ListSnowflakeSessions(ctx, 0, "")
require.NoError(t, err)
require.Empty(t, next)
require.Len(t, sessions, 3)
require.Empty(t, cmp.Diff(expected, sessions, opts...))

page1, next, err := client.ListSnowflakeSessions(ctx, 2, "")
require.NoError(t, err)
require.NotEmpty(t, next)
require.Len(t, page1, 2)

page2, next, err := client.ListSnowflakeSessions(ctx, 0, next)
require.NoError(t, err)
require.Empty(t, next)
require.Len(t, page2, 1)
require.Empty(t, cmp.Diff(expected, append(page1, page2...), opts...))

}

func TestDeleteSnowflakeSession(t *testing.T) {
t.Parallel()
srv := newTestTLSServer(t)
Expand Down
9 changes: 9 additions & 0 deletions lib/auth/authclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,15 @@ type Cache interface {
// GetSnowflakeSession gets a Snowflake web session.
GetSnowflakeSession(context.Context, types.GetSnowflakeSessionRequest) (types.WebSession, error)

// GetSnowflakeSessions returns all Snowflake session resources.
GetSnowflakeSessions(ctx context.Context) ([]types.WebSession, error)

// ListSnowflakeSessions returns a page of Snowflake session resources.
ListSnowflakeSessions(ctx context.Context, limit int, startKey string) ([]types.WebSession, string, error)

// RangeSnowflakeSessions returns Snowflake session resources within the range [start, end).
RangeSnowflakeSessions(ctx context.Context, start, end string) iter.Seq2[types.WebSession, error]

// GetWebSession gets a web session for the given request
GetWebSession(context.Context, types.GetWebSessionRequest) (types.WebSession, error)

Expand Down
28 changes: 28 additions & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,34 @@ func (g *GRPCServer) GetSnowflakeSessions(ctx context.Context, e *emptypb.Empty)
}, nil
}

// ListSnowflakeSessions returns a page of Snowflake sessions.
func (g *GRPCServer) ListSnowflakeSessions(ctx context.Context, req *authpb.ListSnowflakeSessionsRequest) (*authpb.ListSnowflakeSessionsResponse, error) {
auth, err := g.authenticate(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

sessions, next, err := auth.ListSnowflakeSessions(ctx, int(req.PageSize), req.PageToken)
if err != nil {
return nil, trace.Wrap(err)
}

resp := &authpb.ListSnowflakeSessionsResponse{
Sessions: make([]*types.WebSessionV2, 0, len(sessions)),
NextPageToken: next,
}

for _, session := range sessions {
webessionV2, ok := session.(*types.WebSessionV2)
if !ok {
return nil, trace.BadParameter("unsupported web session type %T", session)
}
resp.Sessions = append(resp.Sessions, webessionV2)
}

return resp, nil
}

func (g *GRPCServer) DeleteSnowflakeSession(ctx context.Context, req *authpb.DeleteSnowflakeSessionRequest) (*emptypb.Empty, error) {
auth, err := g.authenticate(ctx)
if err != nil {
Expand Down
75 changes: 73 additions & 2 deletions lib/cache/web_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ const snowflakeSessionNameIndex snowflakeSessionIndex = "name"

func newSnowflakeSessionCollection(upstream services.SnowflakeSession, w types.WatchKind) (*collection[types.WebSession, snowflakeSessionIndex], error) {
if upstream == nil {
return nil, trace.BadParameter("missing parameter AppSession")
return nil, trace.BadParameter("missing parameter upstream")
}

return &collection[types.WebSession, snowflakeSessionIndex]{
Expand All @@ -254,7 +254,8 @@ func newSnowflakeSessionCollection(upstream services.SnowflakeSession, w types.W
snowflakeSessionNameIndex: types.WebSession.GetName,
}),
fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) {
webSessions, err := upstream.GetSnowflakeSessions(ctx)
// TODO(okraport): DELETE IN v21.0.0, replace with regular collect
webSessions, err := clientutils.CollectWithFallback(ctx, upstream.ListSnowflakeSessions, upstream.GetSnowflakeSessions)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -312,3 +313,73 @@ func (c *Cache) GetSnowflakeSession(ctx context.Context, req types.GetSnowflakeS
}
return out, trace.Wrap(err)
}

// RangeSnowflakeSessions returns Snowflake session resources within the range [start, end).
func (c *Cache) RangeSnowflakeSessions(ctx context.Context, start, end string) iter.Seq2[types.WebSession, error] {
lister := genericLister[types.WebSession, snowflakeSessionIndex]{
cache: c,
collection: c.collections.snowflakeSessions,
index: snowflakeSessionNameIndex,
upstreamList: c.Config.SnowflakeSession.ListSnowflakeSessions,
nextToken: types.WebSession.GetName,
// TODO(lokraszewski): DELETE IN v21.0.0
fallbackGetter: c.Config.SnowflakeSession.GetSnowflakeSessions,
}

return func(yield func(types.WebSession, error) bool) {
ctx, span := c.Tracer.Start(ctx, "cache/RangeSnowflakeSessions")
defer span.End()

for db, err := range lister.RangeWithFallback(ctx, start, end) {
if !yield(db, err) {
return
}

if err != nil {
return
}
}
}
}

// ListSnowflakeSessions returns a page of Snowflake session resources.
func (c *Cache) ListSnowflakeSessions(ctx context.Context, limit int, startKey string) ([]types.WebSession, string, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListSnowflakeSessions")
defer span.End()

lister := genericLister[types.WebSession, snowflakeSessionIndex]{
cache: c,
collection: c.collections.snowflakeSessions,
index: snowflakeSessionNameIndex,
upstreamList: c.Config.SnowflakeSession.ListSnowflakeSessions,
nextToken: func(a types.WebSession) string {
return a.GetMetadata().Name
},
}
out, next, err := lister.list(ctx, limit, startKey)
return out, next, trace.Wrap(err)
}

// GetSnowflakeSessions returns all Snowflake session resources.
func (c *Cache) GetSnowflakeSessions(ctx context.Context) ([]types.WebSession, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetSnowflakeSessions")
defer span.End()

rg, err := acquireReadGuard(c, c.collections.snowflakeSessions)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()

if !rg.ReadCache() {
sessions, err := c.Config.SnowflakeSession.GetSnowflakeSessions(ctx)
return sessions, trace.Wrap(err)
}

out := make([]types.WebSession, 0, rg.store.len())
for a := range rg.store.resources(snowflakeSessionNameIndex, "", "") {
out = append(out, a.Copy())
}

return out, nil
}
Loading
Loading