@@ -10,6 +10,7 @@ import (
1010 "errors"
1111 "net"
1212 "net/netip"
13+ "runtime"
1314 "strconv"
1415 "sync"
1516 "syscall"
@@ -22,16 +23,21 @@ var (
2223 _ Bind = (* StdNetBind )(nil )
2324)
2425
25- // StdNetBind implements Bind for all platforms except Windows.
26+ // StdNetBind implements Bind for all platforms. While Windows has its own Bind
27+ // (see bind_windows.go), it may fall back to StdNetBind.
28+ // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
29+ // methods for sending and receiving multiple datagrams per-syscall. See the
30+ // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
2631type StdNetBind struct {
27- mu sync.Mutex // protects following fields
28- ipv4 * net.UDPConn
29- ipv6 * net.UDPConn
30- blackhole4 bool
31- blackhole6 bool
32- ipv4PC * ipv4.PacketConn
33- ipv6PC * ipv6.PacketConn
34- udpAddrPool sync.Pool
32+ mu sync.Mutex // protects following fields
33+ ipv4 * net.UDPConn
34+ ipv6 * net.UDPConn
35+ blackhole4 bool
36+ blackhole6 bool
37+ ipv4PC * ipv4.PacketConn // will be nil on non-Linux
38+ ipv6PC * ipv6.PacketConn // will be nil on non-Linux
39+
40+ udpAddrPool sync.Pool // following fields are not guarded by mu
3541 ipv4MsgsPool sync.Pool
3642 ipv6MsgsPool sync.Pool
3743}
@@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
154160again:
155161 port := int (uport )
156162 var v4conn , v6conn * net.UDPConn
163+ var v4pc * ipv4.PacketConn
164+ var v6pc * ipv6.PacketConn
157165
158166 v4conn , port , err = listenNet ("udp4" , port )
159167 if err != nil && ! errors .Is (err , syscall .EAFNOSUPPORT ) {
@@ -173,63 +181,92 @@ again:
173181 }
174182 var fns []ReceiveFunc
175183 if v4conn != nil {
176- fns = append (fns , s .receiveIPv4 )
184+ if runtime .GOOS == "linux" {
185+ v4pc = ipv4 .NewPacketConn (v4conn )
186+ s .ipv4PC = v4pc
187+ }
188+ fns = append (fns , s .makeReceiveIPv4 (v4pc , v4conn ))
177189 s .ipv4 = v4conn
178190 }
179191 if v6conn != nil {
180- fns = append (fns , s .receiveIPv6 )
192+ if runtime .GOOS == "linux" {
193+ v6pc = ipv6 .NewPacketConn (v6conn )
194+ s .ipv6PC = v6pc
195+ }
196+ fns = append (fns , s .makeReceiveIPv6 (v6pc , v6conn ))
181197 s .ipv6 = v6conn
182198 }
183199 if len (fns ) == 0 {
184200 return nil , 0 , syscall .EAFNOSUPPORT
185201 }
186202
187- s .ipv4PC = ipv4 .NewPacketConn (s .ipv4 )
188- s .ipv6PC = ipv6 .NewPacketConn (s .ipv6 )
189-
190203 return fns , uint16 (port ), nil
191204}
192205
193- func (s * StdNetBind ) receiveIPv4 (buffs [][]byte , sizes []int , eps []Endpoint ) (n int , err error ) {
194- msgs := s .ipv4MsgsPool .Get ().(* []ipv4.Message )
195- defer s .ipv4MsgsPool .Put (msgs )
196- for i := range buffs {
197- (* msgs )[i ].Buffers [0 ] = buffs [i ]
198- }
199- numMsgs , err := s .ipv4PC .ReadBatch (* msgs , 0 )
200- if err != nil {
201- return 0 , err
202- }
203- for i := 0 ; i < numMsgs ; i ++ {
204- msg := & (* msgs )[i ]
205- sizes [i ] = msg .N
206- addrPort := msg .Addr .(* net.UDPAddr ).AddrPort ()
207- ep := asEndpoint (addrPort )
208- getSrcFromControl (msg .OOB , ep )
209- eps [i ] = ep
206+ func (s * StdNetBind ) makeReceiveIPv4 (pc * ipv4.PacketConn , conn * net.UDPConn ) ReceiveFunc {
207+ return func (buffs [][]byte , sizes []int , eps []Endpoint ) (n int , err error ) {
208+ msgs := s .ipv4MsgsPool .Get ().(* []ipv4.Message )
209+ defer s .ipv4MsgsPool .Put (msgs )
210+ for i := range buffs {
211+ (* msgs )[i ].Buffers [0 ] = buffs [i ]
212+ }
213+ var numMsgs int
214+ if runtime .GOOS == "linux" {
215+ numMsgs , err = pc .ReadBatch (* msgs , 0 )
216+ if err != nil {
217+ return 0 , err
218+ }
219+ } else {
220+ msg := & (* msgs )[0 ]
221+ msg .N , msg .NN , _ , msg .Addr , err = conn .ReadMsgUDP (msg .Buffers [0 ], msg .OOB )
222+ if err != nil {
223+ return 0 , err
224+ }
225+ numMsgs = 1
226+ }
227+ for i := 0 ; i < numMsgs ; i ++ {
228+ msg := & (* msgs )[i ]
229+ sizes [i ] = msg .N
230+ addrPort := msg .Addr .(* net.UDPAddr ).AddrPort ()
231+ ep := asEndpoint (addrPort )
232+ getSrcFromControl (msg .OOB , ep )
233+ eps [i ] = ep
234+ }
235+ return numMsgs , nil
210236 }
211- return numMsgs , nil
212237}
213238
214- func (s * StdNetBind ) receiveIPv6 (buffs [][]byte , sizes []int , eps []Endpoint ) (n int , err error ) {
215- msgs := s .ipv6MsgsPool .Get ().(* []ipv6.Message )
216- defer s .ipv6MsgsPool .Put (msgs )
217- for i := range buffs {
218- (* msgs )[i ].Buffers [0 ] = buffs [i ]
219- }
220- numMsgs , err := s .ipv6PC .ReadBatch (* msgs , 0 )
221- if err != nil {
222- return 0 , err
223- }
224- for i := 0 ; i < numMsgs ; i ++ {
225- msg := & (* msgs )[i ]
226- sizes [i ] = msg .N
227- addrPort := msg .Addr .(* net.UDPAddr ).AddrPort ()
228- ep := asEndpoint (addrPort )
229- getSrcFromControl (msg .OOB , ep )
230- eps [i ] = ep
239+ func (s * StdNetBind ) makeReceiveIPv6 (pc * ipv6.PacketConn , conn * net.UDPConn ) ReceiveFunc {
240+ return func (buffs [][]byte , sizes []int , eps []Endpoint ) (n int , err error ) {
241+ msgs := s .ipv4MsgsPool .Get ().(* []ipv6.Message )
242+ defer s .ipv4MsgsPool .Put (msgs )
243+ for i := range buffs {
244+ (* msgs )[i ].Buffers [0 ] = buffs [i ]
245+ }
246+ var numMsgs int
247+ if runtime .GOOS == "linux" {
248+ numMsgs , err = pc .ReadBatch (* msgs , 0 )
249+ if err != nil {
250+ return 0 , err
251+ }
252+ } else {
253+ msg := & (* msgs )[0 ]
254+ msg .N , msg .NN , _ , msg .Addr , err = conn .ReadMsgUDP (msg .Buffers [0 ], msg .OOB )
255+ if err != nil {
256+ return 0 , err
257+ }
258+ numMsgs = 1
259+ }
260+ for i := 0 ; i < numMsgs ; i ++ {
261+ msg := & (* msgs )[i ]
262+ sizes [i ] = msg .N
263+ addrPort := msg .Addr .(* net.UDPAddr ).AddrPort ()
264+ ep := asEndpoint (addrPort )
265+ getSrcFromControl (msg .OOB , ep )
266+ eps [i ] = ep
267+ }
268+ return numMsgs , nil
231269 }
232- return numMsgs , nil
233270}
234271
235272// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
@@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
246283 if s .ipv4 != nil {
247284 err1 = s .ipv4 .Close ()
248285 s .ipv4 = nil
286+ s .ipv4PC = nil
249287 }
250288 if s .ipv6 != nil {
251289 err2 = s .ipv6 .Close ()
252290 s .ipv6 = nil
291+ s .ipv6PC = nil
253292 }
254293 s .blackhole4 = false
255294 s .blackhole6 = false
@@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
263302 s .mu .Lock ()
264303 blackhole := s .blackhole4
265304 conn := s .ipv4
305+ var (
306+ pc4 * ipv4.PacketConn
307+ pc6 * ipv6.PacketConn
308+ )
266309 is6 := false
267310 if endpoint .DstIP ().Is6 () {
268311 blackhole = s .blackhole6
269312 conn = s .ipv6
313+ pc6 = s .ipv6PC
270314 is6 = true
315+ } else {
316+ pc4 = s .ipv4PC
271317 }
272318 s .mu .Unlock ()
273319
@@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
278324 return syscall .EAFNOSUPPORT
279325 }
280326 if is6 {
281- return s .send6 (s . ipv6PC , endpoint , buffs )
327+ return s .send6 (conn , pc6 , endpoint , buffs )
282328 } else {
283- return s .send4 (s . ipv4PC , endpoint , buffs )
329+ return s .send4 (conn , pc4 , endpoint , buffs )
284330 }
285331}
286332
287- func (s * StdNetBind ) send4 (conn * ipv4.PacketConn , ep Endpoint , buffs [][]byte ) error {
333+ func (s * StdNetBind ) send4 (conn * net. UDPConn , pc * ipv4.PacketConn , ep Endpoint , buffs [][]byte ) error {
288334 ua := s .udpAddrPool .Get ().(* net.UDPAddr )
289335 as4 := ep .DstIP ().As4 ()
290336 copy (ua .IP , as4 [:])
@@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
301347 err error
302348 start int
303349 )
304- for {
305- n , err = conn .WriteBatch ((* msgs )[start :len (buffs )], 0 )
306- if err != nil || n == len ((* msgs )[start :len (buffs )]) {
307- break
350+ if runtime .GOOS == "linux" {
351+ for {
352+ n , err = pc .WriteBatch ((* msgs )[start :len (buffs )], 0 )
353+ if err != nil || n == len ((* msgs )[start :len (buffs )]) {
354+ break
355+ }
356+ start += n
357+ }
358+ } else {
359+ for i , buff := range buffs {
360+ _ , _ , err = conn .WriteMsgUDP (buff , (* msgs )[i ].OOB , ua )
361+ if err != nil {
362+ break
363+ }
308364 }
309- start += n
310365 }
311366 s .udpAddrPool .Put (ua )
312367 s .ipv4MsgsPool .Put (msgs )
313368 return err
314369}
315370
316- func (s * StdNetBind ) send6 (conn * ipv6.PacketConn , ep Endpoint , buffs [][]byte ) error {
371+ func (s * StdNetBind ) send6 (conn * net. UDPConn , pc * ipv6.PacketConn , ep Endpoint , buffs [][]byte ) error {
317372 ua := s .udpAddrPool .Get ().(* net.UDPAddr )
318373 as16 := ep .DstIP ().As16 ()
319374 copy (ua .IP , as16 [:])
@@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
330385 err error
331386 start int
332387 )
333- for {
334- n , err = conn .WriteBatch ((* msgs )[start :len (buffs )], 0 )
335- if err != nil || n == len ((* msgs )[start :len (buffs )]) {
336- break
388+ if runtime .GOOS == "linux" {
389+ for {
390+ n , err = pc .WriteBatch ((* msgs )[start :len (buffs )], 0 )
391+ if err != nil || n == len ((* msgs )[start :len (buffs )]) {
392+ break
393+ }
394+ start += n
395+ }
396+ } else {
397+ for i , buff := range buffs {
398+ _ , _ , err = conn .WriteMsgUDP (buff , (* msgs )[i ].OOB , ua )
399+ if err != nil {
400+ break
401+ }
337402 }
338- start += n
339403 }
340404 s .udpAddrPool .Put (ua )
341405 s .ipv6MsgsPool .Put (msgs )
0 commit comments