Skip to content

Commit f43d5e2

Browse files
committed
Finish pbench forward
1 parent 4a2ae07 commit f43d5e2

File tree

10 files changed

+192
-46
lines changed

10 files changed

+192
-46
lines changed

cmd/forward.go

+6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ var forwardCmd = &cobra.Command{
3131
return fmt.Errorf("the forward target server host at position %d is identical to the source server host %s", i, sourceUrl.Host)
3232
}
3333
}
34+
for _, isTrino := range forward.PrestoFlagsArray.IsTrino {
35+
if isTrino {
36+
return fmt.Errorf("forward command does not support Trino yet")
37+
}
38+
}
3439
return nil
3540
},
3641
Short: "Watch incoming query workloads from the first Presto cluster (cluster 0) and forward them to the rest clusters.",
@@ -39,6 +44,7 @@ var forwardCmd = &cobra.Command{
3944
func init() {
4045
RootCmd.AddCommand(forwardCmd)
4146
forward.PrestoFlagsArray.Install(forwardCmd)
47+
_ = forwardCmd.Flags().MarkHidden("trino")
4248
wd, _ := os.Getwd()
4349
forwardCmd.Flags().StringVarP(&forward.OutputPath, "output-path", "o", wd, "Output directory path")
4450
forwardCmd.Flags().StringVarP(&forward.RunName, "name", "n", fmt.Sprintf("forward_%s", time.Now().Format(utils.DirectoryNameTimeFormat)), `Assign a name to this run. (default: "forward_<current time>")`)

cmd/forward/main.go

+125-28
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@ package forward
22

33
import (
44
"context"
5-
"fmt"
65
"github.com/spf13/cobra"
6+
"net/http"
7+
"os"
8+
"os/signal"
9+
"path/filepath"
710
"pbench/log"
811
"pbench/presto"
912
"pbench/utils"
1013
"sync"
14+
"sync/atomic"
15+
"syscall"
1116
"time"
1217
)
1318

@@ -17,49 +22,141 @@ var (
1722
RunName string
1823
PollInterval time.Duration
1924

20-
runningTasks sync.WaitGroup
25+
runningTasks sync.WaitGroup
26+
failedToForward atomic.Uint32
27+
forwarded atomic.Uint32
2128
)
2229

23-
type QueryHistory struct {
24-
QueryId string `presto:"query_id"`
25-
Query string `presto:"query"`
26-
Created *time.Time `presto:"created"`
27-
}
28-
2930
func Run(_ *cobra.Command, _ []string) {
30-
//OutputPath = filepath.Join(OutputPath, RunName)
31-
//utils.PrepareOutputDirectory(OutputPath)
32-
//
33-
//// also start to write logs to the output directory from this point on.
34-
//logPath := filepath.Join(OutputPath, "forward.log")
35-
//flushLog := utils.InitLogFile(logPath)
36-
//defer flushLog()
31+
OutputPath = filepath.Join(OutputPath, RunName)
32+
utils.PrepareOutputDirectory(OutputPath)
3733

38-
prestoClusters := PrestoFlagsArray.Assemble()
34+
// also start to write logs to the output directory from this point on.
35+
logPath := filepath.Join(OutputPath, "forward.log")
36+
flushLog := utils.InitLogFile(logPath)
37+
defer flushLog()
38+
39+
ctx, cancel := context.WithCancel(context.Background())
40+
timeToExit := make(chan os.Signal, 1)
41+
signal.Notify(timeToExit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
42+
// Handle SIGINT, SIGTERM, and SIGQUIT. When ctx is canceled, in-progress MySQL transactions and InfluxDB operations will roll back.
43+
go func() {
44+
sig := <-timeToExit
45+
if sig != nil {
46+
log.Info().Msg("abort forwarding")
47+
cancel()
48+
}
49+
}()
50+
51+
prestoClusters := PrestoFlagsArray.Pivot()
3952
// The design here is to forward the traffic from cluster 0 to the rest.
4053
sourceClusterSize := 0
4154
clients := make([]*presto.Client, 0, len(prestoClusters))
4255
for i, cluster := range prestoClusters {
4356
clients = append(clients, cluster.NewPrestoClient())
44-
if stats, _, err := clients[i].GetClusterInfo(context.Background()); err != nil {
45-
log.Fatal().Err(err).Msgf("cannot connect to cluster at position %d", i)
57+
// Check if we can connect to the cluster.
58+
if stats, _, err := clients[i].GetClusterInfo(ctx); err != nil {
59+
log.Fatal().Err(err).Msgf("cannot connect to cluster at position %d: %s", i, cluster.ServerUrl)
4660
} else if i == 0 {
4761
sourceClusterSize = stats.ActiveWorkers
4862
} else if stats.ActiveWorkers != sourceClusterSize {
49-
log.Warn().Msgf("source cluster size does not match target cluster %d size (%d != %d)", i, stats.ActiveWorkers, sourceClusterSize)
63+
log.Warn().Msgf("the source cluster and target cluster %d do not match in size (%d != %d)", i, sourceClusterSize, stats.ActiveWorkers)
5064
}
5165
}
5266

5367
sourceClient := clients[0]
5468
trueValue := true
55-
states, _, err := sourceClient.GetQueryState(context.Background(), &presto.GetQueryStatsOptions{
56-
IncludeAllQueries: &trueValue,
57-
IncludeAllQueryProgressStats: nil,
58-
ExcludeResourceGroupPathInfo: nil,
59-
QueryTextSizeLimit: nil,
60-
})
61-
if err != nil {
62-
log.Fatal().Err(err).Msgf("cannot get query states")
69+
// lastQueryStateCheckCutoffTime is the query create time of the most recent query in the previous batch.
70+
// We only look at queries created later than this timestamp in the following batch.
71+
lastQueryStateCheckCutoffTime := time.Time{}
72+
firstBatch := true
73+
// Keep running until the source cluster becomes unavailable or the user interrupts or quits using Ctrl + C or Ctrl + D.
74+
for ctx.Err() == nil {
75+
states, _, err := sourceClient.GetQueryState(ctx, &presto.GetQueryStatsOptions{IncludeAllQueries: &trueValue})
76+
if err != nil {
77+
log.Error().Err(err).Msgf("failed to get query states")
78+
break
79+
}
80+
newCutoffTime := time.Time{}
81+
for _, state := range states {
82+
if !state.CreateTime.After(lastQueryStateCheckCutoffTime) {
83+
// We looked at this query in the previous batch.
84+
continue
85+
}
86+
if newCutoffTime.Before(state.CreateTime) {
87+
newCutoffTime = state.CreateTime
88+
}
89+
if !firstBatch {
90+
runningTasks.Add(1)
91+
go forwardQuery(ctx, &state, clients)
92+
}
93+
}
94+
firstBatch = false
95+
if newCutoffTime.After(lastQueryStateCheckCutoffTime) {
96+
lastQueryStateCheckCutoffTime = newCutoffTime
97+
}
98+
timer := time.NewTimer(PollInterval)
99+
select {
100+
case <-ctx.Done():
101+
case <-timer.C:
102+
}
103+
}
104+
runningTasks.Wait()
105+
// This causes the signal handler to exit.
106+
close(timeToExit)
107+
log.Info().Uint32("forwarded", forwarded.Load()).Uint32("failed_to_forward", failedToForward.Load()).
108+
Msgf("finished forwarding queries")
109+
}
110+
111+
func forwardQuery(ctx context.Context, queryState *presto.QueryStateInfo, clients []*presto.Client) {
112+
defer runningTasks.Done()
113+
queryInfo, _, queryInfoErr := clients[0].GetQueryInfo(ctx, queryState.QueryId, false, nil)
114+
if queryInfoErr != nil {
115+
log.Error().Str("query_id", queryState.QueryId).Err(queryInfoErr).Msg("failed to get query info for forwarding")
116+
failedToForward.Add(1)
117+
return
118+
}
119+
SessionPropertyHeader := clients[0].GenerateSessionParamsHeaderValue(queryInfo.Session.CollectSessionProperties())
120+
successful, failed := atomic.Uint32{}, atomic.Uint32{}
121+
forwardedQueries := sync.WaitGroup{}
122+
for i := 1; i < len(clients); i++ {
123+
forwardedQueries.Add(1)
124+
go func(client *presto.Client) {
125+
defer forwardedQueries.Done()
126+
clientResult, _, queryErr := client.Query(ctx, queryInfo.Query, func(req *http.Request) {
127+
if queryInfo.Session.Catalog != nil {
128+
req.Header.Set(presto.CatalogHeader, *queryInfo.Session.Catalog)
129+
}
130+
if queryInfo.Session.Schema != nil {
131+
req.Header.Set(presto.SchemaHeader, *queryInfo.Session.Schema)
132+
}
133+
req.Header.Set(presto.SessionHeader, SessionPropertyHeader)
134+
req.Header.Set(presto.SourceHeader, queryInfo.QueryId)
135+
})
136+
if queryErr != nil {
137+
log.Error().Str("source_query_id", queryInfo.QueryId).
138+
Str("target_host", client.GetHost()).Err(queryErr).Msg("failed to execute query")
139+
failed.Add(1)
140+
return
141+
}
142+
rowCount := 0
143+
drainErr := clientResult.Drain(ctx, func(qr *presto.QueryResults) error {
144+
rowCount += len(qr.Data)
145+
return nil
146+
})
147+
if drainErr != nil {
148+
log.Error().Str("source_query_id", queryInfo.QueryId).
149+
Str("target_host", client.GetHost()).Err(drainErr).Msg("failed to fetch query result")
150+
failed.Add(1)
151+
return
152+
}
153+
successful.Add(1)
154+
log.Info().Str("source_query_id", queryInfo.QueryId).
155+
Str("target_host", client.GetHost()).Int("row_count", rowCount).Msg("query executed successfully")
156+
}(clients[i])
63157
}
64-
fmt.Printf("%#v", states)
158+
forwardedQueries.Wait()
159+
log.Info().Str("source_query_id", queryInfo.QueryId).Uint32("successful", successful.Load()).
160+
Uint32("failed", failed.Load()).Msg("query forwarding finished")
161+
forwarded.Add(1)
65162
}

cmd/replay.go

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ We also expect the queries in this CSV file are sorted by "create_time" in ascen
2424
}
2525
utils.ExpandHomeDirectory(&replay.OutputPath)
2626
utils.ExpandHomeDirectory(&args[0])
27+
if replay.PrestoFlags.IsTrino {
28+
return fmt.Errorf("replay command does not support Trino yet")
29+
}
2730
return nil
2831
},
2932
Run: replay.Run,
@@ -33,6 +36,7 @@ func init() {
3336
RootCmd.AddCommand(replayCmd)
3437
wd, _ := os.Getwd()
3538
replay.PrestoFlags.Install(replayCmd)
39+
_ = replayCmd.Flags().MarkHidden("trino")
3640
replayCmd.Flags().StringVarP(&replay.OutputPath, "output-path", "o", wd, "Output directory path")
3741
replayCmd.Flags().StringVarP(&replay.RunName, "name", "n", fmt.Sprintf("replay_%s", time.Now().Format(utils.DirectoryNameTimeFormat)), `Assign a name to this run. (default: "replay_<current time>")`)
3842
}

presto/client.go

+4
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ func GenerateHttpQueryParameter(v any) string {
9191
return queryBuilder.String()
9292
}
9393

94+
func (c *Client) GetHost() string {
95+
return c.serverUrl.Host
96+
}
97+
9498
func (c *Client) setHeader(key, value string) {
9599
if c.isTrino {
96100
key = strings.Replace(key, "X-Presto", "X-Trino", 1)

presto/client_test.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"github.com/stretchr/testify/assert"
77
"pbench/presto"
8+
"pbench/presto/query_json"
89
"strings"
910
"syscall"
1011
"testing"
@@ -42,9 +43,14 @@ func TestQuery(t *testing.T) {
4243
assert.Equal(t, 150000, rowCount)
4344

4445
buf := &strings.Builder{}
45-
_, err = client.GetQueryInfo(context.Background(), qr.Id, false, buf)
46+
var queryInfo *query_json.QueryInfo
47+
queryInfo, _, err = client.GetQueryInfo(context.Background(), qr.Id, false, buf)
4648
assert.Nil(t, err)
49+
assert.Nil(t, queryInfo)
4750
assert.Greater(t, buf.Len(), 0)
51+
queryInfo, _, err = client.GetQueryInfo(context.Background(), qr.Id, true, nil)
52+
assert.Nil(t, err)
53+
assert.Equal(t, qr.Id, queryInfo.QueryId)
4854
}
4955

5056
func TestGenerateQueryParameter(t *testing.T) {

presto/query.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"io"
66
"net/http"
7+
"pbench/presto/query_json"
78
)
89

910
func (c *Client) requestQueryResults(ctx context.Context, req *http.Request) (*QueryResults, *http.Response, error) {
@@ -49,19 +50,30 @@ func (c *Client) CancelQuery(ctx context.Context, nextUri string, opts ...Reques
4950
return c.requestQueryResults(ctx, req)
5051
}
5152

52-
func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool, writer io.Writer, opts ...RequestOption) (*http.Response, error) {
53+
// GetQueryInfo retrieves the query JSON for the given query ID.
54+
// If writer is nil, we return deserialized QueryInfo. Otherwise, we just return the raw buffer.
55+
func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool, writer io.Writer, opts ...RequestOption) (*query_json.QueryInfo, *http.Response, error) {
5356
urlStr := "v1/query/" + queryId
5457
if pretty {
5558
urlStr += "?pretty"
5659
}
5760
req, err := c.NewRequest("GET",
5861
urlStr, nil, opts...)
5962
if err != nil {
60-
return nil, err
63+
return nil, nil, err
64+
}
65+
var (
66+
resp *http.Response
67+
queryInfo *query_json.QueryInfo
68+
)
69+
if writer != nil {
70+
resp, err = c.Do(ctx, req, writer)
71+
} else {
72+
queryInfo = new(query_json.QueryInfo)
73+
resp, err = c.Do(ctx, req, queryInfo)
6174
}
62-
resp, err := c.Do(ctx, req, writer)
6375
if err != nil {
64-
return resp, err
76+
return nil, resp, err
6577
}
66-
return resp, nil
78+
return queryInfo, resp, nil
6779
}

presto/query_json/session.go

+16
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,19 @@ func (s *Session) PrepareForInsert() {
5151
s.SessionPropertiesJson = string(jsonBytes[:len(jsonBytes)-1])
5252
}
5353
}
54+
55+
func (s *Session) CollectSessionProperties() map[string]any {
56+
sessionParams := make(map[string]any)
57+
if s == nil {
58+
return sessionParams
59+
}
60+
for k, v := range s.SystemProperties {
61+
sessionParams[k] = v
62+
}
63+
for catalog, catalogProps := range s.CatalogProperties {
64+
for k, v := range catalogProps {
65+
sessionParams[catalog+"."+k] = v
66+
}
67+
}
68+
return sessionParams
69+
}

presto/query_state.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ import (
1111
// https://github.com/prestodb/presto/blob/master/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfo.java
1212
// Unused fields are commented out for now.
1313
type QueryStateInfo struct {
14-
QueryId string `json:"queryId"`
15-
QueryState string `json:"queryState"`
14+
QueryId string `json:"queryId"`
15+
//QueryState string `json:"queryState"`
1616
//ResourceGroupId []string `json:"resourceGroupId"`
17-
Query string `json:"query"`
18-
QueryTruncated bool `json:"queryTruncated"`
19-
CreateTime time.Time `json:"createTime"`
20-
User string `json:"user"`
21-
Authenticated bool `json:"authenticated"`
22-
Source string `json:"source"`
23-
Catalog string `json:"catalog"`
17+
//Query string `json:"query"`
18+
//QueryTruncated bool `json:"queryTruncated"`
19+
CreateTime time.Time `json:"createTime"`
20+
//User string `json:"user"`
21+
//Authenticated bool `json:"authenticated"`
22+
//Source string `json:"source,omitempty"`
23+
//Catalog string `json:"catalog"`
2424
//Progress struct {
2525
// ElapsedTimeMillis int `json:"elapsedTimeMillis"`
2626
// QueuedTimeMillis int `json:"queuedTimeMillis"`
@@ -61,7 +61,7 @@ func (c *Client) GetQueryState(ctx context.Context, reqOpt *GetQueryStatsOptions
6161
}
6262

6363
infoArray := make([]QueryStateInfo, 0, 16)
64-
resp, err := c.Do(ctx, req, infoArray)
64+
resp, err := c.Do(ctx, req, &infoArray)
6565
if err != nil {
6666
return nil, resp, err
6767
}

stage/stage_utils.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func (s *Stage) saveQueryJsonFile(result *QueryResult) {
270270
checkErr(err)
271271
if err == nil {
272272
// We need to save the query json file even if the stage context is canceled.
273-
_, err = s.Client.GetQueryInfo(utils.GetCtxWithTimeout(time.Second*5), result.QueryId, false, queryJsonFile)
273+
_, _, err = s.Client.GetQueryInfo(utils.GetCtxWithTimeout(time.Second*5), result.QueryId, false, queryJsonFile)
274274
checkErr(err)
275275
checkErr(queryJsonFile.Close())
276276
}

utils/presto_flags.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ func (a *PrestoFlagsArray) Install(cmd *cobra.Command) {
5656
cmd.Flags().StringArrayVarP(&a.Password, "password", "p", []string{""}, "Presto user password (optional)")
5757
}
5858

59-
func (a *PrestoFlagsArray) Assemble() []PrestoFlags {
59+
// Pivot generates PrestoFlags array that is suitable for creating Presto clients conveniently.
60+
func (a *PrestoFlagsArray) Pivot() []PrestoFlags {
6061
ret := make([]PrestoFlags, 0, len(a.ServerUrl))
6162
for _, url := range a.ServerUrl {
6263
ret = append(ret, PrestoFlags{

0 commit comments

Comments
 (0)