Skip to content

Commit 131a7ee

Browse files
committed
Add Support for PG PUT Batch
1 parent 11f1d64 commit 131a7ee

File tree

3 files changed

+286
-0
lines changed

3 files changed

+286
-0
lines changed

lib/backend/pgbk/put_batch.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Teleport
3+
* Copyright (C) 2025 Gravitational, Inc.
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU Affero General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU Affero General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU Affero General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*/
18+
19+
package pgbk
20+
21+
import (
22+
"context"
23+
"slices"
24+
25+
"github.com/gravitational/trace"
26+
"github.com/jackc/pgx/v5/pgtype/zeronull"
27+
28+
"github.com/gravitational/teleport/lib/backend"
29+
pgcommon "github.com/gravitational/teleport/lib/backend/pgbk/common"
30+
)
31+
32+
const (
33+
defaultUpsertBatchChunk = 1000
34+
putBatchStmt = `
35+
INSERT INTO kv (key, value, expires, revision)
36+
SELECT k, v, e, r FROM UNNEST(
37+
$1::bytea[],
38+
$2::bytea[],
39+
$3::timestamptz[],
40+
$4::uuid[]
41+
) AS t(k, v, e, r)
42+
ON CONFLICT (key) DO UPDATE
43+
SET
44+
value = EXCLUDED.value,
45+
expires = EXCLUDED.expires,
46+
revision = EXCLUDED.revision;
47+
`
48+
)
49+
50+
// PutBatch puts multiple items into the backend in a single transaction.
51+
func (b *Backend) PutBatch(ctx context.Context, items []backend.Item) ([]string, error) {
52+
if len(items) == 0 {
53+
return nil, trace.BadParameter("at least one item must be provided")
54+
}
55+
revOut := make([]string, 0, len(items))
56+
for chunk := range slices.Chunk(items, defaultUpsertBatchChunk) {
57+
keys := make([][]byte, 0, len(chunk))
58+
values := make([][]byte, 0, len(chunk))
59+
expires := make([]zeronull.Timestamptz, 0, len(chunk))
60+
revs := make([]revision, 0, len(chunk))
61+
62+
for i := range chunk {
63+
chunk[i].Expires = chunk[i].Expires.UTC()
64+
keys = append(keys, nonNilKey(chunk[i].Key))
65+
values = append(values, nonNil(chunk[i].Value))
66+
expires = append(expires, zeronull.Timestamptz(chunk[i].Expires))
67+
68+
revVal := newRevision()
69+
revs = append(revs, revVal)
70+
revOut = append(revOut, revisionToString(revVal))
71+
}
72+
73+
if _, err := pgcommon.Retry(ctx, b.log, func() (struct{}, error) {
74+
_, err := b.pool.Exec(ctx, putBatchStmt, keys, values, expires, revs)
75+
return struct{}{}, trace.Wrap(err)
76+
}); err != nil {
77+
return nil, trace.Wrap(err)
78+
}
79+
}
80+
return revOut, nil
81+
}

