Skip to content

Commit d20617e

Browse files
committed
[v18][api] Add ListSnowflakeSessions
Backport #59591 to branch/v18
1 parent bfa5b7d commit d20617e

File tree

13 files changed

+1847
-1025
lines changed

13 files changed

+1847
-1025
lines changed

api/client/client.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,6 +1572,27 @@ func (c *Client) GetSnowflakeSessions(ctx context.Context) ([]types.WebSession,
15721572
return out, nil
15731573
}
15741574

1575+
// ListSnowflakeSessions returns a page of Snowflake web sessions.
1576+
func (c *Client) ListSnowflakeSessions(ctx context.Context, limit int, start string) ([]types.WebSession, string, error) {
1577+
resp, err := c.grpc.ListSnowflakeSessions(ctx, &proto.ListSnowflakeSessionsRequest{
1578+
PageSize: int32(limit),
1579+
PageToken: start,
1580+
})
1581+
if err != nil {
1582+
return nil, "", trace.Wrap(err)
1583+
}
1584+
sessions := make([]types.WebSession, len(resp.Sessions))
1585+
for i := range resp.Sessions {
1586+
sessions[i] = resp.Sessions[i]
1587+
}
1588+
return sessions, resp.NextPageToken, nil
1589+
}
1590+
1591+
// RangeSnowflakeSessions returns Snowflake web sessions within the range [start, end).
1592+
func (c *Client) RangeSnowflakeSessions(ctx context.Context, start, end string) iter.Seq2[types.WebSession, error] {
1593+
return clientutils.RangeResources(ctx, start, end, c.ListSnowflakeSessions, types.WebSession.GetName)
1594+
}
1595+
15751596
// CreateAppSession creates an application web session. Application web
15761597
// sessions represent a browser session the client holds.
15771598
func (c *Client) CreateAppSession(ctx context.Context, req *proto.CreateAppSessionRequest) (types.WebSession, error) {

api/client/proto/authservice.pb.go

Lines changed: 1418 additions & 951 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/client/proto/authservice_grpc.pb.go

Lines changed: 40 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/proto/teleport/legacy/client/proto/authservice.proto

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,6 +2776,22 @@ message ListKubernetesClustersResponse {
27762776
string next_page_token = 2;
27772777
}
27782778

2779+
message ListSnowflakeSessionsRequest {
2780+
// The maximum number of items to return.
2781+
// The server may impose a different page size at its discretion.
2782+
int32 page_size = 1;
2783+
// The next_page_token value returned from a previous List request, if any.
2784+
string page_token = 2;
2785+
}
2786+
2787+
message ListSnowflakeSessionsResponse {
2788+
// Sessions is a list of Snowflake web sessions.
2789+
repeated types.WebSessionV2 sessions = 1;
2790+
// Token to retrieve the next page of results, or empty if there are no
2791+
// more results in the list.
2792+
string next_page_token = 2;
2793+
}
2794+
27792795
// AuthService is authentication/authorization service implementation
27802796
service AuthService {
27812797
// InventoryControlStream is the per-instance stream used to advertise teleport instance
@@ -2993,6 +3009,9 @@ service AuthService {
29933009
rpc GetSnowflakeSession(GetSnowflakeSessionRequest) returns (GetSnowflakeSessionResponse);
29943010
// GetSnowflakeSessions gets all Snowflake web sessions.
29953011
rpc GetSnowflakeSessions(google.protobuf.Empty) returns (GetSnowflakeSessionsResponse);
3012+
// ListSnowflakeSessions returns a page of Snowflake web sessions.
3013+
rpc ListSnowflakeSessions(ListSnowflakeSessionsRequest) returns (ListSnowflakeSessionsResponse);
3014+
29963015
// DeleteSnowflakeSession removes a Snowflake web session.
29973016
rpc DeleteSnowflakeSession(DeleteSnowflakeSessionRequest) returns (google.protobuf.Empty);
29983017
// DeleteAllSnowflakeSessions removes all Snowflake web sessions.

lib/auth/auth_with_roles.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5802,11 +5802,27 @@ func (a *ServerWithRoles) GetSnowflakeSessions(ctx context.Context) ([]types.Web
58025802
}
58035803
}
58045804

5805-
sessions, err := a.authServer.GetSnowflakeSessions(ctx)
5805+
out, err := iterstream.Collect(a.authServer.RangeSnowflakeSessions(ctx, "", ""))
58065806
if err != nil {
58075807
return nil, trace.Wrap(err)
58085808
}
5809-
return sessions, nil
5809+
5810+
return out, nil
5811+
}
5812+
5813+
// ListSnowflakeSessions returns a page of Snowflake web sessions.
5814+
func (a *ServerWithRoles) ListSnowflakeSessions(ctx context.Context, limit int, start string) ([]types.WebSession, string, error) {
5815+
if !a.hasBuiltinRole(types.RoleDatabase) {
5816+
if err := a.authorizeAction(types.KindWebSession, types.VerbList, types.VerbRead); err != nil {
5817+
return nil, "", trace.Wrap(err)
5818+
}
5819+
}
5820+
5821+
return generic.CollectPageAndCursor(
5822+
a.authServer.RangeSnowflakeSessions(ctx, start, ""),
5823+
limit,
5824+
types.WebSession.GetName,
5825+
)
58105826
}
58115827

58125828
// CreateAppSession creates an application web session. Application web

lib/auth/auth_with_roles_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9018,6 +9018,56 @@ func TestGetSnowflakeSessions(t *testing.T) {
90189018
}
90199019
}
90209020

