@@ -2,12 +2,17 @@ package forward
2
2
3
3
import (
4
4
"context"
5
- "fmt"
6
5
"github.com/spf13/cobra"
6
+ "net/http"
7
+ "os"
8
+ "os/signal"
9
+ "path/filepath"
7
10
"pbench/log"
8
11
"pbench/presto"
9
12
"pbench/utils"
10
13
"sync"
14
+ "sync/atomic"
15
+ "syscall"
11
16
"time"
12
17
)
13
18
@@ -17,49 +22,141 @@ var (
17
22
RunName string
18
23
PollInterval time.Duration
19
24
20
- runningTasks sync.WaitGroup
25
+ runningTasks sync.WaitGroup
26
+ failedToForward atomic.Uint32
27
+ forwarded atomic.Uint32
21
28
)
22
29
23
- type QueryHistory struct {
24
- QueryId string `presto:"query_id"`
25
- Query string `presto:"query"`
26
- Created * time.Time `presto:"created"`
27
- }
28
-
29
30
func Run (_ * cobra.Command , _ []string ) {
30
- //OutputPath = filepath.Join(OutputPath, RunName)
31
- //utils.PrepareOutputDirectory(OutputPath)
32
- //
33
- //// also start to write logs to the output directory from this point on.
34
- //logPath := filepath.Join(OutputPath, "forward.log")
35
- //flushLog := utils.InitLogFile(logPath)
36
- //defer flushLog()
31
+ OutputPath = filepath .Join (OutputPath , RunName )
32
+ utils .PrepareOutputDirectory (OutputPath )
37
33
38
- prestoClusters := PrestoFlagsArray .Assemble ()
34
+ // also start to write logs to the output directory from this point on.
35
+ logPath := filepath .Join (OutputPath , "forward.log" )
36
+ flushLog := utils .InitLogFile (logPath )
37
+ defer flushLog ()
38
+
39
+ ctx , cancel := context .WithCancel (context .Background ())
40
+ timeToExit := make (chan os.Signal , 1 )
41
+ signal .Notify (timeToExit , syscall .SIGINT , syscall .SIGTERM , syscall .SIGQUIT )
42
+ // Handle SIGINT, SIGTERM, and SIGQUIT. When ctx is canceled, in-progress MySQL transactions and InfluxDB operations will roll back.
43
+ go func () {
44
+ sig := <- timeToExit
45
+ if sig != nil {
46
+ log .Info ().Msg ("abort forwarding" )
47
+ cancel ()
48
+ }
49
+ }()
50
+
51
+ prestoClusters := PrestoFlagsArray .Pivot ()
39
52
// The design here is to forward the traffic from cluster 0 to the rest.
40
53
sourceClusterSize := 0
41
54
clients := make ([]* presto.Client , 0 , len (prestoClusters ))
42
55
for i , cluster := range prestoClusters {
43
56
clients = append (clients , cluster .NewPrestoClient ())
44
- if stats , _ , err := clients [i ].GetClusterInfo (context .Background ()); err != nil {
45
- log .Fatal ().Err (err ).Msgf ("cannot connect to cluster at position %d" , i )
57
+ // Check if we can connect to the cluster.
58
+ if stats , _ , err := clients [i ].GetClusterInfo (ctx ); err != nil {
59
+ log .Fatal ().Err (err ).Msgf ("cannot connect to cluster at position %d: %s" , i , cluster .ServerUrl )
46
60
} else if i == 0 {
47
61
sourceClusterSize = stats .ActiveWorkers
48
62
} else if stats .ActiveWorkers != sourceClusterSize {
49
- log .Warn ().Msgf ("source cluster size does not match target cluster %d size (%d != %d)" , i , stats .ActiveWorkers , sourceClusterSize )
63
+ log .Warn ().Msgf ("the source cluster and target cluster %d do not match in size (%d != %d)" , i , sourceClusterSize , stats .ActiveWorkers )
50
64
}
51
65
}
52
66
53
67
sourceClient := clients [0 ]
54
68
trueValue := true
55
- states , _ , err := sourceClient .GetQueryState (context .Background (), & presto.GetQueryStatsOptions {
56
- IncludeAllQueries : & trueValue ,
57
- IncludeAllQueryProgressStats : nil ,
58
- ExcludeResourceGroupPathInfo : nil ,
59
- QueryTextSizeLimit : nil ,
60
- })
61
- if err != nil {
62
- log .Fatal ().Err (err ).Msgf ("cannot get query states" )
69
+ // lastQueryStateCheckCutoffTime is the query create time of the most recent query in the previous batch.
70
+ // We only look at queries created later than this timestamp in the following batch.
71
+ lastQueryStateCheckCutoffTime := time.Time {}
72
+ firstBatch := true
73
+ // Keep running until the source cluster becomes unavailable or the user interrupts or quits using Ctrl + C or Ctrl + D.
74
+ for ctx .Err () == nil {
75
+ states , _ , err := sourceClient .GetQueryState (ctx , & presto.GetQueryStatsOptions {IncludeAllQueries : & trueValue })
76
+ if err != nil {
77
+ log .Error ().Err (err ).Msgf ("failed to get query states" )
78
+ break
79
+ }
80
+ newCutoffTime := time.Time {}
81
+ for _ , state := range states {
82
+ if ! state .CreateTime .After (lastQueryStateCheckCutoffTime ) {
83
+ // We looked at this query in the previous batch.
84
+ continue
85
+ }
86
+ if newCutoffTime .Before (state .CreateTime ) {
87
+ newCutoffTime = state .CreateTime
88
+ }
89
+ if ! firstBatch {
90
+ runningTasks .Add (1 )
91
+ go forwardQuery (ctx , & state , clients )
92
+ }
93
+ }
94
+ firstBatch = false
95
+ if newCutoffTime .After (lastQueryStateCheckCutoffTime ) {
96
+ lastQueryStateCheckCutoffTime = newCutoffTime
97
+ }
98
+ timer := time .NewTimer (PollInterval )
99
+ select {
100
+ case <- ctx .Done ():
101
+ case <- timer .C :
102
+ }
103
+ }
104
+ runningTasks .Wait ()
105
+ // This causes the signal handler to exit.
106
+ close (timeToExit )
107
+ log .Info ().Uint32 ("forwarded" , forwarded .Load ()).Uint32 ("failed_to_forward" , failedToForward .Load ()).
108
+ Msgf ("finished forwarding queries" )
109
+ }
110
+
111
+ func forwardQuery (ctx context.Context , queryState * presto.QueryStateInfo , clients []* presto.Client ) {
112
+ defer runningTasks .Done ()
113
+ queryInfo , _ , queryInfoErr := clients [0 ].GetQueryInfo (ctx , queryState .QueryId , false , nil )
114
+ if queryInfoErr != nil {
115
+ log .Error ().Str ("query_id" , queryState .QueryId ).Err (queryInfoErr ).Msg ("failed to get query info for forwarding" )
116
+ failedToForward .Add (1 )
117
+ return
118
+ }
119
+ SessionPropertyHeader := clients [0 ].GenerateSessionParamsHeaderValue (queryInfo .Session .CollectSessionProperties ())
120
+ successful , failed := atomic.Uint32 {}, atomic.Uint32 {}
121
+ forwardedQueries := sync.WaitGroup {}
122
+ for i := 1 ; i < len (clients ); i ++ {
123
+ forwardedQueries .Add (1 )
124
+ go func (client * presto.Client ) {
125
+ defer forwardedQueries .Done ()
126
+ clientResult , _ , queryErr := client .Query (ctx , queryInfo .Query , func (req * http.Request ) {
127
+ if queryInfo .Session .Catalog != nil {
128
+ req .Header .Set (presto .CatalogHeader , * queryInfo .Session .Catalog )
129
+ }
130
+ if queryInfo .Session .Schema != nil {
131
+ req .Header .Set (presto .SchemaHeader , * queryInfo .Session .Schema )
132
+ }
133
+ req .Header .Set (presto .SessionHeader , SessionPropertyHeader )
134
+ req .Header .Set (presto .SourceHeader , queryInfo .QueryId )
135
+ })
136
+ if queryErr != nil {
137
+ log .Error ().Str ("source_query_id" , queryInfo .QueryId ).
138
+ Str ("target_host" , client .GetHost ()).Err (queryErr ).Msg ("failed to execute query" )
139
+ failed .Add (1 )
140
+ return
141
+ }
142
+ rowCount := 0
143
+ drainErr := clientResult .Drain (ctx , func (qr * presto.QueryResults ) error {
144
+ rowCount += len (qr .Data )
145
+ return nil
146
+ })
147
+ if drainErr != nil {
148
+ log .Error ().Str ("source_query_id" , queryInfo .QueryId ).
149
+ Str ("target_host" , client .GetHost ()).Err (drainErr ).Msg ("failed to fetch query result" )
150
+ failed .Add (1 )
151
+ return
152
+ }
153
+ successful .Add (1 )
154
+ log .Info ().Str ("source_query_id" , queryInfo .QueryId ).
155
+ Str ("target_host" , client .GetHost ()).Int ("row_count" , rowCount ).Msg ("query executed successfully" )
156
+ }(clients [i ])
63
157
}
64
- fmt .Printf ("%#v" , states )
158
+ forwardedQueries .Wait ()
159
+ log .Info ().Str ("source_query_id" , queryInfo .QueryId ).Uint32 ("successful" , successful .Load ()).
160
+ Uint32 ("failed" , failed .Load ()).Msg ("query forwarding finished" )
161
+ forwarded .Add (1 )
65
162
}
0 commit comments