Skip to content

Commit 2f4156f

Browse files
committed
Add forward command
1 parent 195ecf6 commit 2f4156f

File tree

8 files changed

+284
-8
lines changed

8 files changed

+284
-8
lines changed

cmd/forward.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package cmd
2+
3+
import (
4+
"fmt"
5+
"github.com/spf13/cobra"
6+
"net/url"
7+
"os"
8+
"pbench/cmd/forward"
9+
"pbench/utils"
10+
"time"
11+
)
12+
13+
var forwardCmd = &cobra.Command{
14+
Use: `forward [flags]`,
15+
DisableFlagsInUseLine: true,
16+
Run: forward.Run,
17+
Args: func(cmd *cobra.Command, args []string) error {
18+
utils.ExpandHomeDirectory(&forward.OutputPath)
19+
if len(forward.PrestoFlagsArray.ServerUrl) < 2 {
20+
return fmt.Errorf("information for at least two clusters is required to do workload forwarding")
21+
}
22+
var sourceUrl *url.URL
23+
for i, serverUrl := range forward.PrestoFlagsArray.ServerUrl {
24+
parsedUrl, err := url.Parse(serverUrl)
25+
if err != nil {
26+
return fmt.Errorf("failed to parse server URL at position %d: %w", i, err)
27+
}
28+
if i == 0 {
29+
sourceUrl = parsedUrl
30+
} else if parsedUrl.Host == sourceUrl.Host {
31+
return fmt.Errorf("the forward target server host at position %d is identical to the source server host %s", i, sourceUrl.Host)
32+
}
33+
}
34+
return nil
35+
},
36+
Short: "Watch incoming query workloads from the first Presto cluster (cluster 0) and forward them to the rest clusters.",
37+
}
38+
39+
func init() {
40+
RootCmd.AddCommand(forwardCmd)
41+
forward.PrestoFlagsArray.Install(forwardCmd)
42+
wd, _ := os.Getwd()
43+
forwardCmd.Flags().StringVarP(&forward.OutputPath, "output-path", "o", wd, "Output directory path")
44+
forwardCmd.Flags().StringVarP(&forward.RunName, "name", "n", fmt.Sprintf("forward_%s", time.Now().Format(utils.DirectoryNameTimeFormat)), `Assign a name to this run. (default: "forward_<current time>")`)
45+
forwardCmd.Flags().DurationVarP(&forward.PollInterval, "poll-interval", "i", time.Second*5, "Interval between polls to the source cluster")
46+
}

cmd/forward/main.go

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package forward
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/spf13/cobra"
7+
"pbench/log"
8+
"pbench/presto"
9+
"pbench/utils"
10+
"sync"
11+
"time"
12+
)
13+
14+
var (
15+
PrestoFlagsArray utils.PrestoFlagsArray
16+
OutputPath string
17+
RunName string
18+
PollInterval time.Duration
19+
20+
runningTasks sync.WaitGroup
21+
)
22+
23+
type QueryHistory struct {
24+
QueryId string `presto:"query_id"`
25+
Query string `presto:"query"`
26+
Created *time.Time `presto:"created"`
27+
}
28+
29+
func Run(_ *cobra.Command, _ []string) {
30+
//OutputPath = filepath.Join(OutputPath, RunName)
31+
//utils.PrepareOutputDirectory(OutputPath)
32+
//
33+
//// also start to write logs to the output directory from this point on.
34+
//logPath := filepath.Join(OutputPath, "forward.log")
35+
//flushLog := utils.InitLogFile(logPath)
36+
//defer flushLog()
37+
38+
prestoClusters := PrestoFlagsArray.Assemble()
39+
// The design here is to forward the traffic from cluster 0 to the rest.
40+
sourceClusterSize := 0
41+
clients := make([]*presto.Client, 0, len(prestoClusters))
42+
for i, cluster := range prestoClusters {
43+
clients = append(clients, cluster.NewPrestoClient())
44+
if stats, _, err := clients[i].GetClusterInfo(context.Background()); err != nil {
45+
log.Fatal().Err(err).Msgf("cannot connect to cluster at position %d", i)
46+
} else if i == 0 {
47+
sourceClusterSize = stats.ActiveWorkers
48+
} else if stats.ActiveWorkers != sourceClusterSize {
49+
log.Warn().Msgf("source cluster size does not match target cluster %d size (%d != %d)", i, stats.ActiveWorkers, sourceClusterSize)
50+
}
51+
}
52+
53+
sourceClient := clients[0]
54+
trueValue := true
55+
states, _, err := sourceClient.GetQueryState(context.Background(), &presto.GetQueryStatsOptions{
56+
IncludeAllQueries: &trueValue,
57+
IncludeAllQueryProgressStats: nil,
58+
ExcludeResourceGroupPathInfo: nil,
59+
QueryTextSizeLimit: nil,
60+
})
61+
if err != nil {
62+
log.Fatal().Err(err).Msgf("cannot get query states")
63+
}
64+
fmt.Printf("%#v", states)
65+
}

