Skip to content

Commit

Permalink
Inject context into OpenConnection() functions
Browse files Browse the repository at this point in the history
And pass it to queues and deliveries. Don't inject context into delivery
functions like Ack().
  • Loading branch information
wellle committed May 23, 2020
1 parent b4dae04 commit 2c8c791
Show file tree
Hide file tree
Showing 18 changed files with 142 additions and 127 deletions.
17 changes: 10 additions & 7 deletions cleaner_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rmq

import (
"context"
"testing"
"time"

Expand All @@ -9,12 +10,14 @@ import (
)

func TestCleaner(t *testing.T) {
flushConn, err := OpenConnection("cleaner-flush", "tcp", "localhost:6379", 1, nil)
ctx := context.Background()

flushConn, err := OpenConnection(ctx, "cleaner-flush", "tcp", "localhost:6379", 1, nil)
assert.NoError(t, err)
assert.NoError(t, flushConn.stopHeartbeat())
assert.NoError(t, flushConn.flushDb())

conn, err := OpenConnection("cleaner-conn1", "tcp", "localhost:6379", 1, nil)
conn, err := OpenConnection(ctx, "cleaner-conn1", "tcp", "localhost:6379", 1, nil)
assert.NoError(t, err)
queues, err := conn.GetOpenQueues()
assert.NoError(t, err)
Expand Down Expand Up @@ -86,7 +89,7 @@ func TestCleaner(t *testing.T) {

require.NotNil(t, consumer.LastDelivery)
assert.Equal(t, "del1", consumer.LastDelivery.Payload())
assert.NoError(t, consumer.LastDelivery.Ack(nil))
assert.NoError(t, consumer.LastDelivery.Ack())
time.Sleep(10 * time.Millisecond)
count, err = queue.unackedCount()
assert.NoError(t, err)
Expand All @@ -109,7 +112,7 @@ func TestCleaner(t *testing.T) {
assert.NoError(t, conn.stopHeartbeat())
time.Sleep(time.Millisecond)

conn, err = OpenConnection("cleaner-conn1", "tcp", "localhost:6379", 1, nil)
conn, err = OpenConnection(ctx, "cleaner-conn1", "tcp", "localhost:6379", 1, nil)
assert.NoError(t, err)
queue, err = conn.OpenQueue("q1")
assert.NoError(t, err)
Expand Down Expand Up @@ -172,7 +175,7 @@ func TestCleaner(t *testing.T) {
assert.Equal(t, int64(6), count)

assert.Equal(t, "del5", consumer.LastDelivery.Payload())
assert.NoError(t, consumer.LastDelivery.Ack(nil))
assert.NoError(t, consumer.LastDelivery.Ack())
time.Sleep(10 * time.Millisecond)
count, err = queue.unackedCount()
assert.NoError(t, err)
Expand All @@ -185,7 +188,7 @@ func TestCleaner(t *testing.T) {
assert.NoError(t, conn.stopHeartbeat())
time.Sleep(time.Millisecond)

cleanerConn, err := OpenConnection("cleaner-conn", "tcp", "localhost:6379", 1, nil)
cleanerConn, err := OpenConnection(ctx, "cleaner-conn", "tcp", "localhost:6379", 1, nil)
assert.NoError(t, err)
cleaner := NewCleaner(cleanerConn)
returned, err := cleaner.Clean()
Expand All @@ -198,7 +201,7 @@ func TestCleaner(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, queues, 2)

conn, err = OpenConnection("cleaner-conn1", "tcp", "localhost:6379", 1, nil)
conn, err = OpenConnection(ctx, "cleaner-conn1", "tcp", "localhost:6379", 1, nil)
assert.NoError(t, err)
queue, err = conn.OpenQueue("q1")
assert.NoError(t, err)
Expand Down
24 changes: 12 additions & 12 deletions connection.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rmq

import (
"context"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -46,6 +47,7 @@ type Connection interface {
// Connection is the entry point. Use a connection to access queues, consumers and deliveries
// Each connection has a single heartbeat shared among all consumers
type redisConnection struct {
ctx context.Context
Name string
heartbeatKey string // key to keep alive
queuesKey string // key to list of queues consumed by this connection
Expand All @@ -59,31 +61,28 @@ type redisConnection struct {
}

// OpenConnection opens and returns a new connection
func OpenConnection(tag, network, address string, db int, errChan chan<- error) (Connection, error) {
redisClient := redis.NewClient(&redis.Options{
Network: network,
Addr: address,
DB: db,
})
return OpenConnectionWithRedisClient(tag, redisClient, errChan)
func OpenConnection(ctx context.Context, tag string, network string, address string, db int, errChan chan<- error) (Connection, error) {
redisClient := redis.NewClient(&redis.Options{Network: network, Addr: address, DB: db})
return OpenConnectionWithRedisClient(ctx, tag, redisClient, errChan)
}

// OpenConnectionWithRedisClient opens and returns a new connection
func OpenConnectionWithRedisClient(tag string, redisClient *redis.Client, errChan chan<- error) (*redisConnection, error) {
return openConnectionWithRedisClient(tag, RedisWrapper{redisClient}, errChan)
func OpenConnectionWithRedisClient(ctx context.Context, tag string, redisClient *redis.Client, errChan chan<- error) (*redisConnection, error) {
return openConnectionWithRedisClient(ctx, tag, RedisWrapper{redisClient}, errChan)
}

// OpenConnectionWithTestRedisClient opens and returns a new connection which
// uses a test redis client internally. This is useful in integration tests.
func OpenConnectionWithTestRedisClient(tag string, errChan chan<- error) (*redisConnection, error) {
return openConnectionWithRedisClient(tag, NewTestRedisClient(), errChan)
func OpenConnectionWithTestRedisClient(ctx context.Context, tag string, errChan chan<- error) (*redisConnection, error) {
return openConnectionWithRedisClient(ctx, tag, NewTestRedisClient(), errChan)
}

func openConnectionWithRedisClient(tag string, redisClient RedisClient, errChan chan<- error) (*redisConnection, error) {
func openConnectionWithRedisClient(ctx context.Context, tag string, redisClient RedisClient, errChan chan<- error) (*redisConnection, error) {
name := fmt.Sprintf("%s-%s", tag, uniuri.NewLen(6))

connection := &redisConnection{
Name: name,
ctx: ctx,
heartbeatKey: strings.Replace(connectionHeartbeatTemplate, phConnection, name, 1),
queuesKey: strings.Replace(connectionQueuesTemplate, phConnection, name, 1),
redisClient: redisClient,
Expand Down Expand Up @@ -249,6 +248,7 @@ func (connection *redisConnection) getConsumingQueues() ([]string, error) {
// openQueue opens a queue without adding it to the set of queues
func (connection *redisConnection) openQueue(name string) Queue {
return newQueue(
connection.ctx,
name,
connection.Name,
connection.queuesKey,
Expand Down
19 changes: 8 additions & 11 deletions deliveries.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package rmq

import "context"

type Deliveries []Delivery

func (deliveries Deliveries) Payloads() []string {
Expand All @@ -19,26 +17,25 @@ func (deliveries Deliveries) Payloads() []string {

// functions with retry, see comments in delivery.go (recommended)

func (deliveries Deliveries) Ack(ctx context.Context) (errMap map[int]error) {
return deliveries.each(ctx, Delivery.Ack)
func (deliveries Deliveries) Ack() (errMap map[int]error) {
return deliveries.each(Delivery.Ack)
}

func (deliveries Deliveries) Reject(ctx context.Context) (errMap map[int]error) {
return deliveries.each(ctx, Delivery.Reject)
func (deliveries Deliveries) Reject() (errMap map[int]error) {
return deliveries.each(Delivery.Reject)
}

func (deliveries Deliveries) Push(ctx context.Context) (errMap map[int]error) {
return deliveries.each(ctx, Delivery.Push)
func (deliveries Deliveries) Push() (errMap map[int]error) {
return deliveries.each(Delivery.Push)
}

// helper functions

func (deliveries Deliveries) each(
ctx context.Context,
f func(Delivery, context.Context) error,
f func(Delivery) error,
) (errMap map[int]error) {
for i, delivery := range deliveries {
if err := f(delivery, ctx); err != nil {
if err := f(delivery); err != nil {
if errMap == nil { // create error map lazily on demand
errMap = map[int]error{}
}
Expand Down
31 changes: 15 additions & 16 deletions delivery.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
type Delivery interface {
Payload() string

Ack(context.Context) error
Reject(context.Context) error
Push(context.Context) error
Ack() error
Reject() error
Push() error
}

type redisDelivery struct {
ctx context.Context
payload string
unackedKey string
rejectedKey string
Expand All @@ -24,6 +25,7 @@ type redisDelivery struct {
}

func newDelivery(
ctx context.Context,
payload string,
unackedKey string,
rejectedKey string,
Expand All @@ -32,6 +34,7 @@ func newDelivery(
errChan chan<- error,
) *redisDelivery {
return &redisDelivery{
ctx: ctx,
payload: payload,
unackedKey: unackedKey,
rejectedKey: rejectedKey,
Expand All @@ -55,11 +58,7 @@ func (delivery *redisDelivery) Payload() string {
// 3. if the context is cancalled or its timeout exceeded, context.Cancelled or
// context.DeadlineExceeded will be returned

func (delivery *redisDelivery) Ack(ctx context.Context) error {
if ctx == nil { // TODO: remove this
ctx = context.TODO()
}

func (delivery *redisDelivery) Ack() error {
errorCount := 0
for {
count, err := delivery.redisClient.LRem(delivery.unackedKey, 1, delivery.payload)
Expand All @@ -79,27 +78,27 @@ func (delivery *redisDelivery) Ack(ctx context.Context) error {
default:
}

if err := ctx.Err(); err != nil {
if err := delivery.ctx.Err(); err != nil {
return err
}

time.Sleep(time.Second)
}
}

func (delivery *redisDelivery) Reject(ctx context.Context) error {
return delivery.move(ctx, delivery.rejectedKey)
func (delivery *redisDelivery) Reject() error {
return delivery.move(delivery.rejectedKey)
}

func (delivery *redisDelivery) Push(ctx context.Context) error {
func (delivery *redisDelivery) Push() error {
if delivery.pushKey == "" {
return delivery.Reject(ctx) // fall back to rejecting
return delivery.Reject() // fall back to rejecting
}

return delivery.move(ctx, delivery.pushKey)
return delivery.move(delivery.pushKey)
}

func (delivery *redisDelivery) move(ctx context.Context, key string) error {
func (delivery *redisDelivery) move(key string) error {
errorCount := 0
for {
_, err := delivery.redisClient.LPush(key, delivery.payload)
Expand All @@ -118,7 +117,7 @@ func (delivery *redisDelivery) move(ctx context.Context, key string) error {
time.Sleep(time.Second)
}

return delivery.Ack(ctx)
return delivery.Ack()
}

// lower level functions which don't retry but just return the first error
17 changes: 6 additions & 11 deletions example/batch_consumer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ func main() {
}
}()

connection, err := rmq.OpenConnection("consumer", "tcp", "localhost:6379", 2, errChan)
ctx, cancel := context.WithCancel(context.Background())
connection, err := rmq.OpenConnection(ctx, "consumer", "tcp", "localhost:6379", 2, errChan)
if err != nil {
panic(err)
}

ctx, cancel := context.WithCancel(context.Background())

for _, queueName := range []string{
"things",
"balls",
Expand All @@ -59,7 +58,7 @@ func main() {
if err := queue.StartConsuming(unackedLimit, pollDuration); err != nil {
panic(err)
}
if _, err := queue.AddBatchConsumer(queueName, batchSize, batchTimeout, NewBatchConsumer(ctx, queueName)); err != nil {
if _, err := queue.AddBatchConsumer(queueName, batchSize, batchTimeout, NewBatchConsumer(queueName)); err != nil {
panic(err)
}
}
Expand All @@ -82,15 +81,11 @@ func main() {
}

type BatchConsumer struct {
ctx context.Context
tag string
}

func NewBatchConsumer(ctx context.Context, tag string) *BatchConsumer {
return &BatchConsumer{
ctx: ctx,
tag: tag,
}
func NewBatchConsumer(tag string) *BatchConsumer {
return &BatchConsumer{tag: tag}
}

func (consumer *BatchConsumer) Consume(batch rmq.Deliveries) {
Expand All @@ -99,7 +94,7 @@ func (consumer *BatchConsumer) Consume(batch rmq.Deliveries) {
time.Sleep(consumeDuration)

log.Printf("%s consumed %d: %s", consumer.tag, len(batch), batch[0])
errors := batch.Ack(consumer.ctx)
errors := batch.Ack()
if len(errors) == 0 {
debugf("acked %q", payloads)
return
Expand Down
3 changes: 2 additions & 1 deletion example/cleaner/main.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package main

import (
"context"
"log"
"time"

"github.com/adjust/rmq/v2"
)

func main() {
connection, err := rmq.OpenConnection("cleaner", "tcp", "localhost:6379", 2, nil)
connection, err := rmq.OpenConnection(context.Background(), "cleaner", "tcp", "localhost:6379", 2, nil)
if err != nil {
panic(err)
}
Expand Down
15 changes: 6 additions & 9 deletions example/consumer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func main() {
}
}()

connection, err := rmq.OpenConnection("consumer", "tcp", "localhost:6379", 2, errChan)
ctx, cancel := context.WithCancel(context.Background())
connection, err := rmq.OpenConnection(ctx, "consumer", "tcp", "localhost:6379", 2, errChan)
if err != nil {
panic(err)
}
Expand All @@ -55,11 +56,9 @@ func main() {
panic(err)
}

ctx, cancel := context.WithCancel(context.Background())

for i := 0; i < numConsumers; i++ {
name := fmt.Sprintf("consumer %d", i)
if _, err := queue.AddConsumer(name, NewConsumer(ctx, i)); err != nil {
if _, err := queue.AddConsumer(name, NewConsumer(i)); err != nil {
panic(err)
}
}
Expand All @@ -82,15 +81,13 @@ func main() {
}

type Consumer struct {
ctx context.Context
name string
count int
before time.Time
}

func NewConsumer(ctx context.Context, tag int) *Consumer {
func NewConsumer(tag int) *Consumer {
return &Consumer{
ctx: ctx,
name: fmt.Sprintf("consumer%d", tag),
count: 0,
before: time.Now(),
Expand All @@ -111,13 +108,13 @@ func (consumer *Consumer) Consume(delivery rmq.Delivery) {
}

if consumer.count%batchSize > 0 {
if err := delivery.Ack(consumer.ctx); err != nil {
if err := delivery.Ack(); err != nil {
debugf("failed to ack %s: %s", payload, err)
} else {
debugf("acked %s", payload)
}
} else { // reject one per batch
if err := delivery.Reject(consumer.ctx); err != nil {
if err := delivery.Reject(); err != nil {
debugf("failed to reject %s: %s", payload, err)
} else {
debugf("rejected %s", payload)
Expand Down
Loading

0 comments on commit 2c8c791

Please sign in to comment.