9021+
func TestListSnowflakeSessions(t *testing.T) {
9022+
t.Parallel()
9023+
srv := newTestTLSServer(t)
9024+
alice, bob, admin := createSessionTestUsers(t, srv.Auth())
9025+
9026+
client, err := srv.NewClient(authtest.TestBuiltin(types.RoleDatabase))
9027+
require.NoError(t, err)
9028+
ctx := t.Context()
9029+
opts := []cmp.Option{
9030+
cmpopts.SortSlices(func(a, b types.WebSession) bool {
9031+
return a.GetName() < b.GetName()
9032+
}),
9033+
cmpopts.IgnoreFields(types.Metadata{}, "Revision"),
9034+
}
9035+
9036+
createSession := func(user string) types.WebSession {
9037+
session, err := client.CreateSnowflakeSession(ctx, types.CreateSnowflakeSessionRequest{
9038+
Username: user,
9039+
TokenTTL: time.Minute * 15,
9040+
SessionToken: "test-token-" + user,
9041+
})
9042+
require.NoError(t, err)
9043+
return session
9044+
}
9045+
9046+
expected := []types.WebSession{
9047+
createSession(alice),
9048+
createSession(bob),
9049+
createSession(admin),
9050+
}
9051+
9052+
sessions, next, err := client.ListSnowflakeSessions(ctx, 0, "")
9053+
require.NoError(t, err)
9054+
require.Empty(t, next)
9055+
require.Len(t, sessions, 3)
9056+
require.Empty(t, cmp.Diff(expected, sessions, opts...))
9057+
9058+
page1, next, err := client.ListSnowflakeSessions(ctx, 2, "")
9059+
require.NoError(t, err)
9060+
require.NotEmpty(t, next)
9061+
require.Len(t, page1, 2)
9062+
9063+
page2, next, err := client.ListSnowflakeSessions(ctx, 0, next)
9064+
require.NoError(t, err)
9065+
require.Empty(t, next)
9066+
require.Len(t, page2, 1)
9067+
require.Empty(t, cmp.Diff(expected, append(page1, page2...), opts...))
9068+
9069+
}
9070+
90219071
func TestDeleteSnowflakeSession(t *testing.T) {
90229072
t.Parallel()
90239073
srv := newTestTLSServer(t)

lib/auth/authclient/api.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,15 @@ type Cache interface {
10811081
// GetSnowflakeSession gets a Snowflake web session.
10821082
GetSnowflakeSession(context.Context, types.GetSnowflakeSessionRequest) (types.WebSession, error)
10831083

1084+
// GetSnowflakeSessions returns all Snowflake session resources.
1085+
GetSnowflakeSessions(ctx context.Context) ([]types.WebSession, error)
1086+
1087+
// ListSnowflakeSessions returns a page of Snowflake session resources.
1088+
ListSnowflakeSessions(ctx context.Context, limit int, startKey string) ([]types.WebSession, string, error)
1089+
1090+
// RangeSnowflakeSessions returns Snowflake session resources within the range [start, end).
1091+
RangeSnowflakeSessions(ctx context.Context, start, end string) iter.Seq2[types.WebSession, error]
1092+
10841093
// GetWebSession gets a web session for the given request
10851094
GetWebSession(context.Context, types.GetWebSessionRequest) (types.WebSession, error)
10861095

lib/auth/grpcserver.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,34 @@ func (g *GRPCServer) GetSnowflakeSessions(ctx context.Context, e *emptypb.Empty)
16311631
}, nil
16321632
}
16331633