lib/backend/test/put_batch.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Teleport
3+
* Copyright (C) 2025 Gravitational, Inc.
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU Affero General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU Affero General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU Affero General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*/
18+
19+
package test
20+
21+
import (
22+
"bytes"
23+
"context"
24+
"fmt"
25+
"testing"
26+
"time"
27+
28+
gocmp "github.com/google/go-cmp/cmp"
29+
"github.com/stretchr/testify/require"
30+
31+
"github.com/gravitational/teleport/api/types"
32+
"github.com/gravitational/teleport/lib/backend"
33+
)
34+
35+
type PutBatcher interface {
36+
PutBatch(ctx context.Context, items []backend.Item) ([]string, error)
37+
}
38+
39+
const (
40+
watchInitTimeout = 3 * time.Second
41+
watchEventTimeout = 3 * time.Second
42+
)
43+
44+
func runPutBatch(t *testing.T, newBackend Constructor) {
45+
t.Helper()
46+
47+
ctx := context.Background()
48+
bk, _, err := newBackend()
49+
require.NoError(t, err)
50+
t.Cleanup(func() { _ = bk.Close() })
51+
52+
batcher, ok := bk.(PutBatcher)
53+
if !ok {
54+
t.Skip("backend does not implement PutBatch; skipping PutBatch suite")
55+
}
56+
57+
prefix := MakePrefix()
58+
rangeStart := prefix("")
59+
rangeEnd := backend.RangeEnd(prefix(""))
60+
61+
itemEqual := func(a, b backend.Item) bool {
62+
return a.Key.String() == b.Key.String() &&
63+
a.Revision == b.Revision &&
64+
string(a.Value) == string(b.Value) &&
65+
a.Expires.Equal(b.Expires)
66+
}
67+
68+
assertItemsEqual := func(t *testing.T, want, got []backend.Item) {
69+
t.Helper()
70+
diff := gocmp.Diff(want, got, gocmp.Comparer(itemEqual))
71+
require.Equal(t, "", diff, "items differ (-want +got):\n%s", diff)
72+
}
73+
74+
buildWant := func(items []backend.Item, rev []string) []backend.Item {
75+
out := make([]backend.Item, 0, len(items))
76+
for _, it := range items {
77+
out = append(out, backend.Item{
78+
Key: it.Key,
79+
Value: it.Value,
80+
Revision: rev[i],
81+
Expires: it.Expires,
82+
})
83+
}
84+
return out
85+
}
86+
87+
newTestItems := func() []backend.Item {
88+
return []backend.Item{
89+
{Key: prefix("a"), Value: []byte("A"), Expires: time.Now().Add(1 * time.Hour)},
90+
{Key: prefix("b"), Value: []byte("B")},
91+
{Key: prefix("c"), Value: []byte("C"), Expires: time.Now().Add(2 * time.Hour)},
92+
}
93+
}
94+
t.Run("put batch items should be propagated in event stream", func(t *testing.T) {
95+
w, err := bk.NewWatcher(t.Context(), backend.Watch{})
96+
require.NoError(t, err)
97+
t.Cleanup(func() { w.Close() })
98+
99+
select {
100+
case <-w.Done():
101+
t.Fatal("watcher closed immediately")
102+
case ev := <-w.Events():
103+
require.Equal(t, types.OpInit, ev.Type)
104+
case <-time.After(watchInitTimeout):
105+
t.Fatal("timed out waiting for init event")
106+
}
107+
108+
items := newTestItems()
109+
rev, err := batcher.PutBatch(ctx, items)
110+
require.NoError(t, err)
111+
require.NotEmpty(t, rev)
112+
113+
got := waitForEvents(t, w, len(items), watchEventTimeout)
114+
want := buildWant(items, rev)
115+
assertItemsEqual(t, want, got)
116+
require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
117+
})
118+
119+
t.Run("put-create-update", func(t *testing.T) {
120+
items := newTestItems()
121+
rev1, err := batcher.PutBatch(ctx, items)
122+
require.NoError(t, err)
123+
require.NotEmpty(t, rev1)
124+
125+
res, err := bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
126+
require.NoError(t, err)
127+
128+
want := buildWant(items, rev1)
129+
got := res.Items
130+
assertItemsEqual(t, want, got)
131+
132+
items[0].Value = []byte("A2")
133+
items[1].Value = []byte("B2")
134+
items[2].Value = []byte("C2")
135+
136+
rev2, err := batcher.PutBatch(ctx, items)
137+
require.NoError(t, err)
138+
require.NotEmpty(t, rev2)
139+
require.NotEqual(t, rev1, rev2)
140+
141+
res, err = bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
142+
require.NoError(t, err)
143+
144+
want = buildWant(items, rev2)
145+
got = res.Items
146+
assertItemsEqual(t, want, got)
147+
148+
require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
149+
})
150+
151+
t.Run("over-chunk-size", func(t *testing.T) {
152+
const itemCount = 500
153+
const payloadSize = 400 * 1024 // 3 KiB
154+
items := make([]backend.Item, 0, itemCount)
155+
for i := 0; i < itemCount; i++ {
156+
items = append(items, backend.Item{
157+
Key: prefix(fmt.Sprintf("item/%04d", i)),
158+
Value: bytes.Repeat([]byte(fmt.Sprintf("%d", i)), payloadSize),
159+
Expires: time.Now().Add(5 * time.Minute),
160+
})
161+
}
162+
163+
rev, err := batcher.PutBatch(ctx, items)
164+
require.NoError(t, err)
165+
require.NotEmpty(t, rev)
166+
167+
res, err := bk.GetRange(ctx, rangeStart, rangeEnd, backend.NoLimit)
168+
require.NoError(t, err)
169+
170+
want := buildWant(items, rev)
171+
got := res.Items
172+
assertItemsEqual(t, want, got)
173+
174+
require.NoError(t, bk.DeleteRange(ctx, rangeStart, rangeEnd))
175+
})
176+
}
177+
178+
func waitForEvents(t *testing.T, w backend.Watcher, wantCount int, timeout time.Duration) []backend.Item {
179+
t.Helper()
180+
181+
var out []backend.Item
182+
deadline := time.NewTimer(timeout)
183+
defer deadline.Stop()
184+
185+
for len(out) < wantCount {
186+
select {
187+
case ev, ok := <-w.Events():
188+
if !ok {
189+
t.Fatalf("watcher closed before receiving all events: got=%d want=%d", len(out), wantCount)
190+
}
191+
if ev.Type == types.OpPut {
192+
out = append(out, ev.Item)
193+
}
194+
case <-w.Done():
195+
t.Fatalf("watcher done before receiving all events: got=%d want=%d", len(out), wantCount)
196+
case <-deadline.C:
197+
t.Fatalf("timed out waiting for events: got=%d want=%d", len(out), wantCount)
198+
}
199+
}
200+
return out
201+
}

lib/backend/test/suite.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ func RunBackendComplianceSuite(t *testing.T, newBackend Constructor) {
198198
t.Run("ConditionalDelete", func(t *testing.T) {
199199
testConditionalDelete(t, newBackend)
200200
})
201+
202+
t.Run("PutBatch", func(t *testing.T) {
203+
runPutBatch(t, newBackend)
204+
})
201205
}
202206

203207
// RequireItems asserts that the supplied `actual` items collection matches

0 commit comments

Comments
 (0)