Skip to content

Commit 9e910de

Browse files
committed
Add Support for PG PUT Batch
1 parent 11f1d64 commit 9e910de

File tree

3 files changed

+286
-0
lines changed

3 files changed

+286
-0
lines changed

lib/backend/pgbk/put_batch.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Teleport
3+
* Copyright (C) 2025 Gravitational, Inc.
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU Affero General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU Affero General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU Affero General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*/
18+
19+
package pgbk
20+
21+
import (
22+
"context"
23+
"slices"
24+
25+
"github.com/gravitational/trace"
26+
"github.com/jackc/pgx/v5/pgtype/zeronull"
27+
28+
"github.com/gravitational/teleport/lib/backend"
29+
pgcommon "github.com/gravitational/teleport/lib/backend/pgbk/common"
30+
)
31+
32+
const (
33+
defaultUpsertBatchChunk = 100
34+
putBatchStmt = `
35+
INSERT INTO kv (key, value, expires, revision)
36+
SELECT * FROM UNNEST(
37+
$1::bytea[],
38+
$2::bytea[],
39+
$3::timestamptz[],
40+
$4::uuid[]
41+
)
42+
ON CONFLICT (key) DO UPDATE
43+
SET
44+
value = EXCLUDED.value,
45+
expires = EXCLUDED.expires,
46+
revision = EXCLUDED.revision;
47+
`
48+
)
49+
50+
// PutBatch puts multiple items into the backend in a single transaction.
51+
func (b *Backend) PutBatch(ctx context.Context, items []backend.Item) ([]string, error) {
52+
if len(items) == 0 {
53+
return []string{}, nil
54+
}
55+
revOut := make([]string, 0, len(items))
56+
for chunk := range slices.Chunk(items, defaultUpsertBatchChunk) {
57+
keys := make([][]byte, 0, len(chunk))
58+
values := make([][]byte, 0, len(chunk))
59+
expires := make([]zeronull.Timestamptz, 0, len(chunk))
60+
revs := make([]revision, 0, len(chunk))
61+
62+
for _, item := range chunk {
63+
keys = append(keys, nonNilKey(item.Key))
64+
values = append(values, nonNil(item.Value))
65+
expires = append(expires, zeronull.Timestamptz(item.Expires.UTC()))
66+
67+
revVal := newRevision()
68+
revs = append(revs, revVal)
69+
revOut = append(revOut, revisionToString(revVal))
70+
}
71+
72+
if _, err := pgcommon.Retry(ctx, b.log, func() (struct{}, error) {
73+
_, err := b.pool.Exec(ctx, putBatchStmt, keys, values, expires, revs)
74+
return struct{}{}, trace.Wrap(err)
75+
}); err != nil {
76+
return nil, trace.Wrap(err)
77+
}
78+
}
79+
return revOut, nil
80+
}