1634+
// ListSnowflakeSessions returns a page of Snowflake sessions.
1635+
func (g *GRPCServer) ListSnowflakeSessions(ctx context.Context, req *authpb.ListSnowflakeSessionsRequest) (*authpb.ListSnowflakeSessionsResponse, error) {
1636+
auth, err := g.authenticate(ctx)
1637+
if err != nil {
1638+
return nil, trace.Wrap(err)
1639+
}
1640+
1641+
sessions, next, err := auth.ListSnowflakeSessions(ctx, int(req.PageSize), req.PageToken)
1642+
if err != nil {
1643+
return nil, trace.Wrap(err)
1644+
}
1645+
1646+
resp := &authpb.ListSnowflakeSessionsResponse{
1647+
Sessions: make([]*types.WebSessionV2, 0, len(sessions)),
1648+
NextPageToken: next,
1649+
}
1650+
1651+
for _, session := range sessions {
1652+
webessionV2, ok := session.(*types.WebSessionV2)
1653+
if !ok {
1654+
return nil, trace.BadParameter("unsupported web session type %T", session)
1655+
}
1656+
resp.Sessions = append(resp.Sessions, webessionV2)
1657+
}
1658+
1659+
return resp, nil
1660+
}
1661+
16341662
func (g *GRPCServer) DeleteSnowflakeSession(ctx context.Context, req *authpb.DeleteSnowflakeSessionRequest) (*emptypb.Empty, error) {
16351663
auth, err := g.authenticate(ctx)
16361664
if err != nil {

lib/cache/web_session.go

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ const snowflakeSessionNameIndex snowflakeSessionIndex = "name"
243243

244244
func newSnowflakeSessionCollection(upstream services.SnowflakeSession, w types.WatchKind) (*collection[types.WebSession, snowflakeSessionIndex], error) {
245245
if upstream == nil {
246-
return nil, trace.BadParameter("missing parameter AppSession")
246+
return nil, trace.BadParameter("missing parameter upstream")
247247
}
248248

249249
return &collection[types.WebSession, snowflakeSessionIndex]{
@@ -254,7 +254,8 @@ func newSnowflakeSessionCollection(upstream services.SnowflakeSession, w types.W
254254
snowflakeSessionNameIndex: types.WebSession.GetName,
255255
}),
256256
fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) {
257-
webSessions, err := upstream.GetSnowflakeSessions(ctx)
257+
// TODO(okraport): DELETE IN v21.0.0, replace with regular collect
258+
webSessions, err := clientutils.CollectWithFallback(ctx, upstream.ListSnowflakeSessions, upstream.GetSnowflakeSessions)
258259
if err != nil {
259260
return nil, trace.Wrap(err)
260261
}
@@ -312,3 +313,73 @@ func (c *Cache) GetSnowflakeSession(ctx context.Context, req types.GetSnowflakeS
312313
}
313314
return out, trace.Wrap(err)
314315
}
316+
317+
// RangeSnowflakeSessions returns Snowflake session resources within the range [start, end).
318+
func (c *Cache) RangeSnowflakeSessions(ctx context.Context, start, end string) iter.Seq2[types.WebSession, error] {
319+
lister := genericLister[types.WebSession, snowflakeSessionIndex]{
320+
cache: c,
321+
collection: c.collections.snowflakeSessions,
322+
index: snowflakeSessionNameIndex,
323+
upstreamList: c.Config.SnowflakeSession.ListSnowflakeSessions,
324+
nextToken: types.WebSession.GetName,
325+
// TODO(lokraszewski): DELETE IN v21.0.0
326+
fallbackGetter: c.Config.SnowflakeSession.GetSnowflakeSessions,
327+
}
328+
329+
return func(yield func(types.WebSession, error) bool) {
330+
ctx, span := c.Tracer.Start(ctx, "cache/RangeSnowflakeSessions")
331+
defer span.End()
332+
333+
for db, err := range lister.RangeWithFallback(ctx, start, end) {
334+
if !yield(db, err) {
335+
return
336+
}
337+
338+
if err != nil {
339+
return
340+
}
341+
}
342+
}
343+
}
344+
345+
// ListSnowflakeSessions returns a page of Snowflake session resources.
346+
func (c *Cache) ListSnowflakeSessions(ctx context.Context, limit int, startKey string) ([]types.WebSession, string, error) {
347+
ctx, span := c.Tracer.Start(ctx, "cache/ListSnowflakeSessions")
348+
defer span.End()
349+
350+
lister := genericLister[types.WebSession, snowflakeSessionIndex]{
351+
cache: c,
352+
collection: c.collections.snowflakeSessions,
353+
index: snowflakeSessionNameIndex,
354+
upstreamList: c.Config.SnowflakeSession.ListSnowflakeSessions,
355+
nextToken: func(a types.WebSession) string {
356+
return a.GetMetadata().Name
357+
},
358+
}
359+
out, next, err := lister.list(ctx, limit, startKey)
360+
return out, next, trace.Wrap(err)
361+
}
362+
363+
// GetSnowflakeSessions returns all Snowflake session resources.
364+
func (c *Cache) GetSnowflakeSessions(ctx context.Context) ([]types.WebSession, error) {
365+
ctx, span := c.Tracer.Start(ctx, "cache/GetSnowflakeSessions")
366+
defer span.End()
367+
368+
rg, err := acquireReadGuard(c, c.collections.snowflakeSessions)
369+
if err != nil {
370+
return nil, trace.Wrap(err)
371+
}
372+
defer rg.Release()
373+
374+
if !rg.ReadCache() {
375+
sessions, err := c.Config.SnowflakeSession.GetSnowflakeSessions(ctx)
376+
return sessions, trace.Wrap(err)
377+
}
378+
379+
out := make([]types.WebSession, 0, rg.store.len())
380+
for a := range rg.store.resources(snowflakeSessionNameIndex, "", "") {
381+
out = append(out, a.Copy())
382+
}
383+
384+
return out, nil
385+
}

0 commit comments

Comments
 (0)