diff --git a/pkg/providers/kinesis/consumer/consumer.go b/pkg/providers/kinesis/consumer/consumer.go index b2d87fa4..0a583ac2 100644 --- a/pkg/providers/kinesis/consumer/consumer.go +++ b/pkg/providers/kinesis/consumer/consumer.go @@ -2,6 +2,7 @@ package consumer import ( "context" + "github.com/aws/aws-sdk-go/aws/request" "sync" "time" @@ -9,13 +10,19 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" "github.com/doublecloud/transfer/internal/logger" "github.com/doublecloud/transfer/library/go/core/xerrors" "github.com/doublecloud/transfer/library/go/slices" "go.ytsaurus.tech/library/go/core/log" ) +// KinesisReader is a lightweight interface that narrow down usage to just what really needed by this code +type KinesisReader interface { + ListShards(*kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) + GetRecords(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) + GetShardIteratorWithContext(aws.Context, *kinesis.GetShardIteratorInput, ...request.Option) (*kinesis.GetShardIteratorOutput, error) +} + // Record wraps the record returned from the Kinesis library and // extends to include the shard id. type Record struct { @@ -65,7 +72,7 @@ type Consumer struct { streamName string initialShardIteratorType string initialTimestamp *time.Time - client kinesisiface.KinesisAPI + client KinesisReader group Group logger log.Logger store Store diff --git a/pkg/providers/kinesis/consumer/group_all.go b/pkg/providers/kinesis/consumer/group_all.go index cf59c0e1..7d20afa2 100644 --- a/pkg/providers/kinesis/consumer/group_all.go +++ b/pkg/providers/kinesis/consumer/group_all.go @@ -7,14 +7,13 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" "github.com/doublecloud/transfer/library/go/core/xerrors" "go.ytsaurus.tech/library/go/core/log" ) // NewAllGroup returns an intitialized AllGroup for consuming // all shards on a stream -func NewAllGroup(ksis kinesisiface.KinesisAPI, store Store, streamName string, logger log.Logger) *AllGroup { +func NewAllGroup(ksis KinesisReader, store Store, streamName string, logger log.Logger) *AllGroup { return &AllGroup{ Store: store, ksis: ksis, @@ -31,7 +30,7 @@ func NewAllGroup(ksis kinesisiface.KinesisAPI, store Store, streamName string, l type AllGroup struct { Store - ksis kinesisiface.KinesisAPI + ksis KinesisReader streamName string logger log.Logger @@ -88,7 +87,7 @@ func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) { } // listShards pulls a list of shard IDs from the kinesis api -func listShards(ksis kinesisiface.KinesisAPI, streamName string) ([]*kinesis.Shard, error) { +func listShards(ksis KinesisReader, streamName string) ([]*kinesis.Shard, error) { var ss []*kinesis.Shard var listShardsInput = &kinesis.ListShardsInput{ StreamName: aws.String(streamName), diff --git a/pkg/providers/kinesis/consumer/options.go b/pkg/providers/kinesis/consumer/options.go index 7a68b732..8bcfa9a3 100644 --- a/pkg/providers/kinesis/consumer/options.go +++ b/pkg/providers/kinesis/consumer/options.go @@ -2,8 +2,6 @@ package consumer import ( "time" - - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" ) // Option is used to override defaults when creating a new Consumer @@ -24,7 +22,7 @@ func WithStore(store Store) Option { } // WithClient overrides the default client -func WithClient(client kinesisiface.KinesisAPI) Option { +func WithClient(client KinesisReader) Option { return func(c *Consumer) { c.client = client } diff --git a/pkg/providers/kinesis/source_test.go b/pkg/providers/kinesis/source_test.go new file mode 100644 index 00000000..9278956e --- /dev/null +++ b/pkg/providers/kinesis/source_test.go @@ -0,0 +1,85 @@ +package kinesis + +import ( + "context" + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/doublecloud/transfer/internal/logger" + "github.com/doublecloud/transfer/library/go/core/metrics/solomon" + "github.com/doublecloud/transfer/library/go/core/xerrors" + "github.com/doublecloud/transfer/pkg/abstract" + "github.com/doublecloud/transfer/pkg/abstract/coordinator" + "github.com/doublecloud/transfer/pkg/parsequeue" + "github.com/doublecloud/transfer/pkg/providers/kinesis/consumer" + "github.com/doublecloud/transfer/pkg/stats" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +type fakeClient struct { + cntr int +} + +func (f *fakeClient) ListShards(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{Shards: []*kinesis.Shard{ + {ShardId: aws.String("s-1")}, + {ShardId: aws.String("s-2")}, + {ShardId: aws.String("s-3")}, + }}, nil +} + +func (f *fakeClient) GetRecords(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + f.cntr++ + if f.cntr < 3 { + return &kinesis.GetRecordsOutput{ + Records: []*kinesis.Record{ + { + ApproximateArrivalTimestamp: aws.Time(time.Now()), + Data: []byte("test"), + EncryptionType: nil, + PartitionKey: nil, + SequenceNumber: aws.String(fmt.Sprintf("s1-%v", f.cntr)), + }, + }, + NextShardIterator: aws.String("next-1"), + }, nil + } + return nil, awserr.New("non-retryable-code", "asd", xerrors.New("demo error")) +} + +func (f *fakeClient) GetShardIteratorWithContext(a aws.Context, input *kinesis.GetShardIteratorInput, option ...request.Option) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("s1"), + }, nil +} + +type mockSync struct { +} + +func (m mockSync) Close() error { + return nil +} + +func (m mockSync) AsyncPush(items []abstract.ChangeItem) chan error { + resCh := make(chan error) + return resCh +} + +func TestFailure(t *testing.T) { + var err error + s := new(Source) + s.cp = coordinator.NewFakeClient() + s.logger = logger.Log + s.ctx = context.Background() + s.config = new(KinesisSource) + s.config.WithDefaults() + s.metrics = stats.NewSourceStats(solomon.NewRegistry(solomon.NewRegistryOpts())) + s.consumer, err = consumer.New("abc", consumer.WithClient(&fakeClient{})) + require.NoError(t, err) + parseQ := parsequeue.NewWaitable(s.logger, 10, &mockSync{}, s.parse, s.ack) + require.Error(t, s.run(parseQ)) +}