diff --git a/examples/zinx_kcp/kcp_client.go b/examples/zinx_kcp/kcp_client.go index 0c856052..05264a6e 100644 --- a/examples/zinx_kcp/kcp_client.go +++ b/examples/zinx_kcp/kcp_client.go @@ -6,6 +6,7 @@ import ( "github.com/aceld/zinx/zpack" "github.com/xtaci/kcp-go" "io" + "time" ) // 模拟客户端 @@ -19,8 +20,8 @@ func main() { } dp := zpack.Factory().NewPack(ziface.ZinxDataPack) - msg, _ := dp.Pack(zpack.NewMsgPackage(1, []byte("client test message"))) - _, err = conn.Write(msg) + sendMsg, _ := dp.Pack(zpack.NewMsgPackage(1, []byte("client test message"))) + _, err = conn.Write(sendMsg) if err != nil { fmt.Println("client write err: ", err) return @@ -55,6 +56,13 @@ func main() { } fmt.Printf("==> Client receive Msg: ID = %d, len = %d , data = %s\n", msg.ID, msg.DataLen, msg.Data) + + time.Sleep(1 * time.Second) + _, err = conn.Write(sendMsg) + if err != nil { + fmt.Println("client write err: ", err) + return + } } } } diff --git a/examples/zinx_kcp/server.go b/examples/zinx_kcp/server.go index 03213ae4..6df6d95f 100644 --- a/examples/zinx_kcp/server.go +++ b/examples/zinx_kcp/server.go @@ -14,6 +14,8 @@ type TestRouter struct { znet.BaseRouter } +var dealTimes = 0 + // PreHandle - func (t *TestRouter) PreHandle(req ziface.IRequest) { start := time.Now() @@ -30,27 +32,24 @@ func (t *TestRouter) PreHandle(req ziface.IRequest) { func (t *TestRouter) Handle(req ziface.IRequest) { fmt.Println("--> Call Handle") - // Simulated scenario - In the event of an expected error such as incorrect permissions or incorrect information, - // subsequent function execution will be stopped, but this function will be fully executed. - // 模拟场景- 出现意料之中的错误 如权限不对或者信息错误 则停止后续函数执行,但是次函数会执行完毕 if err := Err(); err != nil { req.Abort() fmt.Println("Insufficient permission") } - // Simulation scenario - In case of a certain situation, repeat the above operation. - // 模拟场景- 出现某种情况,重复上面的操作 - /* - if err := Err(); err != nil { - req.Goto(znet.PRE_HANDLE) - fmt.Println("repeat") - } - */ + dealTimes++ + req.GetConnection().AddCloseCallback(nil, nil, func() { + fmt.Println("run close callback") + }) if err := req.GetConnection().SendMsg(0, []byte("test2")); err != nil { fmt.Println(err) } + if dealTimes == 5 { + req.GetConnection().Stop() + } + time.Sleep(1 * time.Millisecond) } @@ -79,5 +78,11 @@ func main() { LogFile: "test.log", }) s.AddRouter(1, &TestRouter{}) + s.SetOnConnStart(func(conn ziface.IConnection) { + fmt.Println("--> OnConnStart") + }) + s.SetOnConnStop(func(conn ziface.IConnection) { + fmt.Println("--> OnConnStop") + }) s.Serve() } diff --git a/ziface/iconnection.go b/ziface/iconnection.go index 842c8bad..64ae0b3a 100644 --- a/ziface/iconnection.go +++ b/ziface/iconnection.go @@ -53,4 +53,8 @@ type IConnection interface { RemoveProperty(key string) // Remove connection property IsAlive() bool // Check if the current connection is alive(判断当前连接是否存活) SetHeartBeat(checker IHeartbeatChecker) // Set the heartbeat detector (设置心跳检测器) + + AddCloseCallback(handler, key interface{}, callback func()) // Add a close callback function (添加关闭回调函数) + RemoveCloseCallback(handler, key interface{}) // Remove a close callback function (删除关闭回调函数) + InvokeCloseCallbacks() // Trigger the close callback function (触发关闭回调函数,独立协程完成) } diff --git a/znet/callbacks.go b/znet/callbacks.go new file mode 100644 index 00000000..de7b9f00 --- /dev/null +++ b/znet/callbacks.go @@ -0,0 +1,57 @@ +package znet + +type callbackCommon struct { + handler interface{} + key interface{} + call func() + next *callbackCommon +} + +type callbacks struct { + first *callbackCommon + last *callbackCommon +} + +func (t *callbacks) Add(handler, key interface{}, callback func()) { + if callback == nil { + return + } + newItem := &callbackCommon{handler, key, callback, nil} + if t.first == nil { + t.first = newItem + } else { + t.last.next = newItem + } + t.last = newItem +} + +func (t *callbacks) Remove(handler, key interface{}) { + var prev *callbackCommon + for callback := t.first; callback != nil; prev, callback = callback, callback.next { + if callback.handler == handler && callback.key == key { + if t.first == callback { + t.first = callback.next + } else if prev != nil { + prev.next = callback.next + } + if t.last == callback { + t.last = prev + } + return + } + } +} + +func (t *callbacks) Invoke() { + for callback := t.first; callback != nil; callback = callback.next { + callback.call() + } +} + +func (t *callbacks) Count() int { + var count int + for callback := t.first; callback != nil; callback = callback.next { + count++ + } + return count +} diff --git a/znet/callbacks_test.go b/znet/callbacks_test.go new file mode 100644 index 00000000..2ee06015 --- /dev/null +++ b/znet/callbacks_test.go @@ -0,0 +1,29 @@ +package znet + +import "testing" + +func TestCallback(t *testing.T) { + cb := &callbacks{} + var count, expected int + + cb.Add("handler", "a", func() { + count++ + }) + cb.Add("handler", "b", func() { + count++ + }) + cb.Invoke() + + expected = 2 + if count != expected { + t.Errorf("returned %d, expected %d", count, expected) + } + + count = 0 + expected = 1 + cb.Remove("handler", "b") + cb.Invoke() + if count != expected { + t.Errorf("returned %d, expected %d", count, expected) + } +} diff --git a/znet/connection.go b/znet/connection.go index 223fd38d..48c152c8 100644 --- a/znet/connection.go +++ b/znet/connection.go @@ -110,6 +110,12 @@ type Connection struct { // Remote address of the current connection // (当前链接的远程地址) remoteAddr string + + // Close callback + closeCallback callbacks + + // Close callback mutex + closeCallbackMutex sync.RWMutex } // newServerConn :for Server, method to create a Server-side connection with Server-specific properties @@ -487,6 +493,21 @@ func (c *Connection) finalizer() { c.connManager.Remove(c) } + // Close all channels associated with the connection + if c.msgBuffChan != nil { + close(c.msgBuffChan) + } + + go func() { + defer func() { + if err := recover(); err != nil { + zlog.Ins().ErrorF("Conn finalizer panic: %v", err) + } + }() + + c.InvokeCloseCallbacks() + }() + zlog.Ins().InfoF("Conn Stop()...ConnID = %d", c.connID) } @@ -549,3 +570,27 @@ func (c *Connection) setClose() bool { func (c *Connection) setStartWriterFlag() bool { return atomic.CompareAndSwapInt32(&c.startWriterFlag, 0, 1) } + +func (s *Connection) AddCloseCallback(handler, key interface{}, f func()) { + if s.isClosed() { + return + } + s.closeCallbackMutex.Lock() + defer s.closeCallbackMutex.Unlock() + s.closeCallback.Add(handler, key, f) +} + +func (s *Connection) RemoveCloseCallback(handler, key interface{}) { + if s.isClosed() { + return + } + s.closeCallbackMutex.Lock() + defer s.closeCallbackMutex.Unlock() + s.closeCallback.Remove(handler, key) +} + +func (s *Connection) InvokeCloseCallbacks() { + s.closeCallbackMutex.RLock() + defer s.closeCallbackMutex.RUnlock() + s.closeCallback.Invoke() +} diff --git a/znet/kcp_connection.go b/znet/kcp_connection.go index 0272596c..b1dea1f2 100644 --- a/znet/kcp_connection.go +++ b/znet/kcp_connection.go @@ -7,6 +7,7 @@ import ( "net" "strconv" "sync" + "sync/atomic" "time" "github.com/aceld/zinx/ziface" @@ -69,7 +70,7 @@ type KcpConnection struct { // The current connection's close state // (当前连接的关闭状态) - isClosed bool + closed int32 // Which Connection Manager the current connection belongs to // (当前链接是属于哪个Connection Manager的) @@ -110,6 +111,12 @@ type KcpConnection struct { // Remote address of the current connection // (当前链接的远程地址) remoteAddr string + + // Close callback + closeCallback callbacks + + // Close callback mutex + closeCallbackMutex sync.RWMutex } // newKcpServerConn :for Server, method to create a Server-side connection with Server-specific properties @@ -120,7 +127,6 @@ func newKcpServerConn(server ziface.IServer, conn *kcp.UDPSession, connID uint64 conn: conn, connID: connID, connIdStr: strconv.FormatUint(connID, 10), - isClosed: false, msgBuffChan: nil, property: nil, name: server.ServerName(), @@ -157,7 +163,6 @@ func newKcpClientConn(client ziface.IClient, conn *kcp.UDPSession) ziface.IConne conn: conn, connID: 0, // client ignore connIdStr: "", // client ignore - isClosed: false, msgBuffChan: nil, property: nil, name: client.GetName(), @@ -346,7 +351,7 @@ func (c *KcpConnection) LocalAddr() net.Addr { func (c *KcpConnection) Send(data []byte) error { c.msgLock.RLock() defer c.msgLock.RUnlock() - if c.isClosed == true { + if c.isClosed() { return errors.New("connection closed when send msg") } @@ -375,7 +380,7 @@ func (c *KcpConnection) SendToQueue(data []byte) error { idleTimeout := time.NewTimer(5 * time.Millisecond) defer idleTimeout.Stop() - if c.isClosed == true { + if c.isClosed() { return errors.New("Connection closed when send buff msg") } @@ -396,7 +401,7 @@ func (c *KcpConnection) SendToQueue(data []byte) error { // SendMsg directly sends Message data to the remote KCP client. // (直接将Message数据发送数据给远程的KCP客户端) func (c *KcpConnection) SendMsg(msgID uint32, data []byte) error { - if c.isClosed == true { + if c.isClosed() { return errors.New("connection closed when send msg") } // Pack data and send it @@ -416,7 +421,7 @@ func (c *KcpConnection) SendMsg(msgID uint32, data []byte) error { } func (c *KcpConnection) SendBuffMsg(msgID uint32, data []byte) error { - if c.isClosed == true { + if c.isClosed() { return errors.New("connection closed when send buff msg") } if c.msgBuffChan == nil { @@ -479,18 +484,23 @@ func (c *KcpConnection) Context() context.Context { } func (c *KcpConnection) finalizer() { + // If the connection has already been closed + if c.isClosed() == true { + return + } + + //set closed + if !c.setClose() { + return + } + // Call the callback function registered by the user when closing the connection if it exists - // (如果用户注册了该链接的 关闭回调业务,那么在此刻应该显示调用) + //(如果用户注册了该链接的 关闭回调业务,那么在此刻应该显示调用) c.callOnConnStop() c.msgLock.Lock() defer c.msgLock.Unlock() - // If the connection has already been closed - if c.isClosed == true { - return - } - // Stop the heartbeat detector associated with the connection if c.hc != nil { c.hc.Stop() @@ -509,7 +519,15 @@ func (c *KcpConnection) finalizer() { close(c.msgBuffChan) } - c.isClosed = true + go func() { + defer func() { + if err := recover(); err != nil { + zlog.Ins().ErrorF("Conn finalizer panic: %v", err) + } + }() + + c.InvokeCloseCallbacks() + }() zlog.Ins().InfoF("Conn Stop()...ConnID = %d", c.connID) } @@ -529,7 +547,7 @@ func (c *KcpConnection) callOnConnStop() { } func (c *KcpConnection) IsAlive() bool { - if c.isClosed { + if c.isClosed() { return false } // Check the last activity time of the connection. If it's beyond the heartbeat interval, @@ -562,6 +580,39 @@ func (c *KcpConnection) GetMsgHandler() ziface.IMsgHandle { return c.msgHandler } +func (c *KcpConnection) isClosed() bool { + return atomic.LoadInt32(&c.closed) != 0 +} + +func (c *KcpConnection) setClose() bool { + return atomic.CompareAndSwapInt32(&c.closed, 0, 1) +} + +func (s *KcpConnection) AddCloseCallback(handler, key interface{}, f func()) { + if s.isClosed() { + return + } + s.closeCallbackMutex.Lock() + defer s.closeCallbackMutex.Unlock() + s.closeCallback.Add(handler, key, f) +} + +func (s *KcpConnection) RemoveCloseCallback(handler, key interface{}) { + if s.isClosed() { + return + } + s.closeCallbackMutex.Lock() + defer s.closeCallbackMutex.Unlock() + s.closeCallback.Remove(handler, key) +} + +// invokeCloseCallbacks 触发 close callback, 在独立协程完成 +func (s *KcpConnection) InvokeCloseCallbacks() { + s.closeCallbackMutex.RLock() + defer s.closeCallbackMutex.RUnlock() + s.closeCallback.Invoke() +} + // Implement other KCP specific methods here... // ... // ... diff --git a/znet/ws_connection.go b/znet/ws_connection.go index 29cd91e2..099b8d88 100644 --- a/znet/ws_connection.go +++ b/znet/ws_connection.go @@ -100,6 +100,12 @@ type WsConnection struct { // remoteAddr is the remote address of the current connection. (当前链接的远程地址) remoteAddr string + + // Close callback + closeCallback callbacks + + // Close callback mutex + closeCallbackMutex sync.RWMutex } // newServerConn: for Server, a method to create a connection with Server characteristics @@ -518,6 +524,16 @@ func (c *WsConnection) finalizer() { // Set the flag to indicate that the connection is closed. (设置标志位) c.isClosed = true + go func() { + defer func() { + if err := recover(); err != nil { + zlog.Ins().ErrorF("Conn finalizer panic: %v", err) + } + }() + + c.InvokeCloseCallbacks() + }() + zlog.Ins().InfoF("Conn Stop()...ConnID = %d", c.connID) } @@ -568,3 +584,27 @@ func (c *WsConnection) GetName() string { func (c *WsConnection) GetMsgHandler() ziface.IMsgHandle { return c.msgHandler } + +func (s *WsConnection) AddCloseCallback(handler, key interface{}, f func()) { + if s.isClosed { + return + } + s.closeCallbackMutex.Lock() + defer s.closeCallbackMutex.Unlock() + s.closeCallback.Add(handler, key, f) +} + +func (s *WsConnection) RemoveCloseCallback(handler, key interface{}) { + if s.isClosed { + return + } + s.closeCallbackMutex.Lock() + defer s.closeCallbackMutex.Unlock() + s.closeCallback.Remove(handler, key) +} + +func (s *WsConnection) InvokeCloseCallbacks() { + s.closeCallbackMutex.RLock() + defer s.closeCallbackMutex.RUnlock() + s.closeCallback.Invoke() +}