Skip to content

Commit c1bf0af

Browse files
author
ffffwh
committed
write waitCh only once #1021
1 parent 1e4bea6 commit c1bf0af

File tree

7 files changed

+73
-37
lines changed

7 files changed

+73
-37
lines changed

driver/common/utils.go

+8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package common
88

99
import (
1010
"fmt"
11+
"github.com/hashicorp/nomad/plugins/drivers"
1112
"github.com/pingcap/tidb/parser/format"
1213
"regexp"
1314
"strconv"
@@ -37,3 +38,10 @@ func MysqlVersionInDigit(v string) (int, error) {
3738

3839
return m0*10000 + m1*100 + m2, nil
3940
}
41+
42+
func WriteWaitCh(ch chan<- *drivers.ExitResult, r *drivers.ExitResult) {
43+
select {
44+
case ch<-r:
45+
default:
46+
}
47+
}

driver/driver.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -655,14 +655,12 @@ func (d *Driver) handleWait(ctx context.Context, handle *taskHandle, ch chan *dr
655655
return
656656
case <-d.ctx.Done():
657657
return
658-
case result := <-handle.waitCh: // Do not refer to handle.runner.waitCh. It might be nil.
659-
handle.stateLock.Lock()
660-
handle.procState = drivers.TaskStateExited
661-
handle.stateLock.Unlock()
662-
ch <- result
658+
case <-handle.doneCh:
659+
ch <- handle.exitResult.Copy()
663660
}
664661
}
665662

