Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions lib/backend/pgbk/put_batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package pgbk

import (
"context"
"slices"

"github.com/gravitational/trace"
"github.com/jackc/pgx/v5/pgtype/zeronull"

"github.com/gravitational/teleport/lib/backend"
pgcommon "github.com/gravitational/teleport/lib/backend/pgbk/common"
)

const (
defaultUpsertBatchChunk = 100
putBatchStmt = `
INSERT INTO kv (key, value, expires, revision)
SELECT * FROM UNNEST(
$1::bytea[],
$2::bytea[],
$3::timestamptz[],
$4::uuid[]
)
ON CONFLICT (key) DO UPDATE
SET
value = EXCLUDED.value,
expires = EXCLUDED.expires,
revision = EXCLUDED.revision;
`
Comment on lines +34 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting this in a single line makes pg_stat_activity bearable to look at, and avoids accidental mixed indentation.

Suggested change
putBatchStmt = `
INSERT INTO kv (key, value, expires, revision)
SELECT * FROM UNNEST(
$1::bytea[],
$2::bytea[],
$3::timestamptz[],
$4::uuid[]
)
ON CONFLICT (key) DO UPDATE
SET
value = EXCLUDED.value,
expires = EXCLUDED.expires,
revision = EXCLUDED.revision;
`
putBatchStmt = "INSERT INTO kv (key, value, expires, revision) SELECT * FROM UNNEST($1::bytea[], $2::bytea[], $3::timestamptz[], $4::uuid[]) ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value, expires = EXCLUDED.expires, revision = EXCLUDED.revision;"

)