lib/backend/test/put_batch.go

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/*
2+
* Teleport
3+
* Copyright (C) 2025 Gravitational, Inc.
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU Affero General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU Affero General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU Affero General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*/
18+
19+
package test
20+
21+
import (
22+
"bytes"
23+
"context"
24+
"fmt"
25+
"testing"
26+
"time"
27+
28+
"github.com/stretchr/testify/require"
29+
30+
"github.com/gravitational/teleport/api/types"
31+
"github.com/gravitational/teleport/lib/backend"
32+
)
33+
34+
type PutBatcher interface {
35+
PutBatch(ctx context.Context, items []backend.Item) ([]string, error)
36+
}
37+
38+
const (
39+
watchInitTimeout = 10 * time.Second
40+
watchEventTimeout = 3 * time.Second
41+
)
42+
43+
func runPutBatch(t *testing.T, newBackend Constructor) {
44+
t.Helper()
45+
46+
ctx := context.Background()
47+
bk, _, err := newBackend()
48+
require.NoError(t, err)
49+
t.Cleanup(func() { _ = bk.Close() })
50+
51+
batcher, ok := bk.(PutBatcher)
52+
if !ok {
53+
t.Skip("backend does not implement PutBatch; skipping PutBatch suite")
54+
}
55+
56+
prefix := MakePrefix()
57+
rangeStart := prefix("")
58+
rangeEnd := backend.RangeEnd(prefix(""))
59+
60+
itemEqual := func(a, b backend.Item) bool {
61+
return a.Key.String() == b.Key.String() &&
62+
a.Revision == b.Revision &&
63+
string(a.Value) == string(b.Value) &&
64+
a.Expires.Equal(b.Expires)
65+
}
66+
67+
assertItemsEqual := func(t *testing.T, want, got []backend.Item) {
68+
t.Helper()
69+
require.Len(t, want, len(got))
70+
for i := range want {
71+
require.True(t, itemEqual(want[i], got[i]))
72+
}
73+
}
74+
75+
buildWant := func(items []backend.Item, rev []string) []backend.Item {
76+
out := make([]backend.Item, 0, len(items))
77+
for i, it := range items {
78+
out = append(out, backend.Item{
79+
Key: it.Key,
80+
Value: it.Value,
81+
Revision: rev[i],
82+
Expires: it.Expires,
83+
})
84+
}
85+
return out
86+
}
87+
88+
newTestItems := func() []backend.Item {
89+
return []backend.Item{
90+
{Key: prefix("a"), Value: []byte("A"), Expires: time.Now().Add(1 * time.Hour)},
91+
{Key: prefix("b"), Value: []byte("B")},
92+
{Key: prefix("c"), Value: []byte("C"), Expires: time.Now().Add(2 * time.Hour)},
93+
}
94+
}
95+
t.Run("put batch items should be propagated in event stream", func(t *testing.T) {
96+
w, err := bk.NewWatcher(t.Context(), backend.Watch{})
97+
require.NoError(t, err)
98+
t.Cleanup(func() { w.Close() })
99+
100+
select {
101+
case <-w.Done():
102+
t.Fatal("watcher closed immediately")
103+
case ev := <-w.Events():
104+
require.Equal(t, types.OpInit, ev.Type)
105+
case <-time.After(watchInitTimeout):
106+
t.Fatal("timed out waiting for init event")
107+
}
108+
109+
items := newTestItems()
110+
rev, err := batcher.PutBatch(ctx, items)
111+
require.NoError(t, err)
112+
require.NotEmpty(t, rev)
113+
114+
got := waitForEvents(t, w, len(items), watchEventTimeout)
115+
want := buildWant(items, rev)
116+
assertItemsEqual(t, want, got)
117+
require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
118+
})
119+
120+
t.Run("put-create-update", func(t *testing.T) {
121+
items := newTestItems()
122+
rev1, err := batcher.PutBatch(ctx, items)
123+
require.NoError(t, err)
124+
require.NotEmpty(t, rev1)
125+
126+
res, err := bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
127+
require.NoError(t, err)
128+
129+
want := buildWant(items, rev1)
130+
got := res.Items
131+
assertItemsEqual(t, want, got)
132+
133+
items[0].Value = []byte("A2")
134+
items[1].Value = []byte("B2")
135+
items[2].Value = []byte("C2")
136+
137+
rev2, err := batcher.PutBatch(ctx, items)
138+
require.NoError(t, err)
139+
require.NotEmpty(t, rev2)
140+
require.NotEqual(t, rev1, rev2)
141+
142+
res, err = bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
143+
require.NoError(t, err)
144+
145+
want = buildWant(items, rev2)
146+
got = res.Items
147+
assertItemsEqual(t, want, got)
148+
149+
require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
150+
})
151+
152+
t.Run("over-chunk-size", func(t *testing.T) {
153+
const itemCount = 1000
154+
const payloadSize = 300 * 1024 // 300 KiB
155+
items := make([]backend.Item, 0, itemCount)
156+
for i := 0; i < itemCount; i++ {
157+
items = append(items, backend.Item{
158+
Key: prefix(fmt.Sprintf("item/%04d", i)),
159+
Value: bytes.Repeat([]byte(fmt.Sprintf("%d", i)), payloadSize),
160+
Expires: time.Now().Add(5 * time.Minute),
161+
})
162+
}
163+
164+
rev, err := batcher.PutBatch(ctx, items)
165+
require.NoError(t, err)
166+
require.NotEmpty(t, rev)
167+
168+
res, err := bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
169+
require.NoError(t, err)
170+
171+
want := buildWant(items, rev)
172+
got := res.Items
173+
assertItemsEqual(t, want, got)
174+
175+
require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
176+
})
177+
}
178+
179+
func waitForEvents(t *testing.T, w backend.Watcher, wantCount int, timeout time.Duration) []backend.Item {
180+
t.Helper()
181+
182+
var out []backend.Item
183+
deadline := time.NewTimer(timeout)
184+
defer deadline.Stop()
185+
186+
for len(out) < wantCount {
187+
select {
188+
case ev, ok := <-w.Events():
189+
if !ok {
190+
t.Fatalf("watcher closed before receiving all events: got=%d want=%d", len(out), wantCount)
191+
}
192+
if ev.Type == types.OpPut {
193+
out = append(out, ev.Item)
194+
}
195+
case <-w.Done():
196+
t.Fatalf("watcher done before receiving all events: got=%d want=%d", len(out), wantCount)
197+
case <-deadline.C:
198+
t.Fatalf("timed out waiting for events: got=%d want=%d", len(out), wantCount)
199+
}
200+
}
201+
return out
202+
}

lib/backend/test/suite.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ func RunBackendComplianceSuite(t *testing.T, newBackend Constructor) {
198198
t.Run("ConditionalDelete", func(t *testing.T) {
199199
testConditionalDelete(t, newBackend)
200200
})
201+
202+
t.Run("PutBatch", func(t *testing.T) {
203+
runPutBatch(t, newBackend)
204+
})
201205
}
202206

203207
// RequireItems asserts that the supplied `actual` items collection matches

0 commit comments

Comments
 (0)