Skip to content

Commit 48ffcbc

Browse files
authored
feat: when reloading, remain those servers whose upstream pull conf with net error (#16)
* feat: remain those servers whose upstream pull conf with net error when reloading * fix: potential nil pointer error * refine code * use brute-force-match instead of hash * test: test and fix * clean code
1 parent 2c71fd9 commit 48ffcbc

File tree

6 files changed

+104
-35
lines changed

6 files changed

+104
-35
lines changed

config/config.go

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package config
22

33
import (
44
"encoding/json"
5+
"errors"
56
"flag"
67
"fmt"
78
"github.com/Qv2ray/mmp-go/cipher"
89
"github.com/Qv2ray/mmp-go/infra/lru"
910
"log"
11+
"net"
1012
"os"
1113
"sync"
1214
"time"
@@ -17,17 +19,46 @@ type Config struct {
1719
Groups []Group `json:"groups"`
1820
}
1921
type Server struct {
20-
Target string `json:"target"`
21-
Method string `json:"method"`
22-
Password string `json:"password"`
23-
MasterKey []byte `json:"-"`
22+
Target string `json:"target"`
23+
Method string `json:"method"`
24+
Password string `json:"password"`
25+
MasterKey []byte `json:"-"`
26+
UpstreamConf *UpstreamConf `json:"-"`
2427
}
2528
type Group struct {
26-
Port int `json:"port"`
27-
Servers []Server `json:"servers"`
28-
Upstreams []map[string]string `json:"upstreams"`
29-
LRUSize int `json:"lruSize"`
30-
UserContextPool *UserContextPool `json:"-"`
29+
Port int `json:"port"`
30+
Servers []Server `json:"servers"`
31+
Upstreams []UpstreamConf `json:"upstreams"`
32+
UserContextPool *UserContextPool `json:"-"`
33+
}
34+
type UpstreamConf map[string]string
35+
36+
const (
37+
PullingErrorKey = "__pulling_error__"
38+
PullingErrorNetError = "net_error"
39+
)
40+
41+
func (uc UpstreamConf) InitPullingError() {
42+
if _, ok := uc[PullingErrorKey]; !ok {
43+
uc[PullingErrorKey] = ""
44+
}
45+
}
46+
47+
func (uc UpstreamConf) Equal(that UpstreamConf) bool {
48+
uc.InitPullingError()
49+
that.InitPullingError()
50+
if len(uc) != len(that) {
51+
return false
52+
}
53+
for k, v := range uc {
54+
if k == PullingErrorKey {
55+
continue
56+
}
57+
if vv, ok := that[k]; !ok || vv != v {
58+
return false
59+
}
60+
}
61+
return true
3162
}
3263

3364
const (
@@ -85,28 +116,34 @@ func parseUpstreams(config *Config) (err error) {
85116
logged := false
86117
for i := range config.Groups {
87118
g := &config.Groups[i]
88-
for j, u := range g.Upstreams {
119+
for j, upstreamConf := range g.Upstreams {
89120
var upstream Upstream
90-
switch u["type"] {
121+
switch upstreamConf["type"] {
91122
case "outline":
92123
var outline Outline
93-
err = Map2upstream(u, &outline)
124+
err = Map2Upstream(upstreamConf, &outline)
94125
if err != nil {
95126
return
96127
}
97128
upstream = outline
98129
default:
99-
return fmt.Errorf("unknown upstream type: %v", u["type"])
130+
return fmt.Errorf("unknown upstream type: %v", upstreamConf["type"])
100131
}
101132
if !logged {
102133
log.Println("pulling configures from upstreams...")
103134
logged = true
104135
}
105136
servers, err := upstream.GetServers()
106137
if err != nil {
107-
log.Printf("[warning] Failed to retrieve configure from groups[%d].upstreams[%d]: %v\n", i, j, err)
138+
if netError := new(net.Error); errors.As(err, netError) {
139+
upstreamConf[PullingErrorKey] = PullingErrorNetError
140+
}
141+
log.Printf("[warning] Failed to retrieve configure from groups[%d].upstreams[%d]: %v: %v\n", i, j, err, upstreamConf[PullingErrorKey])
108142
continue
109143
}
144+
for i := range servers {
145+
servers[i].UpstreamConf = &upstreamConf
146+
}
110147
g.Servers = append(g.Servers, servers...)
111148
}
112149
}

config/outline.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (outline Outline) getConfig() ([]byte, error) {
6666
// concatenate errors
6767
err = errs[0]
6868
for i := 1; i < len(errs); i++ {
69-
err = fmt.Errorf("%v; %v", err, errs[i])
69+
err = fmt.Errorf("%w; %s", err, errs[i].Error())
7070
}
7171
return nil, err
7272
}
@@ -83,7 +83,7 @@ func (outline Outline) getConfigFromLink() ([]byte, error) {
8383
}
8484
resp, err := client.Get(outline.Link)
8585
if err != nil {
86-
return nil, fmt.Errorf("getConfigFromLink failed: %v", err)
86+
return nil, fmt.Errorf("getConfigFromLink failed: %w", err)
8787
}
8888
defer resp.Body.Close()
8989
return io.ReadAll(resp.Body)
@@ -113,7 +113,7 @@ func (outline Outline) getConfigFromApi() ([]byte, error) {
113113
outline.ApiUrl = strings.TrimSuffix(outline.ApiUrl, "/")
114114
resp, err := client.Get(fmt.Sprintf("%v/access-keys", outline.ApiUrl))
115115
if err != nil {
116-
return nil, fmt.Errorf("getConfigFromLink failed: %v", err)
116+
return nil, fmt.Errorf("getConfigFromApi failed: %w", err)
117117
}
118118
defer resp.Body.Close()
119119
return io.ReadAll(resp.Body)
@@ -130,7 +130,7 @@ func (outline Outline) getConfigFromSSH() ([]byte, error) {
130130
if outline.SSHPrivateKey != "" {
131131
signer, err := ssh.ParsePrivateKey([]byte(outline.SSHPrivateKey))
132132
if err != nil {
133-
return nil, fmt.Errorf("parse privateKey error: %v", err)
133+
return nil, fmt.Errorf("parse privateKey error: %w", err)
134134
}
135135
authMethods = append(authMethods, ssh.PublicKeys(signer))
136136
}
@@ -151,18 +151,18 @@ func (outline Outline) getConfigFromSSH() ([]byte, error) {
151151
}
152152
client, err := ssh.Dial("tcp", net.JoinHostPort(outline.Server, port), conf)
153153
if err != nil {
154-
return nil, fmt.Errorf("failed to dial: %v", err)
154+
return nil, fmt.Errorf("failed to dial: %w", err)
155155
}
156156
defer client.Close()
157157

158158
session, err := client.NewSession()
159159
if err != nil {
160-
return nil, fmt.Errorf("failed to create session: %v", err)
160+
return nil, fmt.Errorf("failed to create session: %w", err)
161161
}
162162
defer session.Close()
163163
out, err := session.CombinedOutput("cat /opt/outline/persisted-state/shadowbox_config.json")
164164
if err != nil {
165-
err = fmt.Errorf("%v: %v", string(bytes.TrimSpace(out)), err)
165+
err = fmt.Errorf("%v: %w", string(bytes.TrimSpace(out)), err)
166166
return nil, err
167167
}
168168
return out, nil
@@ -171,7 +171,7 @@ func (outline Outline) getConfigFromSSH() ([]byte, error) {
171171
func (outline Outline) GetServers() (servers []Server, err error) {
172172
defer func() {
173173
if err != nil {
174-
err = fmt.Errorf("outline.GetGroups: %v", err)
174+
err = fmt.Errorf("outline.GetGroups: %w", err)
175175
}
176176
}()
177177
b, err := outline.getConfig()

config/upstream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ type Upstream interface {
1111

1212
var InvalidUpstreamErr = fmt.Errorf("invalid upstream")
1313

14-
func Map2upstream(m map[string]string, upstream interface{}) error {
14+
func Map2Upstream(m map[string]string, upstream interface{}) error {
1515
v := reflect.ValueOf(upstream)
1616
if !v.IsValid() {
1717
return fmt.Errorf("upstream should not be nil")

dispatcher/tcp/tcp.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,8 @@ func (d *TCP) Listen() (err error) {
4444
for {
4545
conn, err := d.l.Accept()
4646
if err != nil {
47-
switch err := err.(type) {
48-
case *net.OpError:
49-
if errors.Is(err.Unwrap(), net.ErrClosed) {
50-
return nil
51-
}
47+
if errors.Is(err, net.ErrClosed) {
48+
return nil
5249
}
5350
log.Printf("[error] ReadFrom: %v", err)
5451
continue
@@ -89,7 +86,7 @@ func (d *TCP) handleConn(conn net.Conn) error {
8986
defer pool.Put(buf)
9087
n, err := io.ReadFull(conn, data)
9188
if err != nil {
92-
return fmt.Errorf("[tcp] handleConn readfull error: %v", err)
89+
return fmt.Errorf("[tcp] handleConn readfull error: %w", err)
9390
}
9491

9592
// get user's context (preference)
@@ -110,13 +107,13 @@ func (d *TCP) handleConn(conn net.Conn) error {
110107
// dial and relay
111108
rc, err := net.Dial("tcp", server.Target)
112109
if err != nil {
113-
return fmt.Errorf("[tcp] handleConn dial error: %v", err)
110+
return fmt.Errorf("[tcp] handleConn dial error: %w", err)
114111
}
115112

116113
_ = rc.SetDeadline(time.Now().Add(DefaultTimeout))
117114
_, err = rc.Write(data[:n])
118115
if err != nil {
119-
return fmt.Errorf("[tcp] handleConn write error: %v", err)
116+
return fmt.Errorf("[tcp] handleConn write error: %w", err)
120117
}
121118

122119
log.Printf("[tcp] %s <-> %s <-> %s", conn.RemoteAddr(), conn.LocalAddr(), rc.RemoteAddr())
@@ -125,7 +122,7 @@ func (d *TCP) handleConn(conn net.Conn) error {
125122
if err, ok := err.(net.Error); ok && err.Timeout() {
126123
return nil // ignore i/o timeout
127124
}
128-
return fmt.Errorf("[tcp] handleConn relay error: %v", err)
125+
return fmt.Errorf("[tcp] handleConn relay error: %w", err)
129126
}
130127
return nil
131128
}

dispatcher/udp/udp.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ func (d *UDP) handleConn(laddr net.Addr, data []byte, n int) (err error) {
8181
if err == AuthFailedErr {
8282
return nil
8383
}
84-
return fmt.Errorf("[udp] handleConn dial target error: %v", err)
84+
return fmt.Errorf("[udp] handleConn dial target error: %w", err)
8585
}
8686

8787
// send packet
8888
if _, err = rc.Write(data[:n]); err != nil {
89-
return fmt.Errorf("[udp] handleConn write error: %v", err)
89+
return fmt.Errorf("[udp] handleConn write error: %w", err)
9090
}
9191
return nil
9292
}
@@ -136,7 +136,7 @@ func (d *UDP) GetOrBuildUCPConn(laddr net.Addr, data []byte) (rc *net.UDPConn, e
136136
d.nm.Lock()
137137
d.nm.Remove(socketIdent) // close channel to inform that establishment ends
138138
d.nm.Unlock()
139-
return nil, fmt.Errorf("GetOrBuildUCPConn dial error: %v", err)
139+
return nil, fmt.Errorf("GetOrBuildUCPConn dial error: %w", err)
140140
}
141141
rc = rconn.(*net.UDPConn)
142142
d.nm.Lock()

reload.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,46 @@ func ReloadConfig() {
1717

1818
// rebuild config
1919
confPath := config.GetConfig().ConfPath
20+
oldConf := config.GetConfig()
2021
newConf, err := config.BuildConfig(confPath)
2122
if err != nil {
2223
log.Printf("failed to reload configuration: %v", err)
2324
return
2425
}
26+
// check if there is any net error when pulling the upstream configurations
27+
for i := range newConf.Groups {
28+
newGroup := &newConf.Groups[i]
29+
for j := range newGroup.Upstreams {
30+
newUpstream := newGroup.Upstreams[j]
31+
if newUpstream[config.PullingErrorKey] != config.PullingErrorNetError {
32+
continue
33+
}
34+
// net error, remain those servers
35+
36+
// find the group in the oldConf
37+
var oldGroup *config.Group
38+
for k := range oldConf.Groups {
39+
// they should have the same port
40+
if oldConf.Groups[k].Port != newGroup.Port {
41+
continue
42+
}
43+
oldGroup = &oldConf.Groups[k]
44+
break
45+
}
46+
if oldGroup == nil {
47+
// cannot find the corresponding old group
48+
continue
49+
}
50+
// check if upstreamConf can match
51+
for k := range oldGroup.Servers {
52+
oldServer := oldGroup.Servers[k]
53+
if oldServer.UpstreamConf != nil && newUpstream.Equal(*oldServer.UpstreamConf) {
54+
// remain the server
55+
newGroup.Servers = append(newGroup.Servers, oldServer)
56+
}
57+
}
58+
}
59+
}
2560
config.SetConfig(newConf)
2661
c := newConf
2762

0 commit comments

Comments
 (0)