diff --git a/cleaner_test.go b/cleaner_test.go index 22c28ab..2cf56c1 100644 --- a/cleaner_test.go +++ b/cleaner_test.go @@ -1,6 +1,7 @@ package rmq import ( + "context" "testing" "time" @@ -9,12 +10,14 @@ import ( ) func TestCleaner(t *testing.T) { - flushConn, err := OpenConnection("cleaner-flush", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + + flushConn, err := OpenConnection(ctx, "cleaner-flush", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) assert.NoError(t, flushConn.stopHeartbeat()) assert.NoError(t, flushConn.flushDb()) - conn, err := OpenConnection("cleaner-conn1", "tcp", "localhost:6379", 1, nil) + conn, err := OpenConnection(ctx, "cleaner-conn1", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queues, err := conn.GetOpenQueues() assert.NoError(t, err) @@ -86,7 +89,7 @@ func TestCleaner(t *testing.T) { require.NotNil(t, consumer.LastDelivery) assert.Equal(t, "del1", consumer.LastDelivery.Payload()) - assert.NoError(t, consumer.LastDelivery.Ack(nil)) + assert.NoError(t, consumer.LastDelivery.Ack()) time.Sleep(10 * time.Millisecond) count, err = queue.unackedCount() assert.NoError(t, err) @@ -109,7 +112,7 @@ func TestCleaner(t *testing.T) { assert.NoError(t, conn.stopHeartbeat()) time.Sleep(time.Millisecond) - conn, err = OpenConnection("cleaner-conn1", "tcp", "localhost:6379", 1, nil) + conn, err = OpenConnection(ctx, "cleaner-conn1", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue, err = conn.OpenQueue("q1") assert.NoError(t, err) @@ -172,7 +175,7 @@ func TestCleaner(t *testing.T) { assert.Equal(t, int64(6), count) assert.Equal(t, "del5", consumer.LastDelivery.Payload()) - assert.NoError(t, consumer.LastDelivery.Ack(nil)) + assert.NoError(t, consumer.LastDelivery.Ack()) time.Sleep(10 * time.Millisecond) count, err = queue.unackedCount() assert.NoError(t, err) @@ -185,7 +188,7 @@ func TestCleaner(t *testing.T) { assert.NoError(t, conn.stopHeartbeat()) time.Sleep(time.Millisecond) - cleanerConn, err := OpenConnection("cleaner-conn", "tcp", "localhost:6379", 1, nil) + cleanerConn, err := OpenConnection(ctx, "cleaner-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) cleaner := NewCleaner(cleanerConn) returned, err := cleaner.Clean() @@ -198,7 +201,7 @@ func TestCleaner(t *testing.T) { assert.NoError(t, err) assert.Len(t, queues, 2) - conn, err = OpenConnection("cleaner-conn1", "tcp", "localhost:6379", 1, nil) + conn, err = OpenConnection(ctx, "cleaner-conn1", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue, err = conn.OpenQueue("q1") assert.NoError(t, err) diff --git a/connection.go b/connection.go index 9e104ff..1a6606f 100644 --- a/connection.go +++ b/connection.go @@ -1,6 +1,7 @@ package rmq import ( + "context" "fmt" "strings" "time" @@ -46,6 +47,7 @@ type Connection interface { // Connection is the entry point. Use a connection to access queues, consumers and deliveries // Each connection has a single heartbeat shared among all consumers type redisConnection struct { + ctx context.Context Name string heartbeatKey string // key to keep alive queuesKey string // key to list of queues consumed by this connection @@ -59,31 +61,28 @@ type redisConnection struct { } // OpenConnection opens and returns a new connection -func OpenConnection(tag, network, address string, db int, errChan chan<- error) (Connection, error) { - redisClient := redis.NewClient(&redis.Options{ - Network: network, - Addr: address, - DB: db, - }) - return OpenConnectionWithRedisClient(tag, redisClient, errChan) +func OpenConnection(ctx context.Context, tag string, network string, address string, db int, errChan chan<- error) (Connection, error) { + redisClient := redis.NewClient(&redis.Options{Network: network, Addr: address, DB: db}) + return OpenConnectionWithRedisClient(ctx, tag, redisClient, errChan) } // OpenConnectionWithRedisClient opens and returns a new connection -func OpenConnectionWithRedisClient(tag string, redisClient *redis.Client, errChan chan<- error) (*redisConnection, error) { - return openConnectionWithRedisClient(tag, RedisWrapper{redisClient}, errChan) +func OpenConnectionWithRedisClient(ctx context.Context, tag string, redisClient *redis.Client, errChan chan<- error) (*redisConnection, error) { + return openConnectionWithRedisClient(ctx, tag, RedisWrapper{redisClient}, errChan) } // OpenConnectionWithTestRedisClient opens and returns a new connection which // uses a test redis client internally. This is useful in integration tests. -func OpenConnectionWithTestRedisClient(tag string, errChan chan<- error) (*redisConnection, error) { - return openConnectionWithRedisClient(tag, NewTestRedisClient(), errChan) +func OpenConnectionWithTestRedisClient(ctx context.Context, tag string, errChan chan<- error) (*redisConnection, error) { + return openConnectionWithRedisClient(ctx, tag, NewTestRedisClient(), errChan) } -func openConnectionWithRedisClient(tag string, redisClient RedisClient, errChan chan<- error) (*redisConnection, error) { +func openConnectionWithRedisClient(ctx context.Context, tag string, redisClient RedisClient, errChan chan<- error) (*redisConnection, error) { name := fmt.Sprintf("%s-%s", tag, uniuri.NewLen(6)) connection := &redisConnection{ Name: name, + ctx: ctx, heartbeatKey: strings.Replace(connectionHeartbeatTemplate, phConnection, name, 1), queuesKey: strings.Replace(connectionQueuesTemplate, phConnection, name, 1), redisClient: redisClient, @@ -249,6 +248,7 @@ func (connection *redisConnection) getConsumingQueues() ([]string, error) { // openQueue opens a queue without adding it to the set of queues func (connection *redisConnection) openQueue(name string) Queue { return newQueue( + connection.ctx, name, connection.Name, connection.queuesKey, diff --git a/deliveries.go b/deliveries.go index dd5f1ed..8225462 100644 --- a/deliveries.go +++ b/deliveries.go @@ -1,7 +1,5 @@ package rmq -import "context" - type Deliveries []Delivery func (deliveries Deliveries) Payloads() []string { @@ -19,26 +17,25 @@ func (deliveries Deliveries) Payloads() []string { // functions with retry, see comments in delivery.go (recommended) -func (deliveries Deliveries) Ack(ctx context.Context) (errMap map[int]error) { - return deliveries.each(ctx, Delivery.Ack) +func (deliveries Deliveries) Ack() (errMap map[int]error) { + return deliveries.each(Delivery.Ack) } -func (deliveries Deliveries) Reject(ctx context.Context) (errMap map[int]error) { - return deliveries.each(ctx, Delivery.Reject) +func (deliveries Deliveries) Reject() (errMap map[int]error) { + return deliveries.each(Delivery.Reject) } -func (deliveries Deliveries) Push(ctx context.Context) (errMap map[int]error) { - return deliveries.each(ctx, Delivery.Push) +func (deliveries Deliveries) Push() (errMap map[int]error) { + return deliveries.each(Delivery.Push) } // helper functions func (deliveries Deliveries) each( - ctx context.Context, - f func(Delivery, context.Context) error, + f func(Delivery) error, ) (errMap map[int]error) { for i, delivery := range deliveries { - if err := f(delivery, ctx); err != nil { + if err := f(delivery); err != nil { if errMap == nil { // create error map lazily on demand errMap = map[int]error{} } diff --git a/delivery.go b/delivery.go index b1bc5bd..0ae6647 100644 --- a/delivery.go +++ b/delivery.go @@ -9,12 +9,13 @@ import ( type Delivery interface { Payload() string - Ack(context.Context) error - Reject(context.Context) error - Push(context.Context) error + Ack() error + Reject() error + Push() error } type redisDelivery struct { + ctx context.Context payload string unackedKey string rejectedKey string @@ -24,6 +25,7 @@ type redisDelivery struct { } func newDelivery( + ctx context.Context, payload string, unackedKey string, rejectedKey string, @@ -32,6 +34,7 @@ func newDelivery( errChan chan<- error, ) *redisDelivery { return &redisDelivery{ + ctx: ctx, payload: payload, unackedKey: unackedKey, rejectedKey: rejectedKey, @@ -55,11 +58,7 @@ func (delivery *redisDelivery) Payload() string { // 3. if the context is cancalled or its timeout exceeded, context.Cancelled or // context.DeadlineExceeded will be returned -func (delivery *redisDelivery) Ack(ctx context.Context) error { - if ctx == nil { // TODO: remove this - ctx = context.TODO() - } - +func (delivery *redisDelivery) Ack() error { errorCount := 0 for { count, err := delivery.redisClient.LRem(delivery.unackedKey, 1, delivery.payload) @@ -79,7 +78,7 @@ func (delivery *redisDelivery) Ack(ctx context.Context) error { default: } - if err := ctx.Err(); err != nil { + if err := delivery.ctx.Err(); err != nil { return err } @@ -87,19 +86,19 @@ func (delivery *redisDelivery) Ack(ctx context.Context) error { } } -func (delivery *redisDelivery) Reject(ctx context.Context) error { - return delivery.move(ctx, delivery.rejectedKey) +func (delivery *redisDelivery) Reject() error { + return delivery.move(delivery.rejectedKey) } -func (delivery *redisDelivery) Push(ctx context.Context) error { +func (delivery *redisDelivery) Push() error { if delivery.pushKey == "" { - return delivery.Reject(ctx) // fall back to rejecting + return delivery.Reject() // fall back to rejecting } - return delivery.move(ctx, delivery.pushKey) + return delivery.move(delivery.pushKey) } -func (delivery *redisDelivery) move(ctx context.Context, key string) error { +func (delivery *redisDelivery) move(key string) error { errorCount := 0 for { _, err := delivery.redisClient.LPush(key, delivery.payload) @@ -118,7 +117,7 @@ func (delivery *redisDelivery) move(ctx context.Context, key string) error { time.Sleep(time.Second) } - return delivery.Ack(ctx) + return delivery.Ack() } // lower level functions which don't retry but just return the first error diff --git a/example/batch_consumer/main.go b/example/batch_consumer/main.go index 6483b96..c4b3fce 100644 --- a/example/batch_consumer/main.go +++ b/example/batch_consumer/main.go @@ -41,13 +41,12 @@ func main() { } }() - connection, err := rmq.OpenConnection("consumer", "tcp", "localhost:6379", 2, errChan) + ctx, cancel := context.WithCancel(context.Background()) + connection, err := rmq.OpenConnection(ctx, "consumer", "tcp", "localhost:6379", 2, errChan) if err != nil { panic(err) } - ctx, cancel := context.WithCancel(context.Background()) - for _, queueName := range []string{ "things", "balls", @@ -59,7 +58,7 @@ func main() { if err := queue.StartConsuming(unackedLimit, pollDuration); err != nil { panic(err) } - if _, err := queue.AddBatchConsumer(queueName, batchSize, batchTimeout, NewBatchConsumer(ctx, queueName)); err != nil { + if _, err := queue.AddBatchConsumer(queueName, batchSize, batchTimeout, NewBatchConsumer(queueName)); err != nil { panic(err) } } @@ -82,15 +81,11 @@ func main() { } type BatchConsumer struct { - ctx context.Context tag string } -func NewBatchConsumer(ctx context.Context, tag string) *BatchConsumer { - return &BatchConsumer{ - ctx: ctx, - tag: tag, - } +func NewBatchConsumer(tag string) *BatchConsumer { + return &BatchConsumer{tag: tag} } func (consumer *BatchConsumer) Consume(batch rmq.Deliveries) { @@ -99,7 +94,7 @@ func (consumer *BatchConsumer) Consume(batch rmq.Deliveries) { time.Sleep(consumeDuration) log.Printf("%s consumed %d: %s", consumer.tag, len(batch), batch[0]) - errors := batch.Ack(consumer.ctx) + errors := batch.Ack() if len(errors) == 0 { debugf("acked %q", payloads) return diff --git a/example/cleaner/main.go b/example/cleaner/main.go index c481f75..c2ca848 100644 --- a/example/cleaner/main.go +++ b/example/cleaner/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "time" @@ -8,7 +9,7 @@ import ( ) func main() { - connection, err := rmq.OpenConnection("cleaner", "tcp", "localhost:6379", 2, nil) + connection, err := rmq.OpenConnection(context.Background(), "cleaner", "tcp", "localhost:6379", 2, nil) if err != nil { panic(err) } diff --git a/example/consumer/main.go b/example/consumer/main.go index 2938dbe..e1091ac 100644 --- a/example/consumer/main.go +++ b/example/consumer/main.go @@ -41,7 +41,8 @@ func main() { } }() - connection, err := rmq.OpenConnection("consumer", "tcp", "localhost:6379", 2, errChan) + ctx, cancel := context.WithCancel(context.Background()) + connection, err := rmq.OpenConnection(ctx, "consumer", "tcp", "localhost:6379", 2, errChan) if err != nil { panic(err) } @@ -55,11 +56,9 @@ func main() { panic(err) } - ctx, cancel := context.WithCancel(context.Background()) - for i := 0; i < numConsumers; i++ { name := fmt.Sprintf("consumer %d", i) - if _, err := queue.AddConsumer(name, NewConsumer(ctx, i)); err != nil { + if _, err := queue.AddConsumer(name, NewConsumer(i)); err != nil { panic(err) } } @@ -82,15 +81,13 @@ func main() { } type Consumer struct { - ctx context.Context name string count int before time.Time } -func NewConsumer(ctx context.Context, tag int) *Consumer { +func NewConsumer(tag int) *Consumer { return &Consumer{ - ctx: ctx, name: fmt.Sprintf("consumer%d", tag), count: 0, before: time.Now(), @@ -111,13 +108,13 @@ func (consumer *Consumer) Consume(delivery rmq.Delivery) { } if consumer.count%batchSize > 0 { - if err := delivery.Ack(consumer.ctx); err != nil { + if err := delivery.Ack(); err != nil { debugf("failed to ack %s: %s", payload, err) } else { debugf("acked %s", payload) } } else { // reject one per batch - if err := delivery.Reject(consumer.ctx); err != nil { + if err := delivery.Reject(); err != nil { debugf("failed to reject %s: %s", payload, err) } else { debugf("rejected %s", payload) diff --git a/example/handler/main.go b/example/handler/main.go index 383ed72..5cb190d 100644 --- a/example/handler/main.go +++ b/example/handler/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "log" "net/http" @@ -9,7 +10,7 @@ import ( ) func main() { - connection, err := rmq.OpenConnection("handler", "tcp", "localhost:6379", 2, nil) + connection, err := rmq.OpenConnection(context.Background(), "handler", "tcp", "localhost:6379", 2, nil) if err != nil { panic(err) } diff --git a/example/producer/main.go b/example/producer/main.go index e5132ce..3d0daf2 100644 --- a/example/producer/main.go +++ b/example/producer/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "log" "time" @@ -14,7 +15,7 @@ const ( ) func main() { - connection, err := rmq.OpenConnection("producer", "tcp", "localhost:6379", 2, nil) + connection, err := rmq.OpenConnection(context.Background(), "producer", "tcp", "localhost:6379", 2, nil) if err != nil { panic(err) } diff --git a/example/purger/main.go b/example/purger/main.go index d55e357..1f80366 100644 --- a/example/purger/main.go +++ b/example/purger/main.go @@ -1,12 +1,14 @@ package main import ( - "github.com/adjust/rmq/v2" + "context" "log" + + "github.com/adjust/rmq/v2" ) func main() { - connection, err := rmq.OpenConnection("cleaner", "tcp", "localhost:6379", 2, nil) + connection, err := rmq.OpenConnection(context.Background(), "cleaner", "tcp", "localhost:6379", 2, nil) if err != nil { panic(err) } diff --git a/example/returner/main.go b/example/returner/main.go index be06966..b13421b 100644 --- a/example/returner/main.go +++ b/example/returner/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "math" @@ -8,7 +9,7 @@ import ( ) func main() { - connection, err := rmq.OpenConnection("returner", "tcp", "localhost:6379", 2, nil) + connection, err := rmq.OpenConnection(context.Background(), "returner", "tcp", "localhost:6379", 2, nil) if err != nil { panic(err) } diff --git a/queue.go b/queue.go index e38adb1..df15973 100644 --- a/queue.go +++ b/queue.go @@ -1,6 +1,7 @@ package rmq import ( + "context" "fmt" "strings" "sync" @@ -40,6 +41,7 @@ type Queue interface { } type redisQueue struct { + ctx context.Context name string connectionName string queuesKey string // key to list of queues consumed by this connection @@ -58,6 +60,7 @@ type redisQueue struct { } func newQueue( + ctx context.Context, name string, connectionName string, queuesKey string, @@ -75,6 +78,7 @@ func newQueue( unackedKey = strings.Replace(unackedKey, phQueue, name, 1) queue := &redisQueue{ + ctx: ctx, name: name, connectionName: connectionName, queuesKey: queuesKey, @@ -208,7 +212,7 @@ func (queue *redisQueue) consumeBatch() error { func (queue *redisQueue) newDelivery(payload string) Delivery { return newDelivery( - // queue.ctx, + queue.ctx, payload, queue.unackedKey, queue.rejectedKey, diff --git a/queue_test.go b/queue_test.go index 72ecd02..c105820 100644 --- a/queue_test.go +++ b/queue_test.go @@ -1,6 +1,7 @@ package rmq import ( + "context" "fmt" "math" "strconv" @@ -12,12 +13,13 @@ import ( ) func TestConnections(t *testing.T) { - flushConn, err := OpenConnection("conns-flush", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + flushConn, err := OpenConnection(ctx, "conns-flush", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) assert.NoError(t, flushConn.stopHeartbeat()) assert.NoError(t, flushConn.flushDb()) - connection, err := OpenConnection("conns-conn", "tcp", "localhost:6379", 1, nil) + connection, err := OpenConnection(ctx, "conns-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) require.NotNil(t, connection) _, err = NewCleaner(connection).Clean() @@ -27,14 +29,14 @@ func TestConnections(t *testing.T) { assert.NoError(t, err) assert.Len(t, connections, 1) // cleaner connection remains - conn1, err := OpenConnection("conns-conn1", "tcp", "localhost:6379", 1, nil) + conn1, err := OpenConnection(ctx, "conns-conn1", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) connections, err = connection.getConnections() assert.NoError(t, err) assert.Len(t, connections, 2) assert.Equal(t, ErrorNotFound, connection.hijackConnection("nope").checkHeartbeat()) assert.NoError(t, conn1.checkHeartbeat()) - conn2, err := OpenConnection("conns-conn2", "tcp", "localhost:6379", 1, nil) + conn2, err := OpenConnection(ctx, "conns-conn2", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) connections, err = connection.getConnections() assert.NoError(t, err) @@ -61,7 +63,8 @@ func TestConnections(t *testing.T) { } func TestConnectionQueues(t *testing.T) { - connection, err := OpenConnection("conn-q-conn", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "conn-q-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) require.NotNil(t, connection) @@ -131,7 +134,8 @@ func TestConnectionQueues(t *testing.T) { } func TestQueueCommon(t *testing.T) { - connection, err := OpenConnection("queue-conn", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "queue-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) require.NotNil(t, connection) @@ -184,7 +188,8 @@ func TestQueueCommon(t *testing.T) { } func TestConsumerCommon(t *testing.T) { - connection, err := OpenConnection("cons-conn", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "cons-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) require.NotNil(t, connection) @@ -222,7 +227,7 @@ func TestConsumerCommon(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(2), count) - assert.NoError(t, consumer.LastDeliveries[0].Ack(nil)) + assert.NoError(t, consumer.LastDeliveries[0].Ack()) count, err = queue1.readyCount() assert.NoError(t, err) assert.Equal(t, int64(0), count) @@ -230,7 +235,7 @@ func TestConsumerCommon(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(1), count) - assert.NoError(t, consumer.LastDeliveries[1].Ack(nil)) + assert.NoError(t, consumer.LastDeliveries[1].Ack()) count, err = queue1.readyCount() assert.NoError(t, err) assert.Equal(t, int64(0), count) @@ -238,7 +243,7 @@ func TestConsumerCommon(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(0), count) - assert.Equal(t, ErrorNotFound, consumer.LastDeliveries[0].Ack(nil)) + assert.Equal(t, ErrorNotFound, consumer.LastDeliveries[0].Ack()) assert.NoError(t, queue1.Publish("cons-d3")) time.Sleep(2 * time.Millisecond) @@ -252,7 +257,7 @@ func TestConsumerCommon(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(0), count) assert.Equal(t, "cons-d3", consumer.LastDelivery.Payload()) - assert.NoError(t, consumer.LastDelivery.Reject(nil)) + assert.NoError(t, consumer.LastDelivery.Reject()) count, err = queue1.readyCount() assert.NoError(t, err) assert.Equal(t, int64(0), count) @@ -275,7 +280,7 @@ func TestConsumerCommon(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(1), count) assert.Equal(t, "cons-d4", consumer.LastDelivery.Payload()) - assert.NoError(t, consumer.LastDelivery.Reject(nil)) + assert.NoError(t, consumer.LastDelivery.Reject()) count, err = queue1.readyCount() assert.NoError(t, err) assert.Equal(t, int64(0), count) @@ -303,7 +308,7 @@ func TestConsumerCommon(t *testing.T) { payload := "cons-func-payload" _, err = queue2.AddConsumerFunc("cons-func", func(delivery Delivery) { - err = delivery.Ack(nil) + err = delivery.Ack() assert.NoError(t, err) payloadChan <- delivery.Payload() }) @@ -325,7 +330,8 @@ func TestConsumerCommon(t *testing.T) { } func TestMulti(t *testing.T) { - connection, err := OpenConnection("multi-conn", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "multi-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue, err := connection.OpenQueue("multi-q") assert.NoError(t, err) @@ -366,7 +372,7 @@ func TestMulti(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(10), count) - assert.NoError(t, consumer.LastDelivery.Ack(nil)) + assert.NoError(t, consumer.LastDelivery.Ack()) time.Sleep(10 * time.Millisecond) count, err = queue.readyCount() assert.NoError(t, err) @@ -384,7 +390,7 @@ func TestMulti(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(10), count) - assert.NoError(t, consumer.LastDelivery.Ack(nil)) + assert.NoError(t, consumer.LastDelivery.Ack()) time.Sleep(10 * time.Millisecond) count, err = queue.readyCount() assert.NoError(t, err) @@ -407,7 +413,8 @@ func TestMulti(t *testing.T) { } func TestBatch(t *testing.T) { - connection, err := OpenConnection("batch-conn", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "batch-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue, err := connection.OpenQueue("batch-q") assert.NoError(t, err) @@ -434,8 +441,8 @@ func TestBatch(t *testing.T) { require.Len(t, consumer.LastBatch, 2) assert.Equal(t, "batch-d0", consumer.LastBatch[0].Payload()) assert.Equal(t, "batch-d1", consumer.LastBatch[1].Payload()) - assert.NoError(t, consumer.LastBatch[0].Reject(nil)) - assert.NoError(t, consumer.LastBatch[1].Ack(nil)) + assert.NoError(t, consumer.LastBatch[0].Reject()) + assert.NoError(t, consumer.LastBatch[1].Ack()) count, err = queue.unackedCount() assert.NoError(t, err) assert.Equal(t, int64(3), count) @@ -448,8 +455,8 @@ func TestBatch(t *testing.T) { require.Len(t, consumer.LastBatch, 2) assert.Equal(t, "batch-d2", consumer.LastBatch[0].Payload()) assert.Equal(t, "batch-d3", consumer.LastBatch[1].Payload()) - assert.NoError(t, consumer.LastBatch[0].Reject(nil)) - assert.NoError(t, consumer.LastBatch[1].Ack(nil)) + assert.NoError(t, consumer.LastBatch[0].Reject()) + assert.NoError(t, consumer.LastBatch[1].Ack()) count, err = queue.unackedCount() assert.NoError(t, err) assert.Equal(t, int64(1), count) @@ -470,7 +477,7 @@ func TestBatch(t *testing.T) { time.Sleep(60 * time.Millisecond) require.Len(t, consumer.LastBatch, 1) assert.Equal(t, "batch-d4", consumer.LastBatch[0].Payload()) - assert.NoError(t, consumer.LastBatch[0].Reject(nil)) + assert.NoError(t, consumer.LastBatch[0].Reject()) count, err = queue.unackedCount() assert.NoError(t, err) assert.Equal(t, int64(0), count) @@ -480,7 +487,8 @@ func TestBatch(t *testing.T) { } func TestReturnRejected(t *testing.T) { - connection, err := OpenConnection("return-conn", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "return-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue, err := connection.OpenQueue("return-q") assert.NoError(t, err) @@ -530,12 +538,12 @@ func TestReturnRejected(t *testing.T) { assert.Equal(t, int64(0), count) assert.Len(t, consumer.LastDeliveries, 6) - assert.NoError(t, consumer.LastDeliveries[0].Reject(nil)) - assert.NoError(t, consumer.LastDeliveries[1].Ack(nil)) - assert.NoError(t, consumer.LastDeliveries[2].Reject(nil)) - assert.NoError(t, consumer.LastDeliveries[3].Reject(nil)) + assert.NoError(t, consumer.LastDeliveries[0].Reject()) + assert.NoError(t, consumer.LastDeliveries[1].Ack()) + assert.NoError(t, consumer.LastDeliveries[2].Reject()) + assert.NoError(t, consumer.LastDeliveries[3].Reject()) // delivery 4 still open - assert.NoError(t, consumer.LastDeliveries[5].Reject(nil)) + assert.NoError(t, consumer.LastDeliveries[5].Reject()) time.Sleep(time.Millisecond) count, err = queue.readyCount() @@ -578,7 +586,8 @@ func TestReturnRejected(t *testing.T) { } func TestPushQueue(t *testing.T) { - connection, err := OpenConnection("push", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "push", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue1, err := connection.OpenQueue("queue1") assert.NoError(t, err) @@ -608,7 +617,7 @@ func TestPushQueue(t *testing.T) { assert.Equal(t, int64(1), count) require.Len(t, consumer1.LastDeliveries, 1) - assert.NoError(t, consumer1.LastDelivery.Push(nil)) + assert.NoError(t, consumer1.LastDelivery.Push()) time.Sleep(2 * time.Millisecond) count, err = queue1.unackedCount() assert.NoError(t, err) @@ -618,7 +627,7 @@ func TestPushQueue(t *testing.T) { assert.Equal(t, int64(1), count) require.Len(t, consumer2.LastDeliveries, 1) - assert.NoError(t, consumer2.LastDelivery.Push(nil)) + assert.NoError(t, consumer2.LastDelivery.Push()) time.Sleep(2 * time.Millisecond) count, err = queue2.rejectedCount() assert.NoError(t, err) @@ -626,7 +635,8 @@ func TestPushQueue(t *testing.T) { } func TestConsuming(t *testing.T) { - connection, err := OpenConnection("consume", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "consume", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue, err := connection.OpenQueue("consume-q") assert.NoError(t, err) @@ -651,7 +661,8 @@ func TestConsuming(t *testing.T) { } func TestStopConsuming_Consumer(t *testing.T) { - connection, err := OpenConnection("consume", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "consume", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue, err := connection.OpenQueue("consume-q") assert.NoError(t, err) @@ -695,7 +706,8 @@ func TestStopConsuming_Consumer(t *testing.T) { } func TestStopConsuming_BatchConsumer(t *testing.T) { - connection, err := OpenConnection("batchConsume", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "batchConsume", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) queue, err := connection.OpenQueue("batchConsume-q") assert.NoError(t, err) @@ -743,11 +755,13 @@ func TestStopConsuming_BatchConsumer(t *testing.T) { func BenchmarkQueue(b *testing.B) { // open queue - connection, err := OpenConnection("bench-conn", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "bench-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(b, err) queueName := fmt.Sprintf("bench-q%d", b.N) queue, err := connection.OpenQueue(queueName) assert.NoError(b, err) + assert.NoError(b, queue.StartConsuming(10, time.Millisecond)) // add some consumers numConsumers := 10 @@ -756,7 +770,6 @@ func BenchmarkQueue(b *testing.B) { consumer := NewTestConsumer("bench-A") // consumer.SleepDuration = time.Microsecond consumers = append(consumers, consumer) - assert.NoError(b, queue.StartConsuming(10, time.Millisecond)) _, err = queue.AddConsumer("bench-cons", consumer) assert.NoError(b, err) } diff --git a/stats_test.go b/stats_test.go index 4731c63..c560f5a 100644 --- a/stats_test.go +++ b/stats_test.go @@ -1,6 +1,7 @@ package rmq import ( + "context" "testing" "time" @@ -9,14 +10,15 @@ import ( ) func TestStats(t *testing.T) { - connection, err := OpenConnection("stats-conn", "tcp", "localhost:6379", 1, nil) + ctx := context.Background() + connection, err := OpenConnection(ctx, "stats-conn", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) _, err = NewCleaner(connection).Clean() require.NoError(t, err) - conn1, err := OpenConnection("stats-conn1", "tcp", "localhost:6379", 1, nil) + conn1, err := OpenConnection(ctx, "stats-conn1", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) - conn2, err := OpenConnection("stats-conn2", "tcp", "localhost:6379", 1, nil) + conn2, err := OpenConnection(ctx, "stats-conn2", "tcp", "localhost:6379", 1, nil) assert.NoError(t, err) q1, err := conn2.OpenQueue("stats-q1") assert.NoError(t, err) @@ -36,8 +38,8 @@ func TestStats(t *testing.T) { assert.NoError(t, q2.Publish("stats-d3")) assert.NoError(t, q2.Publish("stats-d4")) time.Sleep(2 * time.Millisecond) - assert.NoError(t, consumer.LastDeliveries[0].Ack(nil)) - assert.NoError(t, consumer.LastDeliveries[1].Reject(nil)) + assert.NoError(t, consumer.LastDeliveries[0].Ack()) + assert.NoError(t, consumer.LastDeliveries[1].Reject()) _, err = q2.AddConsumer("stats-cons2", NewTestConsumer("hand-B")) assert.NoError(t, err) diff --git a/test_batch_consumer.go b/test_batch_consumer.go index 60404ff..ba742c5 100644 --- a/test_batch_consumer.go +++ b/test_batch_consumer.go @@ -18,7 +18,7 @@ func (consumer *TestBatchConsumer) Consume(batch Deliveries) { consumer.LastBatch = batch consumer.ConsumedCount += int64(len(batch)) if consumer.AutoFinish { - batch.Ack(nil) + batch.Ack() } else { <-consumer.finish // log.Printf("TestBatchConsumer.Consume() finished") diff --git a/test_consumer.go b/test_consumer.go index 68fadda..877c11d 100644 --- a/test_consumer.go +++ b/test_consumer.go @@ -37,7 +37,7 @@ func (consumer *TestConsumer) Consume(delivery Delivery) { time.Sleep(consumer.SleepDuration) } if consumer.AutoAck { - if err := delivery.Ack(nil); err != nil { + if err := delivery.Ack(); err != nil { panic(err) } } diff --git a/test_delivery.go b/test_delivery.go index 0e94b85..e094424 100644 --- a/test_delivery.go +++ b/test_delivery.go @@ -1,7 +1,6 @@ package rmq import ( - "context" "encoding/json" ) @@ -33,7 +32,7 @@ func (delivery *TestDelivery) Payload() string { return delivery.payload } -func (delivery *TestDelivery) Ack(context.Context) error { +func (delivery *TestDelivery) Ack() error { if delivery.State != Unacked { return ErrorNotFound } @@ -41,7 +40,7 @@ func (delivery *TestDelivery) Ack(context.Context) error { return nil } -func (delivery *TestDelivery) Reject(context.Context) error { +func (delivery *TestDelivery) Reject() error { if delivery.State != Unacked { return ErrorNotFound } @@ -49,7 +48,7 @@ func (delivery *TestDelivery) Reject(context.Context) error { return nil } -func (delivery *TestDelivery) Push(context.Context) error { +func (delivery *TestDelivery) Push() error { if delivery.State != Unacked { return ErrorNotFound } diff --git a/test_delivery_test.go b/test_delivery_test.go index a97ad0d..a2ced26 100644 --- a/test_delivery_test.go +++ b/test_delivery_test.go @@ -9,28 +9,28 @@ import ( func TestDeliveryPayload(t *testing.T) { var delivery Delivery delivery = NewTestDelivery("p23") - assert.NoError(t, delivery.Ack(nil)) + assert.NoError(t, delivery.Ack()) assert.Equal(t, "p23", delivery.Payload()) } func TestDeliveryAck(t *testing.T) { delivery := NewTestDelivery("p") assert.Equal(t, Unacked, delivery.State) - assert.NoError(t, delivery.Ack(nil)) + assert.NoError(t, delivery.Ack()) assert.Equal(t, Acked, delivery.State) - assert.Equal(t, ErrorNotFound, delivery.Ack(nil)) - assert.Equal(t, ErrorNotFound, delivery.Reject(nil)) + assert.Equal(t, ErrorNotFound, delivery.Ack()) + assert.Equal(t, ErrorNotFound, delivery.Reject()) assert.Equal(t, Acked, delivery.State) } func TestDeliveryReject(t *testing.T) { delivery := NewTestDelivery("p") assert.Equal(t, Unacked, delivery.State) - assert.NoError(t, delivery.Reject(nil)) + assert.NoError(t, delivery.Reject()) assert.Equal(t, Rejected, delivery.State) - assert.Equal(t, ErrorNotFound, delivery.Reject(nil)) - assert.Equal(t, ErrorNotFound, delivery.Ack(nil)) + assert.Equal(t, ErrorNotFound, delivery.Reject()) + assert.Equal(t, ErrorNotFound, delivery.Ack()) assert.Equal(t, Rejected, delivery.State) }