presto/client.go

+32
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"net/url"
1111
"pbench/log"
12+
"reflect"
1213
"strings"
1314
"time"
1415
)
@@ -59,6 +60,37 @@ func NewClient(serverUrl string, isTrino bool) (*Client, error) {
5960
return client, nil
6061
}
6162

63+
func derefValue(v *reflect.Value) reflect.Kind {
64+
k := v.Kind()
65+
for k == reflect.Pointer || k == reflect.Interface {
66+
*v = v.Elem()
67+
k = v.Kind()
68+
}
69+
return k
70+
}
71+
72+
func GenerateHttpQueryParameter(v any) string {
73+
rv := reflect.ValueOf(v)
74+
if rvk := derefValue(&rv); rvk != reflect.Struct {
75+
return ""
76+
}
77+
queryBuilder := strings.Builder{}
78+
vt := rv.Type()
79+
for i := 0; i < vt.NumField(); i++ {
80+
fv, ft := rv.Field(i), vt.Field(i)
81+
if fvk := derefValue(&fv); fvk == reflect.Invalid || !fv.CanInterface() {
82+
continue
83+
}
84+
if tag := ft.Tag.Get("query"); tag != "" {
85+
if queryBuilder.Len() > 0 {
86+
queryBuilder.WriteString("&")
87+
}
88+
queryBuilder.WriteString(fmt.Sprintf("%s=%s", url.QueryEscape(tag), url.QueryEscape(fmt.Sprint(fv.Interface()))))
89+
}
90+
}
91+
return queryBuilder.String()
92+
}
93+
6294
func (c *Client) setHeader(key, value string) {
6395
if c.isTrino {
6496
key = strings.Replace(key, "X-Presto", "X-Trino", 1)

presto/client_test.go

+19
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ func TestQuery(t *testing.T) {
1414
// This test requires Presto hive query runner.
1515
client, err := presto.NewClient("http://127.0.0.1:8080", false)
1616
assert.Nil(t, err)
17+
if _, _, err = client.GetClusterInfo(context.Background()); err != nil {
18+
t.Skip("local cluster is not ready")
19+
}
1720
qr, _, err := client.
1821
User("ethan").
1922
Catalog("tpch").
@@ -43,3 +46,19 @@ func TestQuery(t *testing.T) {
4346
assert.Nil(t, err)
4447
assert.Greater(t, buf.Len(), 0)
4548
}
49+
50+
func TestGenerateQueryParameter(t *testing.T) {
51+
stringValue := "was it clear (already)?"
52+
serializedQuery := presto.GenerateHttpQueryParameter(&struct {
53+
StringField *string `query:"stringField"`
54+
BoolField bool `query:"boolField"`
55+
IntField int `query:"intField"`
56+
BoolPtr *bool `query:"boolPtr"`
57+
StringPtr *string `query:"stringPtr"`
58+
}{
59+
StringField: &stringValue,
60+
BoolField: true,
61+
IntField: 123,
62+
})
63+
assert.Equal(t, `stringField=was+it+clear+%28already%29%3F&boolField=true&intField=123`, serializedQuery)
64+
}

presto/cluster.go

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package presto
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
type ClusterStats struct {
9+
RunningQueries int `json:"runningQueries"`
10+
BlockedQueries int `json:"blockedQueries"`
11+
QueuedQueries int `json:"queuedQueries"`
12+
ActiveWorkers int `json:"activeWorkers"`
13+
RunningDrivers int `json:"runningDrivers"`
14+
RunningTasks int `json:"runningTasks"`
15+
ReservedMemory float64 `json:"reservedMemory"`
16+
TotalInputRows int `json:"totalInputRows"`
17+
TotalInputBytes int `json:"totalInputBytes"`
18+
TotalCpuTimeSecs int `json:"totalCpuTimeSecs"`
19+
AdjustedQueueSize int `json:"adjustedQueueSize"`
20+
}
21+
22+
func (c *Client) GetClusterInfo(ctx context.Context, opts ...RequestOption) (*ClusterStats, *http.Response, error) {
23+
req, err := c.NewRequest("GET",
24+
"v1/cluster", nil, opts...)
25+
if err != nil {
26+
return nil, nil, err
27+
}
28+
29+
stats := new(ClusterStats)
30+
resp, err := c.Do(ctx, req, stats)
31+
if err != nil {
32+
return nil, resp, err
33+
}
34+
return stats, resp, nil
35+
}

presto/query_state.go

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package presto
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"time"
8+
)
9+
10+
// QueryStateInfo is the Go translation of the QueryStateInfo class in Presto Java:
11+
// https://github.com/prestodb/presto/blob/master/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfo.java
12+
// Unused fields are commented out for now.
13+
type QueryStateInfo struct {
14+
QueryId string `json:"queryId"`
15+
QueryState string `json:"queryState"`
16+
//ResourceGroupId []string `json:"resourceGroupId"`
17+
Query string `json:"query"`
18+
QueryTruncated bool `json:"queryTruncated"`
19+
CreateTime time.Time `json:"createTime"`
20+
User string `json:"user"`
21+
Authenticated bool `json:"authenticated"`
22+
Source string `json:"source"`
23+
Catalog string `json:"catalog"`
24+
//Progress struct {
25+
// ElapsedTimeMillis int `json:"elapsedTimeMillis"`
26+
// QueuedTimeMillis int `json:"queuedTimeMillis"`
27+
// ExecutionTimeMillis int `json:"executionTimeMillis"`
28+
// CpuTimeMillis int `json:"cpuTimeMillis"`
29+
// ScheduledTimeMillis int `json:"scheduledTimeMillis"`
30+
// CurrentMemoryBytes int `json:"currentMemoryBytes"`
31+
// PeakMemoryBytes int `json:"peakMemoryBytes"`
32+
// PeakTotalMemoryBytes int `json:"peakTotalMemoryBytes"`
33+
// PeakTaskTotalMemoryBytes int `json:"peakTaskTotalMemoryBytes"`
34+
// CumulativeUserMemory int `json:"cumulativeUserMemory"`
35+
// CumulativeTotalMemory int `json:"cumulativeTotalMemory"`
36+
// InputRows int `json:"inputRows"`
37+
// InputBytes int `json:"inputBytes"`
38+
// Blocked bool `json:"blocked"`
39+
// QueuedDrivers int `json:"queuedDrivers"`
40+
// RunningDrivers int `json:"runningDrivers"`
41+
// CompletedDrivers int `json:"completedDrivers"`
42+
//} `json:"progress"`
43+
//WarningCodes []interface{} `json:"warningCodes"`
44+
}
45+
46+
// GetQueryStatsOptions includes parameters in https://github.com/prestodb/presto/blob/a7af002182098ba5a61248edfcaaa66e5d50e489/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfoResource.java#L89-L95
47+
type GetQueryStatsOptions struct {
48+
User *string `query:"user"`
49+
IncludeLocalQueryOnly *bool `query:"includeLocalQueryOnly"`
50+
IncludeAllQueries *bool `query:"includeAllQueries"`
51+
IncludeAllQueryProgressStats *bool `query:"includeAllQueryProgressStats"`
52+
ExcludeResourceGroupPathInfo *bool `query:"excludeResourceGroupPathInfo"`
53+
QueryTextSizeLimit *int `query:"queryTextSizeLimit"`
54+
}
55+
56+
func (c *Client) GetQueryState(ctx context.Context, reqOpt *GetQueryStatsOptions, opts ...RequestOption) ([]QueryStateInfo, *http.Response, error) {
57+
req, err := c.NewRequest("GET",
58+
fmt.Sprintf("v1/queryState?%s", GenerateHttpQueryParameter(reqOpt)), nil, opts...)
59+
if err != nil {
60+
return nil, nil, err
61+
}
62+
63+
infoArray := make([]QueryStateInfo, 0, 16)
64+
resp, err := c.Do(ctx, req, infoArray)
65+
if err != nil {
66+
return nil, resp, err
67+
}
68+
return infoArray, resp, nil
69+
}

stage/stage.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ type Stage struct {
4848
// When RandomExecution is turned on, we randomly pick queries to run until a certain number of queries/a specific
4949
// duration has passed. Expected row counts will not be checked in this mode because we cannot figure out the correct
5050
// expected row count offset.
51-
RandomExecution bool `json:"random_execution,omitempty"`
51+
RandomExecution *bool `json:"random_execution,omitempty"`
5252
// Use RandomlyExecuteUntil to specify a duration like "1h" or an integer as the number of queries should be executed
5353
// before exiting.
54-
RandomlyExecuteUntil string `json:"randomly_execute_until,omitempty"`
54+
RandomlyExecuteUntil *string `json:"randomly_execute_until,omitempty"`
5555
// If not set, the default is 1. The default value is set when the stage is run.
5656
ColdRuns int `json:"cold_runs,omitempty"`
5757
// If not set, the default is 0.
@@ -225,7 +225,7 @@ func (s *Stage) run(ctx context.Context) (returnErr error) {
225225
s.prepareClient()
226226
s.propagateStates()
227227
if len(s.Queries)+len(s.QueryFiles) > 0 {
228-
if s.RandomExecution {
228+
if *s.RandomExecution {
229229
returnErr = s.runRandomly(ctx)
230230
} else {
231231
returnErr = s.runSequentially(ctx)
@@ -306,17 +306,17 @@ func (s *Stage) runQueryFile(ctx context.Context, queryFile string, expectedRowC
306306

307307
func (s *Stage) runRandomly(ctx context.Context) error {
308308
var continueExecution func(queryCount int) bool
309-
if dur, parseErr := time.ParseDuration(s.RandomlyExecuteUntil); parseErr == nil {
309+
if dur, parseErr := time.ParseDuration(*s.RandomlyExecuteUntil); parseErr == nil {
310310
endTime := time.Now().Add(dur)
311311
continueExecution = func(_ int) bool {
312312
return time.Now().Before(endTime)
313313
}
314-
} else if count, atoiErr := strconv.Atoi(s.RandomlyExecuteUntil); atoiErr == nil {
314+
} else if count, atoiErr := strconv.Atoi(*s.RandomlyExecuteUntil); atoiErr == nil {
315315
continueExecution = func(queryCount int) bool {
316316
return queryCount <= count
317317
}
318318
} else {
319-
err := fmt.Errorf("failed to parse randomly_execute_until %s", s.RandomlyExecuteUntil)
319+
err := fmt.Errorf("failed to parse randomly_execute_until %s", *s.RandomlyExecuteUntil)
320320
if *s.AbortOnError {
321321
s.States.exitCode.CompareAndSwap(0, 5)
322322
return err

stage/stage_utils.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,12 @@ func (s *Stage) MergeWith(other *Stage) *Stage {
6060
delete(s.ExpectedRowCounts, k)
6161
}
6262
}
63-
s.RandomExecution = other.RandomExecution
64-
s.RandomlyExecuteUntil = other.RandomlyExecuteUntil
63+
if other.RandomExecution != nil {
64+
s.RandomExecution = other.RandomExecution
65+
}
66+
if other.RandomlyExecuteUntil != nil {
67+
s.RandomlyExecuteUntil = other.RandomlyExecuteUntil
68+
}
6569
if other.ColdRuns > 0 {
6670
s.ColdRuns = other.ColdRuns
6771
}
@@ -206,6 +210,12 @@ func (s *Stage) propagateStates() {
206210
if nextStage.TimeZone == nil {
207211
nextStage.TimeZone = s.TimeZone
208212
}
213+
if nextStage.RandomExecution == nil {
214+
nextStage.RandomExecution = s.RandomExecution
215+
}
216+
if nextStage.RandomlyExecuteUntil == nil {
217+
nextStage.RandomlyExecuteUntil = s.RandomlyExecuteUntil
218+
}
209219
for k, v := range s.SessionParams {
210220
if v == nil {
211221
continue

0 commit comments

Comments
 (0)