// PutBatch puts multiple items into the backend in a single transaction.
func (b *Backend) PutBatch(ctx context.Context, items []backend.Item) ([]string, error) {
if len(items) == 0 {
return []string{}, nil
}
revOut := make([]string, 0, len(items))
for chunk := range slices.Chunk(items, defaultUpsertBatchChunk) {
keys := make([][]byte, 0, len(chunk))
values := make([][]byte, 0, len(chunk))
expires := make([]zeronull.Timestamptz, 0, len(chunk))
revs := make([]revision, 0, len(chunk))

for _, item := range chunk {
keys = append(keys, nonNilKey(item.Key))
values = append(values, nonNil(item.Value))
expires = append(expires, zeronull.Timestamptz(item.Expires.UTC()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle as expected if item.Expires is zeroed aka it never expires?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shoud, The came approach is used in Put flow where the zeronull.Timestamptz handles IsZero

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Please just make sure we have a test case for this please 🙏🏾 🙏🏾 🙏🏾


revVal := newRevision()
revs = append(revs, revVal)
revOut = append(revOut, revisionToString(revVal))
Comment on lines +57 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
keys := make([][]byte, 0, len(chunk))
values := make([][]byte, 0, len(chunk))
expires := make([]zeronull.Timestamptz, 0, len(chunk))
revs := make([]revision, 0, len(chunk))
for _, item := range chunk {
keys = append(keys, nonNilKey(item.Key))
values = append(values, nonNil(item.Value))
expires = append(expires, zeronull.Timestamptz(item.Expires.UTC()))
revVal := newRevision()
revs = append(revs, revVal)
revOut = append(revOut, revisionToString(revVal))
keys := make([][]byte, len(chunk))
values := make([][]byte, len(chunk))
expires := make([]zeronull.Timestamptz, len(chunk))
revs := make([]revision, len(chunk))
for i, item := range chunk {
keys[i] = nonNilKey(item.Key)
values[i] = nonNil(item.Value)
expires[i] = zeronull.Timestamptz(item.Expires.UTC())
revVal := newRevision()
revs[i] = revVal
revOut = append(revOut, revisionToString(revVal))

/nit

Direct indexing is slightly more efficient since the runtime can skip bounds checks.

Disclaimer: I haven't actually tested this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like to see a benchmark for that before we go and use a pattern that's more error-prone.

Copy link
Contributor

@cthach cthach Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like to see a benchmark for that before we go and use a pattern that's more error-prone.

I gotchu. Heres a simple benchmark.

package teleport_test

import "testing"

func BenchmarkAppendNoPrealloc(b *testing.B) {
	for b.Loop() {
		var s []int

		for i := range 10000 {
			s = append(s, i)
		}
	}
}

func BenchmarkAppendWithPrealloc(b *testing.B) {
	for b.Loop() {
		s := make([]int, 0, 10000)

		for i := range 10000 {
			s = append(s, i)
		}
	}
}

func BenchmarkDirectIndexWithPrealloc(b *testing.B) {
	for b.Loop() {
		s := make([]int, 10000)

		for i := range 10000 {
			s[i] = i
		}
	}
}

Results:

❯ go test -bench=. -benchmem ./benchmark_prealloc_test.go

goos: darwin
goarch: arm64
cpu: Apple M4 Pro
BenchmarkAppendNoPrealloc-12               12636             90313 ns/op          357627 B/op         19 allocs/op
BenchmarkAppendWithPrealloc-12             94554             21342 ns/op           81920 B/op          1 allocs/op
BenchmarkDirectIndexWithPrealloc-12        65499             15631 ns/op           81920 B/op          1 allocs/op
PASS

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a chunk size of 1000 and copying the types we have in the loop in four slices (but with the revision getting copied as a string) I'm getting

BenchmarkAppendWithPrealloc-14         	   45082	     25864 ns/op	   98113 B/op	    1004 allocs/op
BenchmarkDirectIndexWithPrealloc-14    	   48951	     24134 ns/op	   98113 B/op	    1004 allocs/op

and if we add the generation of the revision, like we have in code

BenchmarkAppendWithPrealloc-14         	    5415	    220249 ns/op	  114116 B/op	    2004 allocs/op
BenchmarkDirectIndexWithPrealloc-14    	    5626	    216876 ns/op	  114113 B/op	    2004 allocs/op

an improvement of 6.6% and 1.5% respectively, which is a lot less worth it compared to the potential for misuse that direct indexing has - especially considering that this is minor preparation for a much larger amount of network I/O. It's definitely worth keeping in mind for very tight loops that mainly deal in memory tho, thank you for pointing that out.

Copy link
Contributor

@espadolini espadolini Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmark code
package foo

import (
	"fmt"
	"testing"
	"time"

	"github.com/google/uuid"
	"github.com/jackc/pgx/v5/pgtype/zeronull"

	"github.com/gravitational/teleport/lib/backend"
)

func BenchmarkAppendWithPrealloc(b *testing.B) {
	var chunk []backend.Item
	for i := range 1000 {
		chunk = append(chunk, backend.Item{
			Key:      backend.KeyFromString(fmt.Sprintf("/%d", i)),
			Value:    []byte("foo"),
			Expires:  time.Now(),
			Revision: uuid.NewString(),
		})
	}

	for b.Loop() {
		keys := make([][]byte, 0, len(chunk))
		values := make([][]byte, 0, len(chunk))
		expires := make([]zeronull.Timestamptz, 0, len(chunk))
		revs := make([]uuid.UUID, 0, len(chunk))
		// revs := make([]string, 0, len(chunk))

		for _, item := range chunk {
			keys = append(keys, []byte(item.Key.String()))
			values = append(values, item.Value)
			expires = append(expires, zeronull.Timestamptz(item.Expires.UTC()))
			revs = append(revs, uuid.New())
			// revs = append(revs, item.Revision)
		}
	}
}

func BenchmarkDirectIndexWithPrealloc(b *testing.B) {
	var chunk []backend.Item
	for i := range 1000 {
		chunk = append(chunk, backend.Item{
			Key:      backend.KeyFromString(fmt.Sprintf("/%d", i)),
			Value:    []byte("foo"),
			Expires:  time.Now(),
			Revision: uuid.NewString(),
		})
	}

	for b.Loop() {
		keys := make([][]byte, len(chunk))
		values := make([][]byte, len(chunk))
		expires := make([]zeronull.Timestamptz, len(chunk))
		revs := make([]uuid.UUID, len(chunk))
		// revs := make([]string, len(chunk))

		for i, item := range chunk {
			keys[i] = []byte(item.Key.String())
			values[i] = item.Value
			expires[i] = zeronull.Timestamptz(item.Expires.UTC())
			revs[i] = uuid.New()
			// revs[i] = item.Revision
		}
	}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for providing that! Yep, the efficiency gain is very small and probably only worth the tradeoff in performance-critical code.

That being said, I don't have a preference, so will defer to @smallinsky on which direction we go with.

}

if _, err := pgcommon.Retry(ctx, b.log, func() (struct{}, error) {
_, err := b.pool.Exec(ctx, putBatchStmt, keys, values, expires, revs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've liked having the sql right in the Exec/Query in the rest of pgbk so it's obvious what the positional parameters are.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance of moving the SQL here so we can see that the parameters are in the right order?

return struct{}{}, trace.Wrap(err)
}); err != nil {
return nil, trace.Wrap(err)
}
}
return revOut, nil
}
202 changes: 202 additions & 0 deletions lib/backend/test/put_batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package test

import (
"bytes"
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/backend"
)

type PutBatcher interface {
PutBatch(ctx context.Context, items []backend.Item) ([]string, error)
}
Comment on lines +34 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this just backend.BatchPutter?


const (
watchInitTimeout = 10 * time.Second
watchEventTimeout = 3 * time.Second
)

func runPutBatch(t *testing.T, newBackend Constructor) {
t.Helper()

ctx := context.Background()
bk, _, err := newBackend()
require.NoError(t, err)
t.Cleanup(func() { _ = bk.Close() })

batcher, ok := bk.(PutBatcher)
if !ok {
t.Skip("backend does not implement PutBatch; skipping PutBatch suite")
}

prefix := MakePrefix()
rangeStart := prefix("")
rangeEnd := backend.RangeEnd(prefix(""))

itemEqual := func(a, b backend.Item) bool {
return a.Key.String() == b.Key.String() &&
a.Revision == b.Revision &&
string(a.Value) == string(b.Value) &&
a.Expires.Equal(b.Expires)
}

assertItemsEqual := func(t *testing.T, want, got []backend.Item) {
t.Helper()
require.Len(t, want, len(got))
for i := range want {
require.True(t, itemEqual(want[i], got[i]))
}
}

buildWant := func(items []backend.Item, rev []string) []backend.Item {
out := make([]backend.Item, 0, len(items))
for i, it := range items {
out = append(out, backend.Item{
Key: it.Key,
Value: it.Value,
Revision: rev[i],
Expires: it.Expires,
})
}
return out
}

newTestItems := func() []backend.Item {
return []backend.Item{
{Key: prefix("a"), Value: []byte("A"), Expires: time.Now().Add(1 * time.Hour)},
{Key: prefix("b"), Value: []byte("B")},
{Key: prefix("c"), Value: []byte("C"), Expires: time.Now().Add(2 * time.Hour)},
}
}
t.Run("put batch items should be propagated in event stream", func(t *testing.T) {
w, err := bk.NewWatcher(t.Context(), backend.Watch{})
require.NoError(t, err)
t.Cleanup(func() { w.Close() })

select {
case <-w.Done():
t.Fatal("watcher closed immediately")
case ev := <-w.Events():
require.Equal(t, types.OpInit, ev.Type)
case <-time.After(watchInitTimeout):
t.Fatal("timed out waiting for init event")
Comment on lines +105 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that waiting for arbitrary timeouts on the execution of steps in the test is an antipattern, there's already a mechanism to set and enforce a timeout in the test harness, and a test runner that's slow enough might fail this for no reason other than needing some more time.

}

items := newTestItems()
rev, err := batcher.PutBatch(ctx, items)
require.NoError(t, err)
require.NotEmpty(t, rev)

got := waitForEvents(t, w, len(items), watchEventTimeout)
want := buildWant(items, rev)
assertItemsEqual(t, want, got)
require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
})

t.Run("put-create-update", func(t *testing.T) {
items := newTestItems()
rev1, err := batcher.PutBatch(ctx, items)
require.NoError(t, err)
require.NotEmpty(t, rev1)

res, err := bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
require.NoError(t, err)

want := buildWant(items, rev1)
got := res.Items
assertItemsEqual(t, want, got)

items[0].Value = []byte("A2")
items[1].Value = []byte("B2")
items[2].Value = []byte("C2")

rev2, err := batcher.PutBatch(ctx, items)
require.NoError(t, err)
require.NotEmpty(t, rev2)
require.NotEqual(t, rev1, rev2)

res, err = bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
require.NoError(t, err)

want = buildWant(items, rev2)
got = res.Items
assertItemsEqual(t, want, got)

require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
})

t.Run("over-chunk-size", func(t *testing.T) {
const itemCount = 1000
const payloadSize = 300 * 1024 // 300 KiB
items := make([]backend.Item, 0, itemCount)
for i := 0; i < itemCount; i++ {
items = append(items, backend.Item{
Key: prefix(fmt.Sprintf("item/%04d", i)),
Value: bytes.Repeat([]byte(fmt.Sprintf("%d", i)), payloadSize),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want variable sized payloads? I think the bytes slice can vary from 1-3 bytes * 300 * 1024.

Wouldn't this consume a large amount of memory in order to run this test? Could we perhaps make the value fixed or reduce the number of items while getting the same test coverage?

Expires: time.Now().Add(5 * time.Minute),
})
}

rev, err := batcher.PutBatch(ctx, items)
require.NoError(t, err)
require.NotEmpty(t, rev)

res, err := bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
require.NoError(t, err)

want := buildWant(items, rev)
got := res.Items
assertItemsEqual(t, want, got)

require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
})
}

func waitForEvents(t *testing.T, w backend.Watcher, wantCount int, timeout time.Duration) []backend.Item {
t.Helper()

var out []backend.Item
deadline := time.NewTimer(timeout)
defer deadline.Stop()

for len(out) < wantCount {
select {
case ev, ok := <-w.Events():
if !ok {
t.Fatalf("watcher closed before receiving all events: got=%d want=%d", len(out), wantCount)
}
if ev.Type == types.OpPut {
out = append(out, ev.Item)
}
case <-w.Done():
t.Fatalf("watcher done before receiving all events: got=%d want=%d", len(out), wantCount)
case <-deadline.C:
t.Fatalf("timed out waiting for events: got=%d want=%d", len(out), wantCount)
Comment on lines +197 to +198
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
}
return out
}
4 changes: 4 additions & 0 deletions lib/backend/test/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ func RunBackendComplianceSuite(t *testing.T, newBackend Constructor) {
t.Run("ConditionalDelete", func(t *testing.T) {
testConditionalDelete(t, newBackend)
})

t.Run("PutBatch", func(t *testing.T) {
runPutBatch(t, newBackend)
})
}

// RequireItems asserts that the supplied `actual` items collection matches
Expand Down
Loading