diff --git a/xkafka/consumer.go b/xkafka/consumer.go index 5cd54de..6c14bc1 100644 --- a/xkafka/consumer.go +++ b/xkafka/consumer.go @@ -4,6 +4,7 @@ import ( "context" "errors" "strings" + "sync/atomic" "time" "github.com/confluentinc/confluent-kafka-go/v2/kafka" @@ -18,6 +19,7 @@ type Consumer struct { handler Handler middlewares []middleware config options + cancelCtx atomic.Pointer[context.CancelFunc] } // NewConsumer creates a new Consumer instance. @@ -65,23 +67,53 @@ func (c *Consumer) Use(mwf ...MiddlewareFunc) { } } -// Run manages starting and stopping the consumer. +// Run starts running the Consumer. The component will stop running +// when the context is closed. Run blocks until the context is closed or +// an error occurs. func (c *Consumer) Run(ctx context.Context) error { - defer c.Close() + if err := c.subscribe(); err != nil { + return err + } + + if err := c.start(ctx); err != nil { + return err + } - return c.Start(ctx) + return c.close() } // Start subscribes to the configured topics and starts consuming messages. // It runs the handler for each message in a separate goroutine. -// It blocks until the context is cancelled or an error occurs. +// +// This method is non-blocking and returns immediately post subscribe. +// Instead, use Run if you want to block until the context is closed or an error occurs. +// // Errors are handled by the ErrorHandler if set, otherwise they stop the consumer // and are returned. -func (c *Consumer) Start(ctx context.Context) error { +func (c *Consumer) Start() error { if err := c.subscribe(); err != nil { return err } + ctx, cancel := context.WithCancel(context.Background()) + c.cancelCtx.Store(&cancel) + + go func() { _ = c.start(ctx) }() + + return nil +} + +// Close closes the consumer. +func (c *Consumer) Close() { + cancel := c.cancelCtx.Load() + if cancel != nil { + (*cancel)() + } + + _ = c.close() +} + +func (c *Consumer) start(ctx context.Context) error { c.handler = c.concatMiddlewares(c.handler) if c.config.concurrency > 1 { @@ -237,9 +269,8 @@ func (c *Consumer) unsubscribe() error { return c.kafka.Unsubscribe() } -// Close closes the consumer. -func (c *Consumer) Close() { +func (c *Consumer) close() error { <-time.After(c.config.shutdownTimeout) - _ = c.kafka.Close() + return c.kafka.Close() } diff --git a/xkafka/consumer_test.go b/xkafka/consumer_test.go index 6fab372..2f4f8bc 100644 --- a/xkafka/consumer_test.go +++ b/xkafka/consumer_test.go @@ -85,6 +85,66 @@ func TestNewConsumerError(t *testing.T) { assert.Equal(t, expectError, err) } +func TestConsumerLifecycle(t *testing.T) { + t.Parallel() + + t.Run("StartSubscribeError", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, + testTopics, + testBrokers, + PollTimeout(testTimeout), + ) + + expectError := errors.New("error in subscribe") + + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(expectError) + + assert.Error(t, consumer.Start()) + + mockKafka.AssertExpectations(t) + }) + + t.Run("StartSuccessCloseError", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, + testTopics, + testBrokers, + PollTimeout(testTimeout), + ) + + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + mockKafka.On("Unsubscribe").Return(nil) + mockKafka.On("ReadMessage", testTimeout).Return(newFakeKafkaMessage(), nil) + mockKafka.On("Commit").Return(nil, nil) + mockKafka.On("Close").Return(errors.New("error in close")) + + assert.NoError(t, consumer.Start()) + <-time.After(100 * time.Millisecond) + consumer.Close() + + mockKafka.AssertExpectations(t) + }) + + t.Run("StartCloseSuccess", func(t *testing.T) { + consumer, mockKafka := newTestConsumer(t, + testTopics, + testBrokers, + PollTimeout(testTimeout), + ) + + mockKafka.On("SubscribeTopics", []string(testTopics), mock.Anything).Return(nil) + mockKafka.On("Unsubscribe").Return(nil) + mockKafka.On("ReadMessage", testTimeout).Return(newFakeKafkaMessage(), nil) + mockKafka.On("Commit").Return(nil, nil) + mockKafka.On("Close").Return(nil) + + assert.NoError(t, consumer.Start()) + <-time.After(100 * time.Millisecond) + consumer.Close() + + mockKafka.AssertExpectations(t) + }) +} + func TestConsumerGetMetadata(t *testing.T) { t.Parallel() @@ -95,6 +155,7 @@ func TestConsumerGetMetadata(t *testing.T) { ) mockKafka.On("GetMetadata", mock.Anything, false, 10000).Return(&kafka.Metadata{}, nil) + mockKafka.On("Close").Return(nil) metadata, err := consumer.GetMetadata() assert.NoError(t, err) @@ -180,6 +241,7 @@ func TestConsumerHandleMessage(t *testing.T) { mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) + mockKafka.On("Close").Return(nil) consumer.handler = handler err := consumer.Run(ctx) @@ -339,6 +401,7 @@ func TestConsumerReadMessageTimeout(t *testing.T) { mockKafka.On("ReadMessage", testTimeout).Return(km, nil).Once() mockKafka.On("ReadMessage", testTimeout).Return(nil, expect).Once() mockKafka.On("ReadMessage", testTimeout).Return(km, nil) + mockKafka.On("Close").Return(nil) consumer.handler = handler @@ -408,6 +471,7 @@ func TestConsumerMiddlewareExecutionOrder(t *testing.T) { mockKafka.On("Unsubscribe").Return(nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) + mockKafka.On("Close").Return(nil) handler := HandlerFunc(func(ctx context.Context, msg *Message) error { cancel() @@ -458,6 +522,7 @@ func TestConsumerManualCommit(t *testing.T) { mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) mockKafka.On("Commit").Return(nil, nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) + mockKafka.On("Close").Return(nil) handler := HandlerFunc(func(ctx context.Context, msg *Message) error { cancel() @@ -493,6 +558,7 @@ func TestConsumerAsync(t *testing.T) { mockKafka.On("StoreOffsets", mock.Anything).Return(nil, nil) mockKafka.On("ReadMessage", testTimeout).Return(km, nil) mockKafka.On("Commit").Return(nil, nil) + mockKafka.On("Close").Return(nil) var recv []*Message var mu sync.Mutex @@ -640,8 +706,6 @@ func testMiddleware(name string, pre, post *[]string) MiddlewareFunc { func newTestConsumer(t *testing.T, opts ...Option) (*Consumer, *MockConsumerClient) { mockConsumer := &MockConsumerClient{} - mockConsumer.On("Close").Return(nil) - opts = append(opts, mockConsumerFunc(mockConsumer)) consumer, err := NewConsumer("consumer-id", noopHandler(), opts...)