diff --git a/internal/pool/conn.go b/internal/pool/conn.go index c1087b401..67dcc2ab5 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -7,7 +7,9 @@ import ( "sync/atomic" "time" + "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/pushnotif" ) var noDeadline = time.Time{} @@ -25,6 +27,10 @@ type Conn struct { createdAt time.Time onClose func() error + + // Push notification processor for handling push notifications on this connection + // This is set when the connection is created and is a reference to the processor + PushNotificationProcessor pushnotif.ProcessorInterface } func NewConn(netConn net.Conn) *Conn { @@ -72,11 +78,23 @@ func (cn *Conn) RemoteAddr() net.Addr { func (cn *Conn) WithReader( ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error, ) error { + // Process any pending push notifications before executing the read function + // This ensures push notifications are handled as soon as they arrive + if cn.PushNotificationProcessor != nil { + // Type assert to the processor interface + if err := cn.PushNotificationProcessor.ProcessPendingNotifications(ctx, cn.rd); err != nil { + // Log the error but don't fail the read operation + // Push notification processing errors shouldn't break normal Redis operations + internal.Logger.Printf(ctx, "push: error processing pending notifications in WithReader: %v", err) + } + } + if timeout >= 0 { if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { return err } } + return fn(cn.rd) } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 3ee3dea6d..efadfaaef 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -9,6 +9,7 @@ import ( "time" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pushnotif" ) var ( @@ -71,6 +72,13 @@ type Options struct { MaxActiveConns int ConnMaxIdleTime time.Duration ConnMaxLifetime time.Duration + + // Push notification processor for connections + // This is an interface to avoid circular imports + PushNotificationProcessor pushnotif.ProcessorInterface + + // Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without) + Protocol int } type lastDialErrorWrap struct { @@ -228,6 +236,12 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { cn := NewConn(netConn) cn.pooled = pooled + + // Set push notification processor if available + if p.cfg.PushNotificationProcessor != nil { + cn.PushNotificationProcessor = p.cfg.PushNotificationProcessor + } + return cn, nil } @@ -377,9 +391,24 @@ func (p *ConnPool) popIdle() (*Conn, error) { func (p *ConnPool) Put(ctx context.Context, cn *Conn) { if cn.rd.Buffered() > 0 { - internal.Logger.Printf(ctx, "Conn has unread data") - p.Remove(ctx, cn, BadConnError{}) - return + // Check if this might be push notification data + if cn.PushNotificationProcessor != nil && p.cfg.Protocol == 3 { + // Only process for RESP3 clients (push notifications only available in RESP3) + err := cn.PushNotificationProcessor.ProcessPendingNotifications(ctx, cn.rd) + if err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications: %v", err) + } + // Check again if there's still unread data after processing push notifications + if cn.rd.Buffered() > 0 { + internal.Logger.Printf(ctx, "Conn has unread data after processing push notifications") + p.Remove(ctx, cn, BadConnError{}) + return + } + } else { + internal.Logger.Printf(ctx, "Conn has unread data") + p.Remove(ctx, cn, BadConnError{}) + return + } } if !cn.pooled { @@ -523,8 +552,24 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool { return false } - if connCheck(cn.netConn) != nil { - return false + // Check connection health, but be aware of push notifications + if err := connCheck(cn.netConn); err != nil { + // If there's unexpected data and we have push notification support, + // it might be push notifications (only for RESP3) + if err == errUnexpectedRead && cn.PushNotificationProcessor != nil && p.cfg.Protocol == 3 { + // Try to process any pending push notifications (only for RESP3) + ctx := context.Background() + if procErr := cn.PushNotificationProcessor.ProcessPendingNotifications(ctx, cn.rd); procErr != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications during health check: %v", procErr) + return false + } + // Check again after processing push notifications + if connCheck(cn.netConn) != nil { + return false + } + } else { + return false + } } cn.SetUsedAt(now) diff --git a/internal/proto/reader.go b/internal/proto/reader.go index 8d23817fe..8daa08a1d 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -90,6 +90,27 @@ func (r *Reader) PeekReplyType() (byte, error) { return b[0], nil } +func (r *Reader) PeekPushNotificationName() (string, error) { + // peek 32 bytes, should be enough to read the push notification name + buf, err := r.rd.Peek(32) + if err != nil { + return "", err + } + if buf[0] != RespPush { + return "", fmt.Errorf("redis: can't parse push notification: %q", buf) + } + // remove push notification type and length + nextLine := buf[2:] + for i := 1; i < len(buf); i++ { + if buf[i] == '\r' && buf[i+1] == '\n' { + nextLine = buf[i+2:] + break + } + } + // return notification name or error + return r.readStringReply(nextLine) +} + // ReadLine Return a valid reply, it will check the protocol or redis error, // and discard the attribute type. func (r *Reader) ReadLine() ([]byte, error) { diff --git a/internal/pushnotif/processor.go b/internal/pushnotif/processor.go new file mode 100644 index 000000000..4476ecb84 --- /dev/null +++ b/internal/pushnotif/processor.go @@ -0,0 +1,186 @@ +package pushnotif + +import ( + "context" + "fmt" + + "github.com/redis/go-redis/v9/internal/proto" +) + +// Processor handles push notifications with a registry of handlers. +type Processor struct { + registry *Registry +} + +// NewProcessor creates a new push notification processor. +func NewProcessor() *Processor { + return &Processor{ + registry: NewRegistry(), + } +} + +// GetHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (p *Processor) GetHandler(pushNotificationName string) Handler { + return p.registry.GetHandler(pushNotificationName) +} + +// RegisterHandler registers a handler for a specific push notification name. +// Returns an error if a handler is already registered for this push notification name. +// If protected is true, the handler cannot be unregistered. +func (p *Processor) RegisterHandler(pushNotificationName string, handler Handler, protected bool) error { + return p.registry.RegisterHandler(pushNotificationName, handler, protected) +} + +// UnregisterHandler removes a handler for a specific push notification name. +// Returns an error if the handler is protected or doesn't exist. +func (p *Processor) UnregisterHandler(pushNotificationName string) error { + return p.registry.UnregisterHandler(pushNotificationName) +} + +// ProcessPendingNotifications checks for and processes any pending push notifications. +func (p *Processor) ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error { + // Check for nil reader + if rd == nil { + return nil + } + + // Check if there are any buffered bytes that might contain push notifications + if rd.Buffered() == 0 { + return nil + } + + // Process all available push notifications + for { + // Peek at the next reply type to see if it's a push notification + replyType, err := rd.PeekReplyType() + if err != nil { + // No more data available or error reading + break + } + + // Push notifications use RespPush type in RESP3 + if replyType != proto.RespPush { + break + } + + notificationName, err := rd.PeekPushNotificationName() + if err != nil { + // Error reading - continue to next iteration + break + } + + // Skip notifications that should be handled by other systems + if shouldSkipNotification(notificationName) { + break + } + + // Try to read the push notification + reply, err := rd.ReadReply() + if err != nil { + return fmt.Errorf("failed to read push notification: %w", err) + } + + // Convert to slice of interfaces + notification, ok := reply.([]interface{}) + if !ok { + continue + } + + // Handle the notification directly + if len(notification) > 0 { + // Extract the notification type (first element) + if notificationType, ok := notification[0].(string); ok { + // Skip notifications that should be handled by other systems + if shouldSkipNotification(notificationType) { + continue + } + + // Get the handler for this notification type + if handler := p.registry.GetHandler(notificationType); handler != nil { + // Handle the notification + handler.HandlePushNotification(ctx, notification) + } + } + } + } + + return nil +} + +// shouldSkipNotification checks if a notification type should be ignored by the push notification +// processor and handled by other specialized systems instead (pub/sub, streams, keyspace, etc.). +func shouldSkipNotification(notificationType string) bool { + switch notificationType { + // Pub/Sub notifications - handled by pub/sub system + case "message", // Regular pub/sub message + "pmessage", // Pattern pub/sub message + "subscribe", // Subscription confirmation + "unsubscribe", // Unsubscription confirmation + "psubscribe", // Pattern subscription confirmation + "punsubscribe", // Pattern unsubscription confirmation + "smessage", // Sharded pub/sub message (Redis 7.0+) + "ssubscribe", // Sharded subscription confirmation + "sunsubscribe", // Sharded unsubscription confirmation + + // Stream notifications - handled by stream consumers + "xread-from", // Stream reading notifications + "xreadgroup-from", // Stream consumer group notifications + + // Client tracking notifications - handled by client tracking system + "invalidate", // Client-side caching invalidation + + // Keyspace notifications - handled by keyspace notification subscribers + // Note: Keyspace notifications typically have prefixes like "__keyspace@0__:" or "__keyevent@0__:" + // but we'll handle the base notification types here + "expired", // Key expiration events + "evicted", // Key eviction events + "set", // Key set events + "del", // Key deletion events + "rename", // Key rename events + "move", // Key move events + "copy", // Key copy events + "restore", // Key restore events + "sort", // Sort operation events + "flushdb", // Database flush events + "flushall": // All databases flush events + return true + default: + return false + } +} + +// VoidProcessor discards all push notifications without processing them. +type VoidProcessor struct{} + +// NewVoidProcessor creates a new void push notification processor. +func NewVoidProcessor() *VoidProcessor { + return &VoidProcessor{} +} + +// GetHandler returns nil for void processor since it doesn't maintain handlers. +func (v *VoidProcessor) GetHandler(pushNotificationName string) Handler { + return nil +} + +// RegisterHandler returns an error for void processor since it doesn't maintain handlers. +// This helps developers identify when they're trying to register handlers on disabled push notifications. +func (v *VoidProcessor) RegisterHandler(pushNotificationName string, handler Handler, protected bool) error { + return fmt.Errorf("cannot register push notification handler '%s': push notifications are disabled (using void processor)", pushNotificationName) +} + +// UnregisterHandler returns an error for void processor since it doesn't maintain handlers. +// This helps developers identify when they're trying to unregister handlers on disabled push notifications. +func (v *VoidProcessor) UnregisterHandler(pushNotificationName string) error { + return fmt.Errorf("cannot unregister push notification handler '%s': push notifications are disabled (using void processor)", pushNotificationName) +} + +// ProcessPendingNotifications for VoidProcessor does nothing since push notifications +// are only available in RESP3 and this processor is used when they're disabled. +// This avoids unnecessary buffer scanning overhead. +func (v *VoidProcessor) ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error { + // VoidProcessor is used when push notifications are disabled (typically RESP2 or disabled RESP3). + // Since push notifications only exist in RESP3, we can safely skip all processing + // to avoid unnecessary buffer scanning overhead. + return nil +} diff --git a/internal/pushnotif/pushnotif_test.go b/internal/pushnotif/pushnotif_test.go new file mode 100644 index 000000000..3fa84e885 --- /dev/null +++ b/internal/pushnotif/pushnotif_test.go @@ -0,0 +1,768 @@ +package pushnotif + +import ( + "context" + "io" + "strings" + "testing" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/proto" +) + +// TestHandler implements Handler interface for testing +type TestHandler struct { + name string + handled [][]interface{} + returnValue bool +} + +func NewTestHandler(name string, returnValue bool) *TestHandler { + return &TestHandler{ + name: name, + handled: make([][]interface{}, 0), + returnValue: returnValue, + } +} + +func (h *TestHandler) HandlePushNotification(ctx context.Context, notification []interface{}) bool { + h.handled = append(h.handled, notification) + return h.returnValue +} + +func (h *TestHandler) GetHandledNotifications() [][]interface{} { + return h.handled +} + +func (h *TestHandler) Reset() { + h.handled = make([][]interface{}, 0) +} + +// TestReaderInterface defines the interface needed for testing +type TestReaderInterface interface { + PeekReplyType() (byte, error) + PeekPushNotificationName() (string, error) + ReadReply() (interface{}, error) +} + +// MockReader implements TestReaderInterface for testing +type MockReader struct { + peekReplies []peekReply + peekIndex int + readReplies []interface{} + readErrors []error + readIndex int +} + +type peekReply struct { + replyType byte + err error +} + +func NewMockReader() *MockReader { + return &MockReader{ + peekReplies: make([]peekReply, 0), + readReplies: make([]interface{}, 0), + readErrors: make([]error, 0), + readIndex: 0, + peekIndex: 0, + } +} + +func (m *MockReader) AddPeekReplyType(replyType byte, err error) { + m.peekReplies = append(m.peekReplies, peekReply{replyType: replyType, err: err}) +} + +func (m *MockReader) AddReadReply(reply interface{}, err error) { + m.readReplies = append(m.readReplies, reply) + m.readErrors = append(m.readErrors, err) +} + +func (m *MockReader) PeekReplyType() (byte, error) { + if m.peekIndex >= len(m.peekReplies) { + return 0, io.EOF + } + peek := m.peekReplies[m.peekIndex] + m.peekIndex++ + return peek.replyType, peek.err +} + +func (m *MockReader) ReadReply() (interface{}, error) { + if m.readIndex >= len(m.readReplies) { + return nil, io.EOF + } + reply := m.readReplies[m.readIndex] + err := m.readErrors[m.readIndex] + m.readIndex++ + return reply, err +} + +func (m *MockReader) PeekPushNotificationName() (string, error) { + // return the notification name from the next read reply + if m.readIndex >= len(m.readReplies) { + return "", io.EOF + } + reply := m.readReplies[m.readIndex] + if reply == nil { + return "", nil + } + notification, ok := reply.([]interface{}) + if !ok { + return "", nil + } + if len(notification) == 0 { + return "", nil + } + name, ok := notification[0].(string) + if !ok { + return "", nil + } + return name, nil +} + +func (m *MockReader) Reset() { + m.readIndex = 0 + m.peekIndex = 0 +} + +// testProcessPendingNotifications is a test version that accepts our mock reader +func testProcessPendingNotifications(processor *Processor, ctx context.Context, reader TestReaderInterface) error { + if reader == nil { + return nil + } + + for { + // Check if there are push notifications available + replyType, err := reader.PeekReplyType() + if err != nil { + // No more data or error - this is normal + break + } + + // Only process push notifications + if replyType != proto.RespPush { + break + } + + notificationName, err := reader.PeekPushNotificationName() + if err != nil { + // Error reading - continue to next iteration + break + } + + // Skip notifications that should be handled by other systems + if shouldSkipNotification(notificationName) { + break + } + + // Read the push notification + reply, err := reader.ReadReply() + if err != nil { + // Error reading - continue to next iteration + internal.Logger.Printf(ctx, "push: error reading push notification: %v", err) + continue + } + + // Convert to slice of interfaces + notification, ok := reply.([]interface{}) + if !ok { + continue + } + + // Handle the notification directly + if len(notification) > 0 { + // Extract the notification type (first element) + if notificationType, ok := notification[0].(string); ok { + // Get the handler for this notification type + if handler := processor.registry.GetHandler(notificationType); handler != nil { + // Handle the notification + handler.HandlePushNotification(ctx, notification) + } + } + } + } + + return nil +} + +// TestRegistry tests the Registry implementation +func TestRegistry(t *testing.T) { + t.Run("NewRegistry", func(t *testing.T) { + registry := NewRegistry() + if registry == nil { + t.Error("NewRegistry should return a non-nil registry") + } + if registry.handlers == nil { + t.Error("Registry handlers map should be initialized") + } + if registry.protected == nil { + t.Error("Registry protected map should be initialized") + } + }) + + t.Run("RegisterHandler", func(t *testing.T) { + registry := NewRegistry() + handler := NewTestHandler("test", true) + + // Test successful registration + err := registry.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Errorf("RegisterHandler should succeed, got error: %v", err) + } + + // Test duplicate registration + err = registry.RegisterHandler("MOVING", handler, false) + if err == nil { + t.Error("RegisterHandler should return error for duplicate registration") + } + if !strings.Contains(err.Error(), "handler already registered") { + t.Errorf("Expected error about duplicate registration, got: %v", err) + } + + // Test protected registration + err = registry.RegisterHandler("MIGRATING", handler, true) + if err != nil { + t.Errorf("RegisterHandler with protected=true should succeed, got error: %v", err) + } + }) + + t.Run("GetHandler", func(t *testing.T) { + registry := NewRegistry() + handler := NewTestHandler("test", true) + + // Test getting non-existent handler + result := registry.GetHandler("NONEXISTENT") + if result != nil { + t.Error("GetHandler should return nil for non-existent handler") + } + + // Test getting existing handler + err := registry.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + result = registry.GetHandler("MOVING") + if result != handler { + t.Error("GetHandler should return the registered handler") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + registry := NewRegistry() + handler := NewTestHandler("test", true) + + // Test unregistering non-existent handler + err := registry.UnregisterHandler("NONEXISTENT") + if err == nil { + t.Error("UnregisterHandler should return error for non-existent handler") + } + if !strings.Contains(err.Error(), "no handler registered") { + t.Errorf("Expected error about no handler registered, got: %v", err) + } + + // Test unregistering regular handler + err = registry.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + err = registry.UnregisterHandler("MOVING") + if err != nil { + t.Errorf("UnregisterHandler should succeed for regular handler, got error: %v", err) + } + + // Verify handler is removed + result := registry.GetHandler("MOVING") + if result != nil { + t.Error("Handler should be removed after unregistration") + } + + // Test unregistering protected handler + err = registry.RegisterHandler("MIGRATING", handler, true) + if err != nil { + t.Fatalf("Failed to register protected handler: %v", err) + } + + err = registry.UnregisterHandler("MIGRATING") + if err == nil { + t.Error("UnregisterHandler should return error for protected handler") + } + if !strings.Contains(err.Error(), "cannot unregister protected handler") { + t.Errorf("Expected error about protected handler, got: %v", err) + } + + // Verify protected handler is still there + result = registry.GetHandler("MIGRATING") + if result != handler { + t.Error("Protected handler should still be registered after failed unregistration") + } + }) + + t.Run("GetRegisteredPushNotificationNames", func(t *testing.T) { + registry := NewRegistry() + handler1 := NewTestHandler("test1", true) + handler2 := NewTestHandler("test2", true) + + // Test empty registry + names := registry.GetRegisteredPushNotificationNames() + if len(names) != 0 { + t.Errorf("Empty registry should return empty slice, got: %v", names) + } + + // Test with registered handlers + err := registry.RegisterHandler("MOVING", handler1, false) + if err != nil { + t.Fatalf("Failed to register handler1: %v", err) + } + + err = registry.RegisterHandler("MIGRATING", handler2, true) + if err != nil { + t.Fatalf("Failed to register handler2: %v", err) + } + + names = registry.GetRegisteredPushNotificationNames() + if len(names) != 2 { + t.Errorf("Expected 2 registered names, got: %d", len(names)) + } + + // Check that both names are present (order doesn't matter) + nameMap := make(map[string]bool) + for _, name := range names { + nameMap[name] = true + } + + if !nameMap["MOVING"] { + t.Error("MOVING should be in registered names") + } + if !nameMap["MIGRATING"] { + t.Error("MIGRATING should be in registered names") + } + }) +} + +// TestProcessor tests the Processor implementation +func TestProcessor(t *testing.T) { + t.Run("NewProcessor", func(t *testing.T) { + processor := NewProcessor() + if processor == nil { + t.Error("NewProcessor should return a non-nil processor") + } + if processor.registry == nil { + t.Error("Processor should have a non-nil registry") + } + }) + + t.Run("GetHandler", func(t *testing.T) { + processor := NewProcessor() + handler := NewTestHandler("test", true) + + // Test getting non-existent handler + result := processor.GetHandler("NONEXISTENT") + if result != nil { + t.Error("GetHandler should return nil for non-existent handler") + } + + // Test getting existing handler + err := processor.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + result = processor.GetHandler("MOVING") + if result != handler { + t.Error("GetHandler should return the registered handler") + } + }) + + t.Run("RegisterHandler", func(t *testing.T) { + processor := NewProcessor() + handler := NewTestHandler("test", true) + + // Test successful registration + err := processor.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Errorf("RegisterHandler should succeed, got error: %v", err) + } + + // Test duplicate registration + err = processor.RegisterHandler("MOVING", handler, false) + if err == nil { + t.Error("RegisterHandler should return error for duplicate registration") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + processor := NewProcessor() + handler := NewTestHandler("test", true) + + // Test unregistering non-existent handler + err := processor.UnregisterHandler("NONEXISTENT") + if err == nil { + t.Error("UnregisterHandler should return error for non-existent handler") + } + + // Test successful unregistration + err = processor.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + err = processor.UnregisterHandler("MOVING") + if err != nil { + t.Errorf("UnregisterHandler should succeed, got error: %v", err) + } + }) + + t.Run("ProcessPendingNotifications", func(t *testing.T) { + processor := NewProcessor() + handler := NewTestHandler("test", true) + ctx := context.Background() + + // Test with nil reader + err := processor.ProcessPendingNotifications(ctx, nil) + if err != nil { + t.Errorf("ProcessPendingNotifications with nil reader should not error, got: %v", err) + } + + // Test with empty reader (no buffered data) + reader := proto.NewReader(strings.NewReader("")) + err = processor.ProcessPendingNotifications(ctx, reader) + if err != nil { + t.Errorf("ProcessPendingNotifications with empty reader should not error, got: %v", err) + } + + // Register a handler for testing + err = processor.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + // Test with mock reader - peek error (no push notifications available) + mockReader := NewMockReader() + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // EOF means no more data + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle peek EOF gracefully, got: %v", err) + } + + // Test with mock reader - non-push reply type + mockReader = NewMockReader() + mockReader.AddPeekReplyType(proto.RespString, nil) // Not RespPush + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle non-push reply types gracefully, got: %v", err) + } + + // Test with mock reader - push notification with ReadReply error + mockReader = NewMockReader() + mockReader.AddPeekReplyType(proto.RespPush, nil) + mockReader.AddReadReply(nil, io.ErrUnexpectedEOF) // ReadReply fails + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // No more push notifications + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle ReadReply errors gracefully, got: %v", err) + } + + // Test with mock reader - push notification with invalid reply type + mockReader = NewMockReader() + mockReader.AddPeekReplyType(proto.RespPush, nil) + mockReader.AddReadReply("not-a-slice", nil) // Invalid reply type + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // No more push notifications + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle invalid reply types gracefully, got: %v", err) + } + + // Test with mock reader - valid push notification with handler + mockReader = NewMockReader() + mockReader.AddPeekReplyType(proto.RespPush, nil) + notification := []interface{}{"MOVING", "slot", "12345"} + mockReader.AddReadReply(notification, nil) + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // No more push notifications + + handler.Reset() + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle valid notifications, got: %v", err) + } + + // Check that handler was called + handled := handler.GetHandledNotifications() + if len(handled) != 1 { + t.Errorf("Expected 1 handled notification, got: %d", len(handled)) + } else if len(handled[0]) != 3 || handled[0][0] != "MOVING" { + t.Errorf("Expected MOVING notification, got: %v", handled[0]) + } + + // Test with mock reader - valid push notification without handler + mockReader = NewMockReader() + mockReader.AddPeekReplyType(proto.RespPush, nil) + notification = []interface{}{"UNKNOWN", "data"} + mockReader.AddReadReply(notification, nil) + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // No more push notifications + + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle notifications without handlers, got: %v", err) + } + + // Test with mock reader - empty notification + mockReader = NewMockReader() + mockReader.AddPeekReplyType(proto.RespPush, nil) + emptyNotification := []interface{}{} + mockReader.AddReadReply(emptyNotification, nil) + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // No more push notifications + + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle empty notifications, got: %v", err) + } + + // Test with mock reader - notification with non-string type + mockReader = NewMockReader() + mockReader.AddPeekReplyType(proto.RespPush, nil) + invalidTypeNotification := []interface{}{123, "data"} // First element is not string + mockReader.AddReadReply(invalidTypeNotification, nil) + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // No more push notifications + + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle invalid notification types, got: %v", err) + } + + // Test the actual ProcessPendingNotifications method with real proto.Reader + // Test with nil reader + err = processor.ProcessPendingNotifications(ctx, nil) + if err != nil { + t.Errorf("ProcessPendingNotifications with nil reader should not error, got: %v", err) + } + + // Test with empty reader (no buffered data) + protoReader := proto.NewReader(strings.NewReader("")) + err = processor.ProcessPendingNotifications(ctx, protoReader) + if err != nil { + t.Errorf("ProcessPendingNotifications with empty reader should not error, got: %v", err) + } + + // Test with reader that has some data but not push notifications + protoReader = proto.NewReader(strings.NewReader("+OK\r\n")) + err = processor.ProcessPendingNotifications(ctx, protoReader) + if err != nil { + t.Errorf("ProcessPendingNotifications with non-push data should not error, got: %v", err) + } + }) +} + +// TestVoidProcessor tests the VoidProcessor implementation +func TestVoidProcessor(t *testing.T) { + t.Run("NewVoidProcessor", func(t *testing.T) { + processor := NewVoidProcessor() + if processor == nil { + t.Error("NewVoidProcessor should return a non-nil processor") + } + }) + + t.Run("GetHandler", func(t *testing.T) { + processor := NewVoidProcessor() + + // VoidProcessor should always return nil for any handler name + result := processor.GetHandler("MOVING") + if result != nil { + t.Error("VoidProcessor GetHandler should always return nil") + } + + result = processor.GetHandler("MIGRATING") + if result != nil { + t.Error("VoidProcessor GetHandler should always return nil") + } + + result = processor.GetHandler("") + if result != nil { + t.Error("VoidProcessor GetHandler should always return nil for empty string") + } + }) + + t.Run("RegisterHandler", func(t *testing.T) { + processor := NewVoidProcessor() + handler := NewTestHandler("test", true) + + // VoidProcessor should always return error for registration + err := processor.RegisterHandler("MOVING", handler, false) + if err == nil { + t.Error("VoidProcessor RegisterHandler should always return error") + } + if !strings.Contains(err.Error(), "cannot register push notification handler") { + t.Errorf("Expected error about cannot register, got: %v", err) + } + if !strings.Contains(err.Error(), "push notifications are disabled") { + t.Errorf("Expected error about disabled push notifications, got: %v", err) + } + + // Test with protected flag + err = processor.RegisterHandler("MIGRATING", handler, true) + if err == nil { + t.Error("VoidProcessor RegisterHandler should always return error even with protected=true") + } + + // Test with empty handler name + err = processor.RegisterHandler("", handler, false) + if err == nil { + t.Error("VoidProcessor RegisterHandler should always return error even with empty name") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + processor := NewVoidProcessor() + + // VoidProcessor should always return error for unregistration + err := processor.UnregisterHandler("MOVING") + if err == nil { + t.Error("VoidProcessor UnregisterHandler should always return error") + } + if !strings.Contains(err.Error(), "cannot unregister push notification handler") { + t.Errorf("Expected error about cannot unregister, got: %v", err) + } + if !strings.Contains(err.Error(), "push notifications are disabled") { + t.Errorf("Expected error about disabled push notifications, got: %v", err) + } + + // Test with empty handler name + err = processor.UnregisterHandler("") + if err == nil { + t.Error("VoidProcessor UnregisterHandler should always return error even with empty name") + } + }) + + t.Run("ProcessPendingNotifications", func(t *testing.T) { + processor := NewVoidProcessor() + ctx := context.Background() + + // VoidProcessor should always succeed and do nothing + err := processor.ProcessPendingNotifications(ctx, nil) + if err != nil { + t.Errorf("VoidProcessor ProcessPendingNotifications should never error, got: %v", err) + } + + // Test with various readers + reader := proto.NewReader(strings.NewReader("")) + err = processor.ProcessPendingNotifications(ctx, reader) + if err != nil { + t.Errorf("VoidProcessor ProcessPendingNotifications should never error, got: %v", err) + } + + reader = proto.NewReader(strings.NewReader("some data")) + err = processor.ProcessPendingNotifications(ctx, reader) + if err != nil { + t.Errorf("VoidProcessor ProcessPendingNotifications should never error, got: %v", err) + } + }) +} + +// TestShouldSkipNotification tests the shouldSkipNotification function +func TestShouldSkipNotification(t *testing.T) { + t.Run("PubSubMessages", func(t *testing.T) { + pubSubMessages := []string{ + "message", // Regular pub/sub message + "pmessage", // Pattern pub/sub message + "subscribe", // Subscription confirmation + "unsubscribe", // Unsubscription confirmation + "psubscribe", // Pattern subscription confirmation + "punsubscribe", // Pattern unsubscription confirmation + "smessage", // Sharded pub/sub message (Redis 7.0+) + } + + for _, msgType := range pubSubMessages { + if !shouldSkipNotification(msgType) { + t.Errorf("shouldSkipNotification(%q) should return true", msgType) + } + } + }) + + t.Run("NonPubSubMessages", func(t *testing.T) { + nonPubSubMessages := []string{ + "MOVING", // Cluster slot migration + "MIGRATING", // Cluster slot migration + "MIGRATED", // Cluster slot migration + "FAILING_OVER", // Cluster failover + "FAILED_OVER", // Cluster failover + "unknown", // Unknown message type + "", // Empty string + "MESSAGE", // Case sensitive - should not match + "PMESSAGE", // Case sensitive - should not match + } + + for _, msgType := range nonPubSubMessages { + if shouldSkipNotification(msgType) { + t.Errorf("shouldSkipNotification(%q) should return false", msgType) + } + } + }) +} + +// TestPubSubFiltering tests that pub/sub messages are filtered out during processing +func TestPubSubFiltering(t *testing.T) { + t.Run("PubSubMessagesIgnored", func(t *testing.T) { + processor := NewProcessor() + handler := NewTestHandler("test", true) + ctx := context.Background() + + // Register a handler for a non-pub/sub notification + err := processor.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + // Test with mock reader - pub/sub message should be ignored + mockReader := NewMockReader() + mockReader.AddPeekReplyType(proto.RespPush, nil) + pubSubNotification := []interface{}{"message", "channel", "data"} + mockReader.AddReadReply(pubSubNotification, nil) + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // No more push notifications + + handler.Reset() + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle pub/sub messages gracefully, got: %v", err) + } + + // Check that handler was NOT called for pub/sub message + handled := handler.GetHandledNotifications() + if len(handled) != 0 { + t.Errorf("Expected 0 handled notifications for pub/sub message, got: %d", len(handled)) + } + }) + + t.Run("NonPubSubMessagesProcessed", func(t *testing.T) { + processor := NewProcessor() + handler := NewTestHandler("test", true) + ctx := context.Background() + + // Register a handler for a non-pub/sub notification + err := processor.RegisterHandler("MOVING", handler, false) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + // Test with mock reader - non-pub/sub message should be processed + mockReader := NewMockReader() + mockReader.AddPeekReplyType(proto.RespPush, nil) + clusterNotification := []interface{}{"MOVING", "slot", "12345"} + mockReader.AddReadReply(clusterNotification, nil) + mockReader.AddPeekReplyType(proto.RespString, io.EOF) // No more push notifications + + handler.Reset() + err = testProcessPendingNotifications(processor, ctx, mockReader) + if err != nil { + t.Errorf("ProcessPendingNotifications should handle cluster notifications, got: %v", err) + } + + // Check that handler WAS called for cluster notification + handled := handler.GetHandledNotifications() + if len(handled) != 1 { + t.Errorf("Expected 1 handled notification for cluster message, got: %d", len(handled)) + } else if len(handled[0]) != 3 || handled[0][0] != "MOVING" { + t.Errorf("Expected MOVING notification, got: %v", handled[0]) + } + }) +} diff --git a/internal/pushnotif/registry.go b/internal/pushnotif/registry.go new file mode 100644 index 000000000..eb3ebfbdf --- /dev/null +++ b/internal/pushnotif/registry.go @@ -0,0 +1,84 @@ +package pushnotif + +import ( + "fmt" + "sync" +) + +// Registry manages push notification handlers. +type Registry struct { + mu sync.RWMutex + handlers map[string]Handler + protected map[string]bool +} + +// NewRegistry creates a new push notification registry. +func NewRegistry() *Registry { + return &Registry{ + handlers: make(map[string]Handler), + protected: make(map[string]bool), + } +} + +// RegisterHandler registers a handler for a specific push notification name. +// Returns an error if a handler is already registered for this push notification name. +// If protected is true, the handler cannot be unregistered. +func (r *Registry) RegisterHandler(pushNotificationName string, handler Handler, protected bool) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.handlers[pushNotificationName]; exists { + return fmt.Errorf("handler already registered for push notification: %s", pushNotificationName) + } + + r.handlers[pushNotificationName] = handler + r.protected[pushNotificationName] = protected + return nil +} + +// UnregisterHandler removes a handler for a specific push notification name. +// Returns an error if the handler is protected or doesn't exist. +func (r *Registry) UnregisterHandler(pushNotificationName string) error { + r.mu.Lock() + defer r.mu.Unlock() + + _, exists := r.handlers[pushNotificationName] + if !exists { + return fmt.Errorf("no handler registered for push notification: %s", pushNotificationName) + } + + if r.protected[pushNotificationName] { + return fmt.Errorf("cannot unregister protected handler for push notification: %s", pushNotificationName) + } + + delete(r.handlers, pushNotificationName) + delete(r.protected, pushNotificationName) + return nil +} + +// GetHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (r *Registry) GetHandler(pushNotificationName string) Handler { + r.mu.RLock() + defer r.mu.RUnlock() + + handler, exists := r.handlers[pushNotificationName] + if !exists { + return nil + } + return handler +} + +// GetRegisteredPushNotificationNames returns a list of all registered push notification names. +func (r *Registry) GetRegisteredPushNotificationNames() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + names := make([]string, 0, len(r.handlers)) + for name := range r.handlers { + names = append(names, name) + } + return names +} + + diff --git a/internal/pushnotif/types.go b/internal/pushnotif/types.go new file mode 100644 index 000000000..e60250e70 --- /dev/null +++ b/internal/pushnotif/types.go @@ -0,0 +1,29 @@ +package pushnotif + +import ( + "context" + + "github.com/redis/go-redis/v9/internal/proto" +) + +// Handler defines the interface for push notification handlers. +type Handler interface { + // HandlePushNotification processes a push notification. + // Returns true if the notification was handled, false otherwise. + HandlePushNotification(ctx context.Context, notification []interface{}) bool +} + +// ProcessorInterface defines the interface for push notification processors. +type ProcessorInterface interface { + GetHandler(pushNotificationName string) Handler + ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error + RegisterHandler(pushNotificationName string, handler Handler, protected bool) error +} + +// RegistryInterface defines the interface for push notification registries. +type RegistryInterface interface { + RegisterHandler(pushNotificationName string, handler Handler, protected bool) error + UnregisterHandler(pushNotificationName string) error + GetHandler(pushNotificationName string) Handler + GetRegisteredPushNotificationNames() []string +} diff --git a/options.go b/options.go index b87a234a4..2ffb8603c 100644 --- a/options.go +++ b/options.go @@ -216,6 +216,21 @@ type Options struct { // UnstableResp3 enables Unstable mode for Redis Search module with RESP3. // When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult UnstableResp3 bool + + // PushNotifications enables general push notification processing. + // When enabled, the client will process RESP3 push notifications and + // route them to registered handlers. + // + // For RESP3 connections (Protocol: 3), push notifications are always enabled + // and cannot be disabled. To avoid push notifications, use Protocol: 2 (RESP2). + // For RESP2 connections, push notifications are not available. + // + // default: always enabled for RESP3, disabled for RESP2 + PushNotifications bool + + // PushNotificationProcessor is the processor for handling push notifications. + // If nil, a default processor will be created when PushNotifications is enabled. + PushNotificationProcessor PushNotificationProcessorInterface } func (opt *Options) init() { @@ -592,5 +607,9 @@ func newConnPool( MaxActiveConns: opt.MaxActiveConns, ConnMaxIdleTime: opt.ConnMaxIdleTime, ConnMaxLifetime: opt.ConnMaxLifetime, + // Pass push notification processor for connection initialization + PushNotificationProcessor: opt.PushNotificationProcessor, + // Pass protocol version for push notification optimization + Protocol: opt.Protocol, }) } diff --git a/pubsub.go b/pubsub.go index 2a0e7a81e..da16d319d 100644 --- a/pubsub.go +++ b/pubsub.go @@ -38,12 +38,21 @@ type PubSub struct { chOnce sync.Once msgCh *channel allCh *channel + + // Push notification processor for handling generic push notifications + pushProcessor PushNotificationProcessorInterface } func (c *PubSub) init() { c.exit = make(chan struct{}) } +// SetPushNotificationProcessor sets the push notification processor for handling +// generic push notifications received on this PubSub connection. +func (c *PubSub) SetPushNotificationProcessor(processor PushNotificationProcessorInterface) { + c.pushProcessor = processor +} + func (c *PubSub) String() string { c.mu.Lock() defer c.mu.Unlock() @@ -367,6 +376,18 @@ func (p *Pong) String() string { return "Pong" } +// PushNotificationMessage represents a generic push notification received on a PubSub connection. +type PushNotificationMessage struct { + // Command is the push notification command (e.g., "MOVING", "CUSTOM_EVENT"). + Command string + // Args are the arguments following the command. + Args []interface{} +} + +func (m *PushNotificationMessage) String() string { + return fmt.Sprintf("push: %s", m.Command) +} + func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { switch reply := reply.(type) { case string: @@ -413,6 +434,19 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { Payload: reply[1].(string), }, nil default: + // Try to handle as generic push notification + ctx := c.getContext() + handler := c.pushProcessor.GetHandler(kind) + if handler != nil { + handled := handler.HandlePushNotification(ctx, reply) + if handled { + // Return a special message type to indicate it was handled + return &PushNotificationMessage{ + Command: kind, + Args: reply[1:], + }, nil + } + } return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind) } default: @@ -658,6 +692,9 @@ func (c *channel) initMsgChan() { // Ignore. case *Pong: // Ignore. + case *PushNotificationMessage: + // Ignore push notifications in message-only channel + // They are already handled by the push notification processor case *Message: timer.Reset(c.chanSendTimeout) select { @@ -712,7 +749,7 @@ func (c *channel) initAllChan() { switch msg := msg.(type) { case *Pong: // Ignore. - case *Subscription, *Message: + case *Subscription, *Message, *PushNotificationMessage: timer.Reset(c.chanSendTimeout) select { case c.allCh <- msg: diff --git a/push_notifications.go b/push_notifications.go new file mode 100644 index 000000000..18544f856 --- /dev/null +++ b/push_notifications.go @@ -0,0 +1,147 @@ +package redis + +import ( + "context" + + "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/pushnotif" +) + +// PushNotificationHandler defines the interface for push notification handlers. +// This is an alias to the internal push notification handler interface. +type PushNotificationHandler = pushnotif.Handler + +// PushNotificationProcessorInterface defines the interface for push notification processors. +// This is an alias to the internal push notification processor interface. +type PushNotificationProcessorInterface = pushnotif.ProcessorInterface + +// PushNotificationRegistry manages push notification handlers. +type PushNotificationRegistry struct { + registry *pushnotif.Registry +} + +// NewPushNotificationRegistry creates a new push notification registry. +func NewPushNotificationRegistry() *PushNotificationRegistry { + return &PushNotificationRegistry{ + registry: pushnotif.NewRegistry(), + } +} + +// RegisterHandler registers a handler for a specific push notification name. +func (r *PushNotificationRegistry) RegisterHandler(pushNotificationName string, handler PushNotificationHandler, protected bool) error { + return r.registry.RegisterHandler(pushNotificationName, handler, protected) +} + +// UnregisterHandler removes a handler for a specific push notification name. +func (r *PushNotificationRegistry) UnregisterHandler(pushNotificationName string) error { + return r.registry.UnregisterHandler(pushNotificationName) +} + +// GetHandler returns the handler for a specific push notification name. +func (r *PushNotificationRegistry) GetHandler(pushNotificationName string) PushNotificationHandler { + return r.registry.GetHandler(pushNotificationName) +} + +// GetRegisteredPushNotificationNames returns a list of all registered push notification names. +func (r *PushNotificationRegistry) GetRegisteredPushNotificationNames() []string { + return r.registry.GetRegisteredPushNotificationNames() +} + +// PushNotificationProcessor handles push notifications with a registry of handlers. +type PushNotificationProcessor struct { + processor *pushnotif.Processor +} + +// NewPushNotificationProcessor creates a new push notification processor. +func NewPushNotificationProcessor() *PushNotificationProcessor { + return &PushNotificationProcessor{ + processor: pushnotif.NewProcessor(), + } +} + +// GetHandler returns the handler for a specific push notification name. +func (p *PushNotificationProcessor) GetHandler(pushNotificationName string) PushNotificationHandler { + return p.processor.GetHandler(pushNotificationName) +} + +// RegisterHandler registers a handler for a specific push notification name. +func (p *PushNotificationProcessor) RegisterHandler(pushNotificationName string, handler PushNotificationHandler, protected bool) error { + return p.processor.RegisterHandler(pushNotificationName, handler, protected) +} + +// UnregisterHandler removes a handler for a specific push notification name. +func (p *PushNotificationProcessor) UnregisterHandler(pushNotificationName string) error { + return p.processor.UnregisterHandler(pushNotificationName) +} + +// ProcessPendingNotifications checks for and processes any pending push notifications. +func (p *PushNotificationProcessor) ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error { + return p.processor.ProcessPendingNotifications(ctx, rd) +} + +// VoidPushNotificationProcessor discards all push notifications without processing them. +type VoidPushNotificationProcessor struct { + processor *pushnotif.VoidProcessor +} + +// NewVoidPushNotificationProcessor creates a new void push notification processor. +func NewVoidPushNotificationProcessor() *VoidPushNotificationProcessor { + return &VoidPushNotificationProcessor{ + processor: pushnotif.NewVoidProcessor(), + } +} + +// GetHandler returns nil for void processor since it doesn't maintain handlers. +func (v *VoidPushNotificationProcessor) GetHandler(pushNotificationName string) PushNotificationHandler { + return nil +} + +// RegisterHandler returns an error for void processor since it doesn't maintain handlers. +func (v *VoidPushNotificationProcessor) RegisterHandler(pushNotificationName string, handler PushNotificationHandler, protected bool) error { + return v.processor.RegisterHandler(pushNotificationName, nil, protected) +} + +// ProcessPendingNotifications reads and discards any pending push notifications. +func (v *VoidPushNotificationProcessor) ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error { + return v.processor.ProcessPendingNotifications(ctx, rd) +} + +// Redis Cluster push notification names +const ( + PushNotificationMoving = "MOVING" + PushNotificationMigrating = "MIGRATING" + PushNotificationMigrated = "MIGRATED" + PushNotificationFailingOver = "FAILING_OVER" + PushNotificationFailedOver = "FAILED_OVER" +) + +// PushNotificationInfo contains metadata about a push notification. +type PushNotificationInfo struct { + Name string + Args []interface{} +} + +// ParsePushNotificationInfo extracts information from a push notification. +func ParsePushNotificationInfo(notification []interface{}) *PushNotificationInfo { + if len(notification) == 0 { + return nil + } + + name, ok := notification[0].(string) + if !ok { + return nil + } + + return &PushNotificationInfo{ + Name: name, + Args: notification[1:], + } +} + +// String returns a string representation of the push notification info. +func (info *PushNotificationInfo) String() string { + if info == nil { + return "" + } + return info.Name +} diff --git a/redis.go b/redis.go index a368623aa..b9e54fb88 100644 --- a/redis.go +++ b/redis.go @@ -207,6 +207,9 @@ type baseClient struct { hooksMixin onClose func() error // hook called when client is closed + + // Push notification processing + pushProcessor PushNotificationProcessorInterface } func (c *baseClient) clone() *baseClient { @@ -383,7 +386,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { // for redis-server versions that do not support the HELLO command, // RESP2 will continue to be used. - if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil { + if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil { // Authentication successful with HELLO command } else if !isRedisError(err) { // When the server responds with the RESP protocol and the result is not a normal @@ -530,7 +533,9 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) { readReplyFunc = cmd.readRawReply } - if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), readReplyFunc); err != nil { + if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error { + return readReplyFunc(rd) + }); err != nil { if cmd.readTimeout() == nil { atomic.StoreUint32(&retryTimeout, 1) } else { @@ -744,12 +749,25 @@ func NewClient(opt *Options) *Client { } opt.init() + // Push notifications are always enabled for RESP3 (cannot be disabled) + // Only override if no custom processor is provided + if opt.Protocol == 3 && opt.PushNotificationProcessor == nil { + opt.PushNotifications = true + } + c := Client{ baseClient: &baseClient{ opt: opt, }, } c.init() + + // Initialize push notification processor + c.initializePushProcessor() + + // Update options with the initialized push processor for connection pool + opt.PushNotificationProcessor = c.pushProcessor + c.connPool = newConnPool(opt, c.dialHook) return &c @@ -787,6 +805,47 @@ func (c *Client) Options() *Options { return c.opt } +// initializePushProcessor initializes the push notification processor for any client type. +// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient. +func initializePushProcessor(opt *Options) PushNotificationProcessorInterface { + // Always use custom processor if provided + if opt.PushNotificationProcessor != nil { + return opt.PushNotificationProcessor + } + + // For regular clients, respect the PushNotifications setting + if opt.PushNotifications { + // Create default processor when push notifications are enabled + return NewPushNotificationProcessor() + } + + // Create void processor when push notifications are disabled + return NewVoidPushNotificationProcessor() +} + +// initializePushProcessor initializes the push notification processor for this client. +func (c *Client) initializePushProcessor() { + c.pushProcessor = initializePushProcessor(c.opt) +} + +// RegisterPushNotificationHandler registers a handler for a specific push notification name. +// Returns an error if a handler is already registered for this push notification name. +// If protected is true, the handler cannot be unregistered. +func (c *Client) RegisterPushNotificationHandler(pushNotificationName string, handler PushNotificationHandler, protected bool) error { + return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected) +} + +// GetPushNotificationProcessor returns the push notification processor. +func (c *Client) GetPushNotificationProcessor() PushNotificationProcessorInterface { + return c.pushProcessor +} + +// GetPushNotificationHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (c *Client) GetPushNotificationHandler(pushNotificationName string) PushNotificationHandler { + return c.pushProcessor.GetHandler(pushNotificationName) +} + type PoolStats pool.Stats // PoolStats returns connection pool stats. @@ -833,6 +892,10 @@ func (c *Client) pubSub() *PubSub { closeConn: c.connPool.CloseConn, } pubsub.init() + + // Set the push notification processor + pubsub.SetPushNotificationProcessor(c.pushProcessor) + return pubsub } @@ -916,6 +979,10 @@ func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn c.hooksMixin = parentHooks.clone() } + // Initialize push notification processor using shared helper + // Use void processor by default for connections (typically don't need push notifications) + c.pushProcessor = initializePushProcessor(opt) + c.cmdable = c.Process c.statefulCmdable = c.Process c.initHooks(hooks{ @@ -934,6 +1001,18 @@ func (c *Conn) Process(ctx context.Context, cmd Cmder) error { return err } +// RegisterPushNotificationHandler registers a handler for a specific push notification name. +// Returns an error if a handler is already registered for this push notification name. +// If protected is true, the handler cannot be unregistered. +func (c *Conn) RegisterPushNotificationHandler(pushNotificationName string, handler PushNotificationHandler, protected bool) error { + return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected) +} + +// GetPushNotificationProcessor returns the push notification processor. +func (c *Conn) GetPushNotificationProcessor() PushNotificationProcessorInterface { + return c.pushProcessor +} + func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(ctx, fn) } diff --git a/sentinel.go b/sentinel.go index 04c0f7269..36283c5ba 100644 --- a/sentinel.go +++ b/sentinel.go @@ -61,6 +61,10 @@ type FailoverOptions struct { Protocol int Username string Password string + + // PushNotifications enables push notifications for RESP3. + // Defaults to true for RESP3 connections. + PushNotifications bool // CredentialsProvider allows the username and password to be updated // before reconnecting. It should return the current username and password. CredentialsProvider func() (username string, password string) @@ -129,6 +133,7 @@ func (opt *FailoverOptions) clientOptions() *Options { Protocol: opt.Protocol, Username: opt.Username, Password: opt.Password, + PushNotifications: opt.PushNotifications, CredentialsProvider: opt.CredentialsProvider, CredentialsProviderContext: opt.CredentialsProviderContext, StreamingCredentialsProvider: opt.StreamingCredentialsProvider, @@ -426,6 +431,10 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { } rdb.init() + // Initialize push notification processor using shared helper + // Use void processor by default for failover clients (typically don't need push notifications) + rdb.pushProcessor = initializePushProcessor(opt) + connPool = newConnPool(opt, rdb.dialHook) rdb.connPool = connPool rdb.onClose = rdb.wrappedOnClose(failover.Close) @@ -492,6 +501,10 @@ func NewSentinelClient(opt *Options) *SentinelClient { }, } + // Initialize push notification processor using shared helper + // Use void processor by default for sentinel clients (typically don't need push notifications) + c.pushProcessor = initializePushProcessor(opt) + c.initHooks(hooks{ dial: c.baseClient.dial, process: c.baseClient.process, @@ -501,6 +514,24 @@ func NewSentinelClient(opt *Options) *SentinelClient { return c } +// GetPushNotificationProcessor returns the push notification processor. +func (c *SentinelClient) GetPushNotificationProcessor() PushNotificationProcessorInterface { + return c.pushProcessor +} + +// GetPushNotificationHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (c *SentinelClient) GetPushNotificationHandler(pushNotificationName string) PushNotificationHandler { + return c.pushProcessor.GetHandler(pushNotificationName) +} + +// RegisterPushNotificationHandler registers a handler for a specific push notification name. +// Returns an error if a handler is already registered for this push notification name. +// If protected is true, the handler cannot be unregistered. +func (c *SentinelClient) RegisterPushNotificationHandler(pushNotificationName string, handler PushNotificationHandler, protected bool) error { + return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected) +} + func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { err := c.processHook(ctx, cmd) cmd.SetErr(err) diff --git a/tx.go b/tx.go index 0daa222e3..67689f57a 100644 --- a/tx.go +++ b/tx.go @@ -24,9 +24,10 @@ type Tx struct { func (c *Client) newTx() *Tx { tx := Tx{ baseClient: baseClient{ - opt: c.opt, - connPool: pool.NewStickyConnPool(c.connPool), - hooksMixin: c.hooksMixin.clone(), + opt: c.opt, + connPool: pool.NewStickyConnPool(c.connPool), + hooksMixin: c.hooksMixin.clone(), + pushProcessor: c.pushProcessor, // Copy push processor from parent client }, } tx.init()