Skip to content

Commit e0e1340

Browse files
committed
Use separate runstates for each run
Clients sending KE messages after timeout occurred were incrementing the KE count for later runs, which would trigger the server to rebroadcast all KE messages before the last peer's message was received. This then caused the peer which may have responded before KE timeout actually occured to be erronously removed and a new run started, and again wrongly increment the KE count for the next run. This would continue until all peers were removed and the mix was aborted. Using separate runstates (message counts and signal channels) keyed by the run prevents this issue, as clients will only ever modify the runstate for the run they are performing.
1 parent 65afb08 commit e0e1340

File tree

1 file changed

+80
-79
lines changed

1 file changed

+80
-79
lines changed

server/server.go

Lines changed: 80 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,6 @@ type runState struct {
9797
confCount uint32
9898
rsCount uint32
9999

100-
run int
101-
mtot int
102-
clients []*client
103-
excluded []*client
104-
vk []ed25519.PublicKey
105-
mcounts []int
106-
roots []*big.Int
107-
108100
allKEs chan struct{}
109101
allSRs chan struct{}
110102
allDCs chan struct{}
@@ -116,7 +108,7 @@ type runState struct {
116108
}
117109

118110
type session struct {
119-
runState
111+
runs []runState
120112

121113
sid []byte
122114
msgses *messages.Session
@@ -125,6 +117,14 @@ type session struct {
125117
newm func() (Mixer, error)
126118
mix Mixer
127119

120+
run int
121+
mtot int
122+
clients []*client
123+
excluded []*client
124+
vk []ed25519.PublicKey
125+
mcounts []int
126+
roots []*big.Int
127+
128128
pids map[string]int
129129
mu sync.Mutex
130130

@@ -422,28 +422,28 @@ func (s *Server) pairSessions(ctx context.Context) error {
422422
}
423423
sid := s.sidPRNG.Next(32)
424424
ses := &session{
425-
runState: runState{
426-
allKEs: make(chan struct{}),
427-
allSRs: make(chan struct{}),
428-
allDCs: make(chan struct{}),
429-
allConfs: make(chan struct{}),
430-
allRSs: make(chan struct{}),
431-
mtot: totalMessages,
432-
clients: clients,
433-
vk: vk,
434-
mcounts: mcounts,
435-
blaming: make(chan struct{}),
436-
rerunning: make(chan struct{}),
437-
},
438-
sid: sid,
439-
msgses: messages.NewSession(sid, 0, vk),
440-
br: messages.BeginRun(vk, mcounts, sid),
441-
msize: s.msize,
442-
newm: newm,
443-
mix: mix,
444-
pids: pids,
445-
report: s.report,
425+
sid: sid,
426+
msgses: messages.NewSession(sid, 0, vk),
427+
br: messages.BeginRun(vk, mcounts, sid),
428+
msize: s.msize,
429+
newm: newm,
430+
mix: mix,
431+
mtot: totalMessages,
432+
clients: clients,
433+
vk: vk,
434+
mcounts: mcounts,
435+
pids: pids,
436+
report: s.report,
446437
}
438+
ses.runs = append(ses.runs, runState{
439+
allKEs: make(chan struct{}),
440+
allSRs: make(chan struct{}),
441+
allDCs: make(chan struct{}),
442+
allConfs: make(chan struct{}),
443+
allRSs: make(chan struct{}),
444+
blaming: make(chan struct{}),
445+
rerunning: make(chan struct{}),
446+
})
447447
pairs = append(pairs, ses)
448448
}
449449
s.pairingsMu.Unlock()
@@ -524,7 +524,7 @@ func (s *session) exclude(blamed []int) error {
524524
defer s.mu.Unlock()
525525
s.mu.Lock()
526526

527-
close(s.rerunning)
527+
close(s.runs[s.run].rerunning)
528528
s.run++
529529

530530
log.Printf("excluding %v", blamed)
@@ -574,22 +574,19 @@ func (s *session) exclude(blamed []int) error {
574574
return err
575575
}
576576

577-
s.keCount = 0
578-
s.srCount = 0
579-
s.dcCount = 0
580-
s.confCount = 0
581-
s.rsCount = 0
582-
s.allKEs = make(chan struct{})
583-
s.allSRs = make(chan struct{})
584-
s.allDCs = make(chan struct{})
585-
s.allConfs = make(chan struct{})
586-
s.allRSs = make(chan struct{})
577+
s.runs = append(s.runs, runState{
578+
allKEs: make(chan struct{}),
579+
allSRs: make(chan struct{}),
580+
allDCs: make(chan struct{}),
581+
allConfs: make(chan struct{}),
582+
allRSs: make(chan struct{}),
583+
blaming: make(chan struct{}),
584+
rerunning: make(chan struct{}),
585+
})
587586
s.roots = nil
588587
s.msgses = messages.NewSession(s.sid, s.run, s.vk)
589588
s.br = messages.BeginRun(s.vk, s.mcounts, s.sid)
590589
s.mix = mix
591-
s.blaming = make(chan struct{})
592-
s.rerunning = make(chan struct{})
593590

594591
for _, c := range s.clients {
595592
select {
@@ -651,12 +648,13 @@ func (s *session) doRun(ctx context.Context) (err error) {
651648
}()
652649

653650
var blamed blamePIDs
651+
st := &s.runs[s.run]
654652

655653
// Wait for all KE messages, or KE timeout.
656654
select {
657655
case <-ctx.Done():
658656
return ctx.Err()
659-
case <-s.allKEs:
657+
case <-st.allKEs:
660658
log.Print("received all KE messages")
661659
case <-time.After(recvTimeout):
662660
log.Print("KE timeout")
@@ -699,15 +697,15 @@ func (s *session) doRun(ctx context.Context) (err error) {
699697
select {
700698
case <-ctx.Done():
701699
return ctx.Err()
702-
case <-s.allSRs:
700+
case <-st.allSRs:
703701
log.Print("received all SR messages")
704702
case <-time.After(recvTimeout):
705703
log.Print("SR timeout")
706704
}
707705

708706
// Solve roots.
709707
s.mu.Lock()
710-
blaming := s.blaming
708+
blaming := st.blaming
711709
vs := make([][]*big.Int, 0, len(s.clients))
712710
for i, c := range s.clients {
713711
if c.sr == nil {
@@ -748,7 +746,7 @@ func (s *session) doRun(ctx context.Context) (err error) {
748746
select {
749747
case <-ctx.Done():
750748
return ctx.Err()
751-
case <-s.allDCs:
749+
case <-st.allDCs:
752750
log.Print("received all DC messages")
753751
case <-time.After(recvTimeout):
754752
log.Print("DC timeout")
@@ -805,7 +803,7 @@ func (s *session) doRun(ctx context.Context) (err error) {
805803
select {
806804
case <-ctx.Done():
807805
return ctx.Err()
808-
case <-s.allConfs:
806+
case <-st.allConfs:
809807
log.Print("received all CM messages")
810808
case <-time.After(recvTimeout):
811809
log.Print("CM timeout")
@@ -899,13 +897,14 @@ func (c *client) run(ctx context.Context, run int, s *session, ke *messages.KE)
899897
log.Printf("recv(%v) KE Run:%d Commitment:%x", c.raddr(), ke.Run, ke.Commitment)
900898

901899
s.mu.Lock()
900+
st := &s.runs[run]
902901
c.ke = ke
903-
s.keCount++
904-
if s.keCount == uint32(len(s.clients)) {
905-
close(s.allKEs)
902+
st.keCount++
903+
if st.keCount == uint32(len(s.clients)) {
904+
close(st.allKEs)
906905
}
907-
blaming := s.blaming
908-
rerunning := s.rerunning
906+
blaming := st.blaming
907+
rerunning := st.rerunning
909908
s.mu.Unlock()
910909

911910
select {
@@ -916,7 +915,7 @@ func (c *client) run(ctx context.Context, run int, s *session, ke *messages.KE)
916915
if err != nil {
917916
return err
918917
}
919-
return c.blame(ctx, s)
918+
return c.blame(ctx, s, run)
920919
case kes := <-c.out:
921920
err := c.sendDeadline(kes, sendTimeout)
922921
if err != nil {
@@ -950,12 +949,12 @@ func (c *client) run(ctx context.Context, run int, s *session, ke *messages.KE)
950949
}
951950
}
952951
c.sr = sr
953-
s.srCount++
954-
if s.srCount == uint32(len(s.clients)) {
955-
close(s.allSRs)
952+
st.srCount++
953+
if st.srCount == uint32(len(s.clients)) {
954+
close(st.allSRs)
956955
}
957-
blaming = s.blaming
958-
rerunning = s.rerunning
956+
blaming = st.blaming
957+
rerunning = st.rerunning
959958
s.mu.Unlock()
960959

961960
select {
@@ -966,7 +965,7 @@ func (c *client) run(ctx context.Context, run int, s *session, ke *messages.KE)
966965
if err != nil {
967966
return err
968967
}
969-
return c.blame(ctx, s)
968+
return c.blame(ctx, s, run)
970969
case mix := <-c.out:
971970
err = c.sendDeadline(mix, sendTimeout)
972971
if err != nil {
@@ -998,17 +997,17 @@ func (c *client) run(ctx context.Context, run int, s *session, ke *messages.KE)
998997

999998
s.mu.Lock()
1000999
c.dc = dc
1001-
s.dcCount++
1002-
if s.dcCount == uint32(len(s.clients)) {
1003-
close(s.allDCs)
1000+
st.dcCount++
1001+
if st.dcCount == uint32(len(s.clients)) {
1002+
close(st.allDCs)
10041003
}
10051004
mix := c.mix
1006-
blaming = s.blaming
1007-
rerunning = s.rerunning
1005+
blaming = st.blaming
1006+
rerunning = st.rerunning
10081007
s.mu.Unlock()
10091008

10101009
if dc.RevealSecrets {
1011-
return c.blame(ctx, s)
1010+
return c.blame(ctx, s, run)
10121011
}
10131012

10141013
// Send unconfirmed mix
@@ -1020,7 +1019,7 @@ func (c *client) run(ctx context.Context, run int, s *session, ke *messages.KE)
10201019
if err != nil {
10211020
return err
10221021
}
1023-
return c.blame(ctx, s)
1022+
return c.blame(ctx, s, run)
10241023
case mix := <-c.out:
10251024
err = c.sendDeadline(mix, sendTimeout)
10261025
if err != nil {
@@ -1044,16 +1043,16 @@ func (c *client) run(ctx context.Context, run int, s *session, ke *messages.KE)
10441043
s.mu.Lock()
10451044
c.cm = cm
10461045
c.mix = mix
1047-
s.confCount++
1048-
if s.confCount == uint32(len(s.clients)) {
1049-
close(s.allConfs)
1046+
st.confCount++
1047+
if st.confCount == uint32(len(s.clients)) {
1048+
close(st.allConfs)
10501049
}
1051-
blaming = s.blaming
1052-
rerunning = s.rerunning
1050+
blaming = st.blaming
1051+
rerunning = st.rerunning
10531052
s.mu.Unlock()
10541053

10551054
if cm.RevealSecrets {
1056-
return c.blame(ctx, s)
1055+
return c.blame(ctx, s, run)
10571056
}
10581057

10591058
// Send signed mix
@@ -1065,7 +1064,7 @@ func (c *client) run(ctx context.Context, run int, s *session, ke *messages.KE)
10651064
if err != nil {
10661065
return err
10671066
}
1068-
return c.blame(ctx, s)
1067+
return c.blame(ctx, s, run)
10691068
case out := <-c.out:
10701069
err = c.sendDeadline(out, sendTimeout)
10711070
if err != nil {
@@ -1135,10 +1134,11 @@ func (s *session) blame(ctx context.Context, reported []int) (err error) {
11351134
}
11361135

11371136
// Wait for all secrets, or timeout.
1137+
st := &s.runs[s.run]
11381138
select {
11391139
case <-ctx.Done():
11401140
return ctx.Err()
1141-
case <-s.allRSs:
1141+
case <-st.allRSs:
11421142
log.Print("received all RS messages")
11431143
case <-time.After(5000 * time.Millisecond):
11441144
s.mu.Lock()
@@ -1307,18 +1307,19 @@ DCLoop:
13071307

13081308
var errRerun = errors.New("rerun")
13091309

1310-
func (c *client) blame(ctx context.Context, s *session) error {
1310+
func (c *client) blame(ctx context.Context, s *session, run int) error {
13111311
rs := new(messages.RS)
13121312
err := c.readDeadline(rs, recvTimeout)
13131313
if err != nil {
13141314
return err
13151315
}
13161316

13171317
s.mu.Lock()
1318+
st := &s.runs[run]
13181319
c.rs = rs
1319-
s.rsCount++
1320-
if s.rsCount == uint32(len(s.clients)) {
1321-
close(s.allRSs)
1320+
st.rsCount++
1321+
if st.rsCount == uint32(len(s.clients)) {
1322+
close(st.allRSs)
13221323
}
13231324
s.mu.Unlock()
13241325

0 commit comments

Comments
 (0)