663+
// StopTask will not be called if the task has already exited (e.g. onError).
666664
func (d *Driver) StopTask(taskID string, timeout time.Duration, signal string) error {
667665
d.logger.Info("StopTask", "id", taskID, "signal", signal)
668666
handle, ok := d.tasks.Get(taskID)
@@ -785,10 +783,15 @@ func (d *Driver) SignalTask(taskID string, signal string) error {
785783
return errors.New(string(bs))
786784
}
787785
case "finish":
786+
if h.runner == nil {
787+
return fmt.Errorf("h.runner is nil")
788+
}
788789
return h.runner.Finish1()
789790
case "pause":
790791
d.logger.Info("pause a task", "taskID", taskID)
791-
h := d.tasks.store[taskID]
792+
if h.runner == nil {
793+
return fmt.Errorf("h.runner is nil")
794+
}
792795
err := h.runner.Shutdown()
793796
if err != nil {
794797
d.logger.Error("error when pausing a task", "taskID", taskID, "err", err)
@@ -798,7 +801,6 @@ func (d *Driver) SignalTask(taskID string, signal string) error {
798801
return nil
799802
case "resume":
800803
d.logger.Info("resume a task", "taskID", taskID)
801-
h := d.tasks.store[taskID]
802804
err := h.resumeTask(d)
803805
if err != nil {
804806
d.logger.Error("error when resuming a task", "taskID", taskID, "err", err)
@@ -809,7 +811,7 @@ func (d *Driver) SignalTask(taskID string, signal string) error {
809811
return nil
810812
}
811813

812-
return nil
814+
return fmt.Errorf("DTLE_BUG SignalTask. should not reach here")
813815
}
814816

815817
func (d *Driver) ExecTask(taskID string, cmd []string, timeout time.Duration) (*drivers.ExecTaskResult, error) {

driver/handle.go

+33-10
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type taskHandle struct {
3737
stats *common.TaskStatistics
3838

3939
driverConfig *common.MySQLDriverConfig
40+
shutdown bool
4041
}
4142

4243
func newDtleTaskHandle(logger g.LoggerType, cfg *drivers.TaskConfig, state drivers.TaskState, started time.Time) *taskHandle {
@@ -51,9 +52,21 @@ func newDtleTaskHandle(logger g.LoggerType, cfg *drivers.TaskConfig, state drive
5152
waitCh: make(chan *drivers.ExitResult),
5253
doneCh: make(chan struct{}),
5354
}
55+
go h.watchWaitCh()
5456
return h
5557
}
5658

59+
func (h *taskHandle) watchWaitCh() {
60+
select {
61+
case r := <-h.waitCh:
62+
h.stateLock.Lock()
63+
h.exitResult = r
64+
h.stateLock.Unlock()
65+
close(h.doneCh)
66+
case <-h.doneCh:
67+
}
68+
}
69+
5770
func (h *taskHandle) TaskStatus() (*drivers.TaskStatus, error) {
5871
h.stateLock.RLock()
5972
defer h.stateLock.RUnlock()
@@ -81,13 +94,14 @@ func (h *taskHandle) TaskStatus() (*drivers.TaskStatus, error) {
8194
}, nil
8295
}
8396

97+
// used when h.runner has not been setup
8498
func (h *taskHandle) onError(err error) {
85-
h.waitCh <- &drivers.ExitResult{
99+
common.WriteWaitCh(h.waitCh, &drivers.ExitResult{
86100
ExitCode: common.TaskStateDead,
87101
Signal: 0,
88102
OOMKilled: false,
89103
Err: err,
90-
}
104+
})
91105
}
92106

93107
func (h *taskHandle) IsRunning() bool {
@@ -136,7 +150,7 @@ func (h *taskHandle) run(d *Driver) {
136150
for {
137151
select {
138152
case <-h.doneCh:
139-
t.Stop()
153+
if !t.Stop() { <-t.C }
140154
return
141155
case <-t.C:
142156
if h.runner != nil {
@@ -264,26 +278,35 @@ func (h *taskHandle) emitStats(ru *common.TaskStatistics) {
264278
}
265279
}
266280

267-
func (h *taskHandle) Destroy() bool {
268-
h.stateLock.RLock()
269-
defer h.stateLock.RUnlock()
281+
func (h *taskHandle) Destroy() {
282+
if h.shutdown {
283+
return
284+
}
285+
h.stateLock.Lock()
286+
h.shutdown = true
287+
h.stateLock.Unlock()
270288

271-
close(h.doneCh)
289+
common.WriteWaitCh(h.waitCh, &drivers.ExitResult{
290+
ExitCode: 0,
291+
Signal: 0,
292+
OOMKilled: false,
293+
Err: nil,
294+
})
272295

273296
if h.runner != nil {
274297
err := h.runner.Shutdown()
275298
if err != nil {
276299
h.logger.Error("error in h.runner.Shutdown", "err", err)
277300
}
278301
}
279-
280-
return h.procState == drivers.TaskStateExited
281302
}
282303

283304
type DriverHandle interface {
284305
Run()
285306

286-
// Shutdown is used to stop the task
307+
// Shutdown is used to stop the task.
308+
// Do not send ExitResult in Shutdown().
309+
// pause API will call Shutdown and the task should not exit.
287310
Shutdown() error
288311

289312
// Stats returns aggregated stats of the driver

driver/kafka/kafka3.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ type KafkaRunner struct {
4646
natsConn *gonats.Conn
4747
waitCh chan *drivers.ExitResult
4848

49-
ctx context.Context
50-
shutdown bool
51-
shutdownCh chan struct{}
49+
ctx context.Context
50+
shutdown bool
51+
shutdownCh chan struct{}
52+
shutdownLock sync.Mutex
5253

5354
kafkaConfig *common.KafkaConfig
5455
kafkaMgr *KafkaManager
@@ -157,9 +158,14 @@ func (kr *KafkaRunner) updateGtidLoop() {
157158
}
158159

159160
func (kr *KafkaRunner) Shutdown() error {
161+
kr.logger.Debug("Shutting down")
162+
kr.shutdownLock.Lock()
163+
defer kr.shutdownLock.Unlock()
164+
160165
if kr.shutdown {
161166
return nil
162167
}
168+
163169
if kr.natsConn != nil {
164170
kr.natsConn.Close()
165171
}
@@ -660,12 +666,12 @@ func (kr *KafkaRunner) onError(state int, err error) {
660666
}
661667
}
662668

663-
kr.waitCh <- &drivers.ExitResult{
669+
common.WriteWaitCh(kr.waitCh, &drivers.ExitResult{
664670
ExitCode: state,
665671
Signal: 0,
666672
OOMKilled: false,
667673
Err: err,
668-
}
674+
})
669675
_ = kr.Shutdown()
670676
}
671677

driver/mysql/applier.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -1044,22 +1044,20 @@ func (a *Applier) onError(state int, err error) {
10441044
}
10451045

10461046
a.logger.Debug("onError. nats published")
1047-
// Do not send ExitResult in Shutdown().
1048-
// pause API will call Shutdown and the task should not exit.
1049-
a.waitCh <- &drivers.ExitResult{
1047+
common.WriteWaitCh(a.waitCh, &drivers.ExitResult{
10501048
ExitCode: state,
10511049
Signal: 0,
10521050
OOMKilled: false,
10531051
Err: err,
1054-
}
1052+
})
10551053
_ = a.Shutdown()
10561054
}
10571055

10581056
func (a *Applier) Shutdown() error {
1059-
a.logger.Info("Shutting down")
1060-
1057+
a.logger.Debug("Shutting down")
10611058
a.shutdownLock.Lock()
10621059
defer a.shutdownLock.Unlock()
1060+
10631061
if a.shutdown {
10641062
return nil
10651063
}
@@ -1083,7 +1081,7 @@ func (a *Applier) Shutdown() error {
10831081
_ = sql.CloseConns(a.dbs...)
10841082
a.logger.Debug("Shutdown. CloseConns. after")
10851083

1086-
a.logger.Info("Shutdown")
1084+
a.logger.Info("Shutting down")
10871085
return nil
10881086
}
10891087

driver/mysql/extractor.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ func (e *Extractor) initNatsPubClient(natsAddr string) (err error) {
641641

642642
switch ctrlMsg.Type {
643643
case common.ControlMsgError:
644-
e.onError(common.TaskStateDead, fmt.Errorf("applier error/restart: %v", ctrlMsg.Msg))
644+
e.onError(common.TaskStateDead, fmt.Errorf("applier error: %v", ctrlMsg.Msg))
645645
return
646646
}
647647
})
@@ -1564,12 +1564,12 @@ func (e *Extractor) onError(state int, err error) {
15641564
if e.shutdown {
15651565
return
15661566
}
1567-
e.waitCh <- &drivers.ExitResult{
1567+
common.WriteWaitCh(e.waitCh, &drivers.ExitResult{
15681568
ExitCode: state,
15691569
Signal: 0,
15701570
OOMKilled: false,
15711571
Err: err,
1572-
}
1572+
})
15731573
_ = e.Shutdown()
15741574
}
15751575
// Shutdown is used to tear down the extractor
@@ -1618,7 +1618,6 @@ func (e *Extractor) Shutdown() error {
16181618
e.logger.Error("Shutdown error close e.db.", "err", err)
16191619
}
16201620

1621-
//close(e.binlogChannel)
16221621
e.logger.Info("Shutting down")
16231622
return nil
16241623
}

driver/oracle/extractor/extractor_oracle.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ func (e *ExtractorOracle) Run() {
197197

198198
switch ctrlMsg.Type {
199199
case common.ControlMsgError:
200-
e.onError(common.TaskStateDead, fmt.Errorf("applier error/restart: %v", ctrlMsg.Msg))
200+
e.onError(common.TaskStateDead, fmt.Errorf("applier error: %v", ctrlMsg.Msg))
201201
return
202202
}
203203
})
@@ -823,12 +823,12 @@ func (e *ExtractorOracle) onError(state int, err error) {
823823
if e.shutdown {
824824
return
825825
}
826-
e.waitCh <- &drivers.ExitResult{
826+
common.WriteWaitCh(e.waitCh, &drivers.ExitResult{
827827
ExitCode: state,
828828
Signal: 0,
829829
OOMKilled: false,
830830
Err: err,
831-
}
831+
})
832832
_ = e.Shutdown()
833833
}
834834

0 commit comments

Comments
 (0)