Skip to content

Commit

Permalink
Reuse error channel passed into OpenConnection()
Browse files Browse the repository at this point in the history
Don't inject error channel into StartConsuming() or delivery functions
like Ack().
  • Loading branch information
wellle committed May 23, 2020
1 parent 4fde042 commit b4dae04
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 125 deletions.
10 changes: 5 additions & 5 deletions cleaner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestCleaner(t *testing.T) {
count, err = queue.unackedCount()
assert.NoError(t, err)
assert.Equal(t, int64(0), count)
assert.NoError(t, queue.StartConsuming(2, time.Millisecond, nil))
assert.NoError(t, queue.StartConsuming(2, time.Millisecond))
time.Sleep(time.Millisecond)
count, err = queue.unackedCount()
assert.NoError(t, err)
Expand All @@ -86,7 +86,7 @@ func TestCleaner(t *testing.T) {

require.NotNil(t, consumer.LastDelivery)
assert.Equal(t, "del1", consumer.LastDelivery.Payload())
assert.NoError(t, consumer.LastDelivery.Ack(nil, nil))
assert.NoError(t, consumer.LastDelivery.Ack(nil))
time.Sleep(10 * time.Millisecond)
count, err = queue.unackedCount()
assert.NoError(t, err)
Expand Down Expand Up @@ -138,7 +138,7 @@ func TestCleaner(t *testing.T) {
count, err = queue.unackedCount()
assert.NoError(t, err)
assert.Equal(t, int64(0), count)
assert.NoError(t, queue.StartConsuming(2, time.Millisecond, nil))
assert.NoError(t, queue.StartConsuming(2, time.Millisecond))
time.Sleep(time.Millisecond)
count, err = queue.unackedCount()
assert.NoError(t, err)
Expand Down Expand Up @@ -172,7 +172,7 @@ func TestCleaner(t *testing.T) {
assert.Equal(t, int64(6), count)

assert.Equal(t, "del5", consumer.LastDelivery.Payload())
assert.NoError(t, consumer.LastDelivery.Ack(nil, nil))
assert.NoError(t, consumer.LastDelivery.Ack(nil))
time.Sleep(10 * time.Millisecond)
count, err = queue.unackedCount()
assert.NoError(t, err)
Expand Down Expand Up @@ -202,7 +202,7 @@ func TestCleaner(t *testing.T) {
assert.NoError(t, err)
queue, err = conn.OpenQueue("q1")
assert.NoError(t, err)
assert.NoError(t, queue.StartConsuming(10, time.Millisecond, nil))
assert.NoError(t, queue.StartConsuming(10, time.Millisecond))
consumer = NewTestConsumer("c-C")

_, err = queue.AddConsumer("consumer3", consumer)
Expand Down
30 changes: 19 additions & 11 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type redisConnection struct {
heartbeatKey string // key to keep alive
queuesKey string // key to list of queues consumed by this connection
redisClient RedisClient
errChan chan<- error
heartbeatTicker *time.Ticker

// list of all queues that have been opened in this connection
Expand All @@ -58,34 +59,35 @@ type redisConnection struct {
}

// OpenConnection opens and returns a new connection
func OpenConnection(tag, network, address string, db int, errors chan<- error) (Connection, error) {
func OpenConnection(tag, network, address string, db int, errChan chan<- error) (Connection, error) {
redisClient := redis.NewClient(&redis.Options{
Network: network,
Addr: address,
DB: db,
})
return OpenConnectionWithRedisClient(tag, redisClient, errors)
return OpenConnectionWithRedisClient(tag, redisClient, errChan)
}

// OpenConnectionWithRedisClient opens and returns a new connection
func OpenConnectionWithRedisClient(tag string, redisClient *redis.Client, errors chan<- error) (*redisConnection, error) {
return openConnectionWithRedisClient(tag, RedisWrapper{redisClient}, errors)
func OpenConnectionWithRedisClient(tag string, redisClient *redis.Client, errChan chan<- error) (*redisConnection, error) {
return openConnectionWithRedisClient(tag, RedisWrapper{redisClient}, errChan)
}

// OpenConnectionWithTestRedisClient opens and returns a new connection which
// uses a test redis client internally. This is useful in integration tests.
func OpenConnectionWithTestRedisClient(tag string, errors chan<- error) (*redisConnection, error) {
return openConnectionWithRedisClient(tag, NewTestRedisClient(), errors)
func OpenConnectionWithTestRedisClient(tag string, errChan chan<- error) (*redisConnection, error) {
return openConnectionWithRedisClient(tag, NewTestRedisClient(), errChan)
}

func openConnectionWithRedisClient(tag string, redisClient RedisClient, errors chan<- error) (*redisConnection, error) {
func openConnectionWithRedisClient(tag string, redisClient RedisClient, errChan chan<- error) (*redisConnection, error) {
name := fmt.Sprintf("%s-%s", tag, uniuri.NewLen(6))

connection := &redisConnection{
Name: name,
heartbeatKey: strings.Replace(connectionHeartbeatTemplate, phConnection, name, 1),
queuesKey: strings.Replace(connectionQueuesTemplate, phConnection, name, 1),
redisClient: redisClient,
errChan: errChan,
heartbeatTicker: time.NewTicker(heartbeatInterval),
}

Expand All @@ -98,7 +100,7 @@ func openConnectionWithRedisClient(tag string, redisClient RedisClient, errors c
return nil, err
}

go connection.heartbeat(errors)
go connection.heartbeat(errChan)
// log.Printf("rmq connection connected to %s %s:%s %d", name, network, address, db)
return connection, nil
}
Expand All @@ -108,7 +110,7 @@ func (connection *redisConnection) updateHeartbeat() error {
}

// heartbeat keeps the heartbeat key alive
func (connection *redisConnection) heartbeat(errors chan<- error) {
func (connection *redisConnection) heartbeat(errChan chan<- error) {
errorCount := 0 // number of consecutive errors
for range connection.heartbeatTicker.C {
err := connection.updateHeartbeat()
Expand All @@ -121,7 +123,7 @@ func (connection *redisConnection) heartbeat(errors chan<- error) {
errorCount++

select { // try to add error to channel, but don't block
case errors <- &HeartbeatError{RedisErr: err, Count: errorCount}:
case errChan <- &HeartbeatError{RedisErr: err, Count: errorCount}:
default:
}

Expand Down Expand Up @@ -246,7 +248,13 @@ func (connection *redisConnection) getConsumingQueues() ([]string, error) {

// openQueue opens a queue without adding it to the set of queues
func (connection *redisConnection) openQueue(name string) Queue {
return newQueue(name, connection.Name, connection.queuesKey, connection.redisClient)
return newQueue(
name,
connection.Name,
connection.queuesKey,
connection.redisClient,
connection.errChan,
)
}

// stopHeartbeat stops the heartbeat of the connection
Expand Down
17 changes: 8 additions & 9 deletions deliveries.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,26 @@ func (deliveries Deliveries) Payloads() []string {

// functions with retry, see comments in delivery.go (recommended)

func (deliveries Deliveries) Ack(ctx context.Context, errChan chan<- error) (errMap map[int]error) {
return deliveries.each(ctx, errChan, Delivery.Ack)
func (deliveries Deliveries) Ack(ctx context.Context) (errMap map[int]error) {
return deliveries.each(ctx, Delivery.Ack)
}

func (deliveries Deliveries) Reject(ctx context.Context, errChan chan<- error) (errMap map[int]error) {
return deliveries.each(ctx, errChan, Delivery.Reject)
func (deliveries Deliveries) Reject(ctx context.Context) (errMap map[int]error) {
return deliveries.each(ctx, Delivery.Reject)
}

func (deliveries Deliveries) Push(ctx context.Context, errChan chan<- error) (errMap map[int]error) {
return deliveries.each(ctx, errChan, Delivery.Push)
func (deliveries Deliveries) Push(ctx context.Context) (errMap map[int]error) {
return deliveries.each(ctx, Delivery.Push)
}

// helper functions

func (deliveries Deliveries) each(
ctx context.Context,
errChan chan<- error,
f func(Delivery, context.Context, chan<- error) error,
f func(Delivery, context.Context) error,
) (errMap map[int]error) {
for i, delivery := range deliveries {
if err := f(delivery, ctx, errChan); err != nil {
if err := f(delivery, ctx); err != nil {
if errMap == nil { // create error map lazily on demand
errMap = map[int]error{}
}
Expand Down
37 changes: 23 additions & 14 deletions delivery.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
type Delivery interface {
Payload() string

Ack(context.Context, chan<- error) error
Reject(context.Context, chan<- error) error
Push(context.Context, chan<- error) error
Ack(context.Context) error
Reject(context.Context) error
Push(context.Context) error
}

type redisDelivery struct {
Expand All @@ -20,15 +20,24 @@ type redisDelivery struct {
rejectedKey string
pushKey string
redisClient RedisClient
errChan chan<- error
}

func newDelivery(payload, unackedKey, rejectedKey, pushKey string, redisClient RedisClient) *redisDelivery {
func newDelivery(
payload string,
unackedKey string,
rejectedKey string,
pushKey string,
redisClient RedisClient,
errChan chan<- error,
) *redisDelivery {
return &redisDelivery{
payload: payload,
unackedKey: unackedKey,
rejectedKey: rejectedKey,
pushKey: pushKey,
redisClient: redisClient,
errChan: errChan,
}
}

Expand All @@ -46,7 +55,7 @@ func (delivery *redisDelivery) Payload() string {
// 3. if the context is cancalled or its timeout exceeded, context.Cancelled or
// context.DeadlineExceeded will be returned

func (delivery *redisDelivery) Ack(ctx context.Context, errChan chan<- error) error {
func (delivery *redisDelivery) Ack(ctx context.Context) error {
if ctx == nil { // TODO: remove this
ctx = context.TODO()
}
Expand All @@ -66,7 +75,7 @@ func (delivery *redisDelivery) Ack(ctx context.Context, errChan chan<- error) er
errorCount++

select { // try to add error to channel, but don't block
case errChan <- &DeliveryError{Delivery: delivery, RedisErr: err, Count: errorCount}:
case delivery.errChan <- &DeliveryError{Delivery: delivery, RedisErr: err, Count: errorCount}:
default:
}

Expand All @@ -78,19 +87,19 @@ func (delivery *redisDelivery) Ack(ctx context.Context, errChan chan<- error) er
}
}

func (delivery *redisDelivery) Reject(ctx context.Context, errChan chan<- error) error {
return delivery.move(ctx, errChan, delivery.rejectedKey)
func (delivery *redisDelivery) Reject(ctx context.Context) error {
return delivery.move(ctx, delivery.rejectedKey)
}

func (delivery *redisDelivery) Push(ctx context.Context, errChan chan<- error) error {
func (delivery *redisDelivery) Push(ctx context.Context) error {
if delivery.pushKey == "" {
return delivery.Reject(ctx, errChan) // fall back to rejecting
return delivery.Reject(ctx) // fall back to rejecting
}

return delivery.move(ctx, errChan, delivery.pushKey)
return delivery.move(ctx, delivery.pushKey)
}

func (delivery *redisDelivery) move(ctx context.Context, errChan chan<- error, key string) error {
func (delivery *redisDelivery) move(ctx context.Context, key string) error {
errorCount := 0
for {
_, err := delivery.redisClient.LPush(key, delivery.payload)
Expand All @@ -102,14 +111,14 @@ func (delivery *redisDelivery) move(ctx context.Context, errChan chan<- error, k
errorCount++

select { // try to add error to channel, but don't block
case errChan <- &DeliveryError{Delivery: delivery, RedisErr: err, Count: errorCount}:
case delivery.errChan <- &DeliveryError{Delivery: delivery, RedisErr: err, Count: errorCount}:
default:
}

time.Sleep(time.Second)
}

return delivery.Ack(ctx, errChan)
return delivery.Ack(ctx)
}

// lower level functions which don't retry but just return the first error
18 changes: 8 additions & 10 deletions example/batch_consumer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func main() {
if err != nil {
panic(err)
}
if err := queue.StartConsuming(unackedLimit, pollDuration, nil); err != nil {
if err := queue.StartConsuming(unackedLimit, pollDuration); err != nil {
panic(err)
}
if _, err := queue.AddBatchConsumer(queueName, batchSize, batchTimeout, NewBatchConsumer(ctx, errChan, queueName)); err != nil {
if _, err := queue.AddBatchConsumer(queueName, batchSize, batchTimeout, NewBatchConsumer(ctx, queueName)); err != nil {
panic(err)
}
}
Expand All @@ -82,16 +82,14 @@ func main() {
}

type BatchConsumer struct {
ctx context.Context
errChan chan<- error
tag string
ctx context.Context
tag string
}

func NewBatchConsumer(ctx context.Context, errChan chan<- error, tag string) *BatchConsumer {
func NewBatchConsumer(ctx context.Context, tag string) *BatchConsumer {
return &BatchConsumer{
ctx: ctx,
errChan: errChan,
tag: tag,
ctx: ctx,
tag: tag,
}
}

Expand All @@ -101,7 +99,7 @@ func (consumer *BatchConsumer) Consume(batch rmq.Deliveries) {
time.Sleep(consumeDuration)

log.Printf("%s consumed %d: %s", consumer.tag, len(batch), batch[0])
errors := batch.Ack(consumer.ctx, consumer.errChan)
errors := batch.Ack(consumer.ctx)
if len(errors) == 0 {
debugf("acked %q", payloads)
return
Expand Down
28 changes: 13 additions & 15 deletions example/consumer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func main() {
panic(err)
}

if err := queue.StartConsuming(unackedLimit, 500*time.Millisecond, errChan); err != nil {
if err := queue.StartConsuming(unackedLimit, 500*time.Millisecond); err != nil {
panic(err)
}

ctx, cancel := context.WithCancel(context.Background())

for i := 0; i < numConsumers; i++ {
name := fmt.Sprintf("consumer %d", i)
if _, err := queue.AddConsumer(name, NewConsumer(ctx, errChan, i)); err != nil {
if _, err := queue.AddConsumer(name, NewConsumer(ctx, i)); err != nil {
panic(err)
}
}
Expand All @@ -82,20 +82,18 @@ func main() {
}

type Consumer struct {
ctx context.Context
errChan chan<- error
name string
count int
before time.Time
ctx context.Context
name string
count int
before time.Time
}

func NewConsumer(ctx context.Context, errChan chan<- error, tag int) *Consumer {
func NewConsumer(ctx context.Context, tag int) *Consumer {
return &Consumer{
ctx: ctx,
errChan: errChan,
name: fmt.Sprintf("consumer%d", tag),
count: 0,
before: time.Now(),
ctx: ctx,
name: fmt.Sprintf("consumer%d", tag),
count: 0,
before: time.Now(),
}
}

Expand All @@ -113,13 +111,13 @@ func (consumer *Consumer) Consume(delivery rmq.Delivery) {
}

if consumer.count%batchSize > 0 {
if err := delivery.Ack(consumer.ctx, consumer.errChan); err != nil {
if err := delivery.Ack(consumer.ctx); err != nil {
debugf("failed to ack %s: %s", payload, err)
} else {
debugf("acked %s", payload)
}
} else { // reject one per batch
if err := delivery.Reject(consumer.ctx, consumer.errChan); err != nil {
if err := delivery.Reject(consumer.ctx); err != nil {
debugf("failed to reject %s: %s", payload, err)
} else {
debugf("rejected %s", payload)
Expand Down
Loading

0 comments on commit b4dae04

Please sign in to comment.