From 51c562a4c149b39dfc33340cc779c5e2317d77aa Mon Sep 17 00:00:00 2001 From: Carly de Frondeville Date: Fri, 11 Oct 2024 21:46:05 -0400 Subject: [PATCH] Recycle task tokens for invalid matching tasks (#6599) ## What changed? Add a `RecycleToken` function to our custom rate limiters. It seems that most of the rate limiters just wrap the `ClockedRateLimiter`, so let the `ClockedRateLimiter`'s `WaitN` function unblock the waiter if recycle is called, thus allowing us to recycle tokens and give them to any waiters if they exist. ## Why? Lots of invalid tasks combined with Activity task dispatch rate limiting can cause the actual rate to be noticeably below the maximum rate. This will make the actual rate match the max rps, as long as `time_to_recycle` is reasonably short and `recycle_rate` is reasonable (see discussion in the comments) ## How did you test it? ClockedRateLimiter unit test. ## Potential risks The max Activity RPS could break. ## Documentation ## Is hotfix candidate? --------- Co-authored-by: David Reiss --- common/quotas/clocked_rate_limiter.go | 74 +++++++++++++++++-- common/quotas/clocked_rate_limiter_test.go | 63 ++++++++++++++++ common/quotas/dynamic_rate_limiter_impl.go | 5 ++ common/quotas/multi_rate_limiter_impl.go | 20 +++-- common/quotas/multi_reservation_impl.go | 1 + common/quotas/rate_limiter.go | 5 ++ common/quotas/rate_limiter_impl.go | 11 ++- common/quotas/rate_limiter_mock.go | 12 +++ service/matching/matcher.go | 17 ++++- service/matching/matcher_test.go | 19 ++--- service/matching/matching_engine.go | 28 +++---- service/matching/matching_engine_test.go | 8 +- .../matching/physical_task_queue_manager.go | 4 +- .../physical_task_queue_manager_test.go | 2 +- service/matching/task.go | 12 ++- service/matching/task_validation.go | 12 +-- 16 files changed, 238 insertions(+), 55 deletions(-) diff --git a/common/quotas/clocked_rate_limiter.go b/common/quotas/clocked_rate_limiter.go index 97e2ed77646..8f7ee786056 100644 --- a/common/quotas/clocked_rate_limiter.go +++ b/common/quotas/clocked_rate_limiter.go @@ -39,6 +39,7 @@ import ( type ClockedRateLimiter struct { rateLimiter *rate.Limiter timeSource clock.TimeSource + recycleCh chan struct{} } var ( @@ -51,6 +52,7 @@ func NewClockedRateLimiter(rateLimiter *rate.Limiter, timeSource clock.TimeSourc return ClockedRateLimiter{ rateLimiter: rateLimiter, timeSource: timeSource, + recycleCh: make(chan struct{}), } } @@ -131,12 +133,35 @@ func (l ClockedRateLimiter) WaitN(ctx context.Context, token int) error { close(waitExpired) }) defer timer.Stop() - select { - case <-ctx.Done(): - reservation.Cancel() - return fmt.Errorf("%w: %v", ErrRateLimiterWaitInterrupted, ctx.Err()) - case <-waitExpired: - return nil + + for { + select { + case <-ctx.Done(): + reservation.Cancel() + return fmt.Errorf("%w: %v", ErrRateLimiterWaitInterrupted, ctx.Err()) + case <-waitExpired: + return nil + case <-l.recycleCh: + if token > 1 { + break // recycling 1 token to a process requesting >1 tokens is a no-op + } + + // Cancel() reverses the effects of this Reservation on the rate limit as much as possible, + // considering that other reservations may have already been made. Normally, Cancel() indicates + // that the reservation holder will not perform the reserved action, so it would make the most + // sense to cancel the reservation whose token was just recycled. However, we don't have access + // to the recycled reservation anymore, and even if we did, Cancel on a reservation that + // has fully waited is a no-op, so instead we cancel the current reservation as a proxy. + // + // Since Cancel() just restores tokens to the rate limiter, cancelling the current 1-token + // reservation should have approximately the same effect on the actual rate as cancelling the + // recycled reservation. + // + // If the recycled reservation was for >1 token, cancelling the current 1-token reservation will + // lead to a slower actual rate than cancelling the original, so the approximation is conservative. + reservation.Cancel() + return nil + } } } @@ -151,3 +176,40 @@ func (l ClockedRateLimiter) SetBurstAt(t time.Time, newBurst int) { func (l ClockedRateLimiter) TokensAt(t time.Time) int { return int(l.rateLimiter.TokensAt(t)) } + +// RecycleToken should be called when the action being rate limited was not completed +// for some reason (i.e. a task is not dispatched because it was invalid). +// In this case, we want to immediately unblock another process that is waiting for one token +// so that the actual rate of completed actions is as close to the intended rate limit as possible. +// If no process is waiting for a token when RecycleToken is called, this is a no-op. +// +// Since we don't know how many tokens were reserved by the process calling recycle, we will only unblock +// new reservations that are for one token (otherwise we could recycle a 1-token-reservation and unblock +// a 100-token-reservation). If all waiting processes are waiting for >1 tokens, this is a no-op. +// +// Because recycleCh is an unbuffered channel, the token will be reused for the next waiter as long +// as there exists a waiter at the time RecycleToken is called. Usually the attempted rate is consistently +// above or below the limit for a period of time, so if rate limiting is in effect and recycling matters, +// most likely there will be a waiter. If the actual rate is erratically bouncing to either side of the +// rate limit AND we perform many recycles, this will drop some recycled tokens. +// If that situation turns out to be common, we may want to make it a buffered channel instead. +// +// Our goal is to ensure that each token in our bucket is used every second, meaning the time between +// taking and successfully using a token must be <= 1s. For this to be true, we must have: +// +// time_to_recycle * number_of_recycles_per_second <= 1s +// time_to_recycle * probability_of_recycle * number_of_attempts_per_second <= 1s +// +// Therefore, it is also possible for this strategy to be inaccurate if the delay between taking and +// successfully using a token is greater than one second. +// +// Currently, RecycleToken is called when we take a token to attempt a matching task dispatch and +// then later find out (usually via RPC to History) that the task should not be dispatched. +// If history rpc takes 10ms --> 100 opportunities for the token to be used that second --> 99% recycle probability is ok. +// If recycle probability is 50% --> need at least 2 opportunities for token to be used --> 500ms history rpc time is ok. +func (l ClockedRateLimiter) RecycleToken() { + select { + case l.recycleCh <- struct{}{}: + default: + } +} diff --git a/common/quotas/clocked_rate_limiter_test.go b/common/quotas/clocked_rate_limiter_test.go index 09b6202fbfe..0139875bd21 100644 --- a/common/quotas/clocked_rate_limiter_test.go +++ b/common/quotas/clocked_rate_limiter_test.go @@ -26,6 +26,7 @@ package quotas_test import ( "context" + "sync/atomic" "testing" "time" @@ -151,3 +152,65 @@ func TestClockedRateLimiter_Wait_DeadlineWouldExceed(t *testing.T) { t.Cleanup(cancel) assert.ErrorIs(t, rl.Wait(ctx), quotas.ErrRateLimiterReservationWouldExceedContextDeadline) } + +// test that reservations for 1 token ARE unblocked by RecycleToken +func TestClockedRateLimiter_Wait_Recycle(t *testing.T) { + t.Parallel() + ts := clock.NewEventTimeSource() + rl := quotas.NewClockedRateLimiter(rate.NewLimiter(1, 1), ts) + ctx := context.Background() + + // take first token + assert.NoError(t, rl.Wait(ctx)) + + // wait for next token and report when success + var asserted atomic.Bool + asserted.Store(false) + go func() { + assert.NoError(t, rl.Wait(ctx)) + asserted.Store(true) + }() + // wait for rl.Wait() to start and get to the select statement + time.Sleep(10 * time.Millisecond) // nolint + + // once a waiter exists, recycle the token instead of advancing time + rl.RecycleToken() + + // wait until done so we know assert.NoError was called + assert.Eventually(t, func() bool { return asserted.Load() }, time.Second, time.Millisecond) +} + +// test that reservations for >1 token are NOT unblocked by RecycleToken +func TestClockedRateLimiter_WaitN_NoRecycle(t *testing.T) { + t.Parallel() + ts := clock.NewEventTimeSource() + + // set burst to 2 so that the reservation succeeds and WaitN gets to the select statement + rl := quotas.NewClockedRateLimiter(rate.NewLimiter(1, 2), ts) + ctx, cancel := context.WithCancel(context.Background()) + + // take first token + assert.NoError(t, rl.Wait(ctx)) + + // wait for 2 tokens, which will never get a recycle + // expect a context cancel error instead once we advance time + // wait for next token and report when success + var asserted atomic.Bool + asserted.Store(false) + go func() { + err := rl.WaitN(ctx, 2) + assert.ErrorContains(t, err, quotas.ErrRateLimiterWaitInterrupted.Error()) + asserted.Store(true) + }() + // wait for rl.Wait() to start and get to the select statement + time.Sleep(10 * time.Millisecond) // nolint + + // once a waiter exists, recycle the token instead of advancing time + rl.RecycleToken() + + // cancel the context so that we return an error + cancel() + + // wait until done so we know assert.NoError was called + assert.Eventually(t, func() bool { return asserted.Load() }, time.Second, time.Millisecond) +} diff --git a/common/quotas/dynamic_rate_limiter_impl.go b/common/quotas/dynamic_rate_limiter_impl.go index dd045126410..5ff38951ac8 100644 --- a/common/quotas/dynamic_rate_limiter_impl.go +++ b/common/quotas/dynamic_rate_limiter_impl.go @@ -165,3 +165,8 @@ func (d *DynamicRateLimiterImpl) maybeRefresh() { func (d *DynamicRateLimiterImpl) TokensAt(t time.Time) int { return d.rateLimiter.TokensAt(t) } + +// RecycleToken returns a token to the rate limiter +func (d *DynamicRateLimiterImpl) RecycleToken() { + d.rateLimiter.RecycleToken() +} diff --git a/common/quotas/multi_rate_limiter_impl.go b/common/quotas/multi_rate_limiter_impl.go index 731a5dd0e69..b85dbbcc110 100644 --- a/common/quotas/multi_rate_limiter_impl.go +++ b/common/quotas/multi_rate_limiter_impl.go @@ -95,8 +95,8 @@ func (rl *MultiRateLimiterImpl) Reserve() Reservation { return rl.ReserveN(time.Now(), 1) } -// ReserveN returns a Reservation that indicates how long the caller -// must wait before event happen. +// ReserveN calls ReserveN on its list of rate limiters and returns a MultiReservation that is a list of the +// individual reservation objects indicating how long the caller must wait before the event can happen. func (rl *MultiRateLimiterImpl) ReserveN(now time.Time, numToken int) Reservation { length := len(rl.rateLimiters) reservations := make([]Reservation, 0, length) @@ -116,12 +116,12 @@ func (rl *MultiRateLimiterImpl) ReserveN(now time.Time, numToken int) Reservatio return NewMultiReservation(true, reservations) } -// Wait waits up till deadline for a rate limit token +// Wait waits up till maximum deadline for a rate limit token func (rl *MultiRateLimiterImpl) Wait(ctx context.Context) error { return rl.WaitN(ctx, 1) } -// WaitN waits up till deadline for n rate limit token +// WaitN waits up till maximum deadline for n rate limit tokens func (rl *MultiRateLimiterImpl) WaitN(ctx context.Context, numToken int) error { select { case <-ctx.Done(): @@ -160,7 +160,7 @@ func (rl *MultiRateLimiterImpl) WaitN(ctx context.Context, numToken int) error { } } -// Rate returns the rate per second for this rate limiter +// Rate returns the minimum rate per second for this rate limiter func (rl *MultiRateLimiterImpl) Rate() float64 { result := rl.rateLimiters[0].Rate() for _, rateLimiter := range rl.rateLimiters { @@ -172,7 +172,7 @@ func (rl *MultiRateLimiterImpl) Rate() float64 { return result } -// Burst returns the burst for this rate limiter +// Burst returns the minimum burst for this rate limiter func (rl *MultiRateLimiterImpl) Burst() int { result := rl.rateLimiters[0].Burst() for _, rateLimiter := range rl.rateLimiters { @@ -191,3 +191,11 @@ func (rl *MultiRateLimiterImpl) TokensAt(t time.Time) int { } return tokens } + +// RecycleToken returns a token to each sub-rate-limiter, unblocking each +// sub-rate-limiter's WaitN callers. +func (rl *MultiRateLimiterImpl) RecycleToken() { + for _, rateLimiter := range rl.rateLimiters { + rateLimiter.RecycleToken() + } +} diff --git a/common/quotas/multi_reservation_impl.go b/common/quotas/multi_reservation_impl.go index 06971290f2f..6c396c757f6 100644 --- a/common/quotas/multi_reservation_impl.go +++ b/common/quotas/multi_reservation_impl.go @@ -81,6 +81,7 @@ func (r *MultiReservationImpl) Delay() time.Duration { // DelayFrom returns the duration for which the reservation holder must wait // before taking the reserved action. Zero duration means act immediately. +// MultiReservation DelayFrom returns the maximum delay of all its sub-reservations. func (r *MultiReservationImpl) DelayFrom(now time.Time) time.Duration { if !r.ok { return InfDuration diff --git a/common/quotas/rate_limiter.go b/common/quotas/rate_limiter.go index d8ffd5bd2f2..0ca4a128d28 100644 --- a/common/quotas/rate_limiter.go +++ b/common/quotas/rate_limiter.go @@ -68,5 +68,10 @@ type ( // TokensAt returns the number of tokens that will be available at time t TokensAt(t time.Time) int + + // RecycleToken immediately unblocks another process that is waiting for a token, if + // a waiter exists. A token should be recycled when the action being rate limited was + // not completed for some reason (i.e. a task is not dispatched because it was invalid). + RecycleToken() } ) diff --git a/common/quotas/rate_limiter_impl.go b/common/quotas/rate_limiter_impl.go index b3114a385f4..571a3100359 100644 --- a/common/quotas/rate_limiter_impl.go +++ b/common/quotas/rate_limiter_impl.go @@ -60,12 +60,12 @@ func NewRateLimiter(newRPS float64, newBurst int) *RateLimiterImpl { return rl } -// SetRate set the rate of the rate limiter +// SetRPS sets the rate of the rate limiter func (rl *RateLimiterImpl) SetRPS(rps float64) { rl.refreshInternalRateLimiterImpl(&rps, nil) } -// SetBurst set the burst of the rate limiter +// SetBurst sets the burst of the rate limiter func (rl *RateLimiterImpl) SetBurst(burst int) { rl.refreshInternalRateLimiterImpl(nil, &burst) } @@ -78,7 +78,7 @@ func (rl *RateLimiterImpl) ReserveN(now time.Time, n int) Reservation { return rl.ClockedRateLimiter.ReserveN(now, n) } -// SetRateBurst set the rps & burst of the rate limiter +// SetRateBurst sets the rps & burst of the rate limiter func (rl *RateLimiterImpl) SetRateBurst(rps float64, burst int) { rl.refreshInternalRateLimiterImpl(&rps, &burst) } @@ -132,3 +132,8 @@ func (rl *RateLimiterImpl) refreshInternalRateLimiterImpl( rl.SetBurstAt(now, rl.burst) } } + +// RecycleToken returns a token to the rate limiter +func (rl *RateLimiterImpl) RecycleToken() { + rl.ClockedRateLimiter.RecycleToken() +} diff --git a/common/quotas/rate_limiter_mock.go b/common/quotas/rate_limiter_mock.go index 7f0aa5a4465..61d71de5090 100644 --- a/common/quotas/rate_limiter_mock.go +++ b/common/quotas/rate_limiter_mock.go @@ -120,6 +120,18 @@ func (mr *MockRateLimiterMockRecorder) Rate() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rate", reflect.TypeOf((*MockRateLimiter)(nil).Rate)) } +// RecycleToken mocks base method. +func (m *MockRateLimiter) RecycleToken() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RecycleToken") +} + +// RecycleToken indicates an expected call of RecycleToken. +func (mr *MockRateLimiterMockRecorder) RecycleToken() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecycleToken", reflect.TypeOf((*MockRateLimiter)(nil).RecycleToken)) +} + // Reserve mocks base method. func (m *MockRateLimiter) Reserve() Reservation { m.ctrl.T.Helper() diff --git a/service/matching/matcher.go b/service/matching/matcher.go index c6f73530012..8ea28fa64e6 100644 --- a/service/matching/matcher.go +++ b/service/matching/matcher.go @@ -170,6 +170,11 @@ func (tm *TaskMatcher) Offer(ctx context.Context, task *internalTask) (bool, err metrics.SyncThrottlePerTaskQueueCounter.With(tm.metricsHandler).Record(1) return false, err } + // because we waited on the rate limiter to offer this task, + // attach the rate limiter's RecycleToken func to the task + // so that if the task is later determined to be invalid, + // we can recycle the token it used. + task.recycleToken = tm.rateLimiter.RecycleToken } select { @@ -317,6 +322,12 @@ func (tm *TaskMatcher) MustOffer(ctx context.Context, task *internalTask, interr return err } + // because we waited on the rate limiter to offer this task, + // attach the rate limiter's RecycleToken func to the task + // so that if the task is later determined to be invalid, + // we can recycle the token it used. + task.recycleToken = tm.rateLimiter.RecycleToken + // attempt a match with local poller first. When that // doesn't succeed, try both local match and remote match select { @@ -393,9 +404,9 @@ forLoop: } cancel() // at this point, we forwarded the task to a parent partition which - // in turn dispatched the task to a poller. Make sure we delete the - // task from the database - task.finish(nil) + // in turn dispatched the task to a poller, because there was no error. + // Make sure we delete the task from the database. + task.finish(nil, true) tm.emitDispatchLatency(task, true) return nil case <-ctx.Done(): diff --git a/service/matching/matcher_test.go b/service/matching/matcher_test.go index 1cf5e2dda21..24827b7d303 100644 --- a/service/matching/matcher_test.go +++ b/service/matching/matcher_test.go @@ -117,7 +117,7 @@ func (t *MatcherTestSuite) TestLocalSyncMatch() { task, err := t.childMatcher.Poll(ctx, &pollMetadata{}) cancel() if err == nil { - task.finish(nil) + task.finish(nil, true) } }() @@ -153,7 +153,7 @@ func (t *MatcherTestSuite) testRemoteSyncMatch(taskSource enumsspb.TaskSource) { task, err := t.childMatcher.Poll(ctx, &pollMetadata{}) cancel() if err == nil && !task.isStarted() { - task.finish(nil) + task.finish(nil, true) } }() @@ -165,7 +165,7 @@ func (t *MatcherTestSuite) testRemoteSyncMatch(taskSource enumsspb.TaskSource) { if err != nil { remotePollErr = err } else { - task.finish(nil) + task.finish(nil, true) remotePollResp = matchingservice.PollWorkflowTaskQueueResponse{ WorkflowExecution: task.workflowExecution(), } @@ -421,6 +421,7 @@ func (t *MatcherTestSuite) TestAvoidForwardingWhenBacklogIsOldButReconsider() { } func (t *MatcherTestSuite) TestBacklogAge() { + t.T().Skip("flaky test") t.Equal(emptyBacklogAge, t.rootMatcher.getBacklogAge()) youngBacklogTask := newInternalTaskFromBacklog(randomTaskInfoWithAge(time.Second), nil) @@ -579,7 +580,7 @@ func (t *MatcherTestSuite) TestQueryLocalSyncMatch() { task, err := t.childMatcher.PollForQuery(ctx, &pollMetadata{}) cancel() if err == nil && task.isQuery() { - task.finish(nil) + task.finish(nil, true) } }() @@ -602,7 +603,7 @@ func (t *MatcherTestSuite) TestQueryRemoteSyncMatch() { task, err := t.childMatcher.PollForQuery(ctx, &pollMetadata{}) cancel() if err == nil && task.isQuery() { - task.finish(nil) + task.finish(nil, true) } }() @@ -615,7 +616,7 @@ func (t *MatcherTestSuite) TestQueryRemoteSyncMatch() { if err != nil { remotePollErr = err } else if task.isQuery() { - task.finish(nil) + task.finish(nil, true) querySet.Swap(true) remotePollResp = matchingservice.PollWorkflowTaskQueueResponse{ Query: &querypb.WorkflowQuery{}, @@ -666,7 +667,7 @@ func (t *MatcherTestSuite) TestQueryRemoteSyncMatchError() { cancel() if err == nil && task.isQuery() { matched = true - task.finish(nil) + task.finish(nil, true) } }() @@ -704,7 +705,7 @@ func (t *MatcherTestSuite) TestMustOfferLocalMatch() { task, err := t.childMatcher.Poll(ctx, &pollMetadata{}) cancel() if err == nil { - task.finish(nil) + task.finish(nil, true) } }() @@ -733,7 +734,7 @@ func (t *MatcherTestSuite) TestMustOfferRemoteMatch() { if err != nil { remotePollErr = err } else { - task.finish(nil) + task.finish(nil, true) remotePollResp = matchingservice.PollWorkflowTaskQueueResponse{ WorkflowExecution: task.workflowExecution(), } diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index 7f807ef76bc..43083bbb45d 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -558,7 +558,7 @@ pollLoop: } if task.isQuery() { - task.finish(nil) // this only means query task sync match succeed. + task.finish(nil, true) // this only means query task sync match succeed. // for query task, we don't need to update history to record workflow task started. but we need to know // the NextEventID and the currently set sticky task queue. @@ -618,7 +618,7 @@ pollLoop: case *serviceerror.Internal, *serviceerror.DataLoss: e.nonRetryableErrorsDropTask(task, taskQueueName, err) // drop the task as otherwise task would be stuck in a retry-loop - task.finish(nil) + task.finish(nil, false) case *serviceerror.NotFound: // mutable state not found, workflow not running or workflow task not found e.logger.Info("Workflow task not found", tag.WorkflowTaskQueueName(taskQueueName), @@ -630,10 +630,10 @@ pollLoop: tag.WorkflowEventID(task.event.Data.GetScheduledEventId()), tag.Error(err), ) - task.finish(nil) + task.finish(nil, false) case *serviceerrors.TaskAlreadyStarted: e.logger.Debug("Duplicated workflow task", tag.WorkflowTaskQueueName(taskQueueName), tag.TaskID(task.event.GetTaskId())) - task.finish(nil) + task.finish(nil, false) case *serviceerrors.ObsoleteDispatchBuildId: // history should've scheduled another task on the right build ID. dropping this one. e.logger.Info("dropping workflow task due to invalid build ID", @@ -645,9 +645,9 @@ pollLoop: tag.TaskVisibilityTimestamp(timestamp.TimeValue(task.event.Data.GetCreateTime())), tag.BuildId(requestClone.WorkerVersionCapabilities.GetBuildId()), ) - task.finish(nil) + task.finish(nil, false) default: - task.finish(err) + task.finish(err, false) if err.Error() == common.ErrNamespaceHandover.Error() { // do not keep polling new tasks when namespace is in handover state // as record start request will be rejected by history service @@ -658,7 +658,7 @@ pollLoop: continue pollLoop } - task.finish(nil) + task.finish(nil, true) return e.createPollWorkflowTaskQueueResponse(task, resp, opMetrics), nil } } @@ -783,7 +783,7 @@ pollLoop: case *serviceerror.Internal, *serviceerror.DataLoss: e.nonRetryableErrorsDropTask(task, taskQueueName, err) // drop the task as otherwise task would be stuck in a retry-loop - task.finish(nil) + task.finish(nil, false) case *serviceerror.NotFound: // mutable state not found, workflow not running or activity info not found e.logger.Info("Activity task not found", tag.WorkflowNamespaceID(task.event.Data.GetNamespaceId()), @@ -795,10 +795,10 @@ pollLoop: tag.WorkflowEventID(task.event.Data.GetScheduledEventId()), tag.Error(err), ) - task.finish(nil) + task.finish(nil, false) case *serviceerrors.TaskAlreadyStarted: e.logger.Debug("Duplicated activity task", tag.WorkflowTaskQueueName(taskQueueName), tag.TaskID(task.event.GetTaskId())) - task.finish(nil) + task.finish(nil, false) case *serviceerrors.ObsoleteDispatchBuildId: // history should've scheduled another task on the right build ID. dropping this one. e.logger.Info("dropping activity task due to invalid build ID", @@ -810,9 +810,9 @@ pollLoop: tag.TaskVisibilityTimestamp(timestamp.TimeValue(task.event.Data.GetCreateTime())), tag.BuildId(requestClone.WorkerVersionCapabilities.GetBuildId()), ) - task.finish(nil) + task.finish(nil, false) default: - task.finish(err) + task.finish(err, false) if err.Error() == common.ErrNamespaceHandover.Error() { // do not keep polling new tasks when namespace is in handover state // as record start request will be rejected by history service @@ -822,7 +822,7 @@ pollLoop: continue pollLoop } - task.finish(nil) + task.finish(nil, true) return e.createPollActivityTaskQueueResponse(task, resp, opMetrics), nil } } @@ -1738,7 +1738,7 @@ pollLoop: return task.pollNexusTaskQueueResponse(), nil } - task.finish(err) + task.finish(err, true) if err != nil { continue pollLoop } diff --git a/service/matching/matching_engine_test.go b/service/matching/matching_engine_test.go index 5ef50614fea..0ad8d81b396 100644 --- a/service/matching/matching_engine_test.go +++ b/service/matching/matching_engine_test.go @@ -1990,7 +1990,7 @@ func (s *matchingEngineSuite) TestAddTaskAfterStartFailure() { task, _, err := s.matchingEngine.pollTask(context.Background(), dbq.partition, &pollMetadata{}) s.NoError(err) - task.finish(errors.New("test error")) + task.finish(serviceerror.NewInternal("test error"), true) s.EqualValues(1, s.taskManager.getTaskCount(dbq)) task2, _, err := s.matchingEngine.pollTask(context.Background(), dbq.partition, &pollMetadata{}) s.NoError(err) @@ -2001,7 +2001,7 @@ func (s *matchingEngineSuite) TestAddTaskAfterStartFailure() { s.Equal(task.event.Data.GetRunId(), task2.event.Data.GetRunId()) s.Equal(task.event.Data.GetScheduledEventId(), task2.event.Data.GetScheduledEventId()) - task2.finish(nil) + task2.finish(nil, true) s.EqualValues(0, s.taskManager.getTaskCount(dbq)) } @@ -2660,7 +2660,7 @@ func (s *matchingEngineSuite) TestUnknownBuildId_Match() { s.NoError(err) s.Equal("wf", task.event.Data.WorkflowId) s.Equal(int64(123), task.event.Data.ScheduledEventId) - task.finish(nil) + task.finish(nil, true) wg.Done() }() @@ -2763,7 +2763,7 @@ func (s *matchingEngineSuite) TestDemotedMatch() { s.Require().NoError(err) s.Equal("wf", task.event.Data.WorkflowId) s.Equal(int64(123), task.event.Data.ScheduledEventId) - task.finish(nil) + task.finish(nil, true) } func (s *matchingEngineSuite) TestUnloadOnMembershipChange() { diff --git a/service/matching/physical_task_queue_manager.go b/service/matching/physical_task_queue_manager.go index 13ff77ddd35..cb9e93cd649 100644 --- a/service/matching/physical_task_queue_manager.go +++ b/service/matching/physical_task_queue_manager.go @@ -360,7 +360,7 @@ func (c *physicalTaskQueueManagerImpl) PollTask( // history, but this is more efficient. if task.event != nil && IsTaskExpired(task.event.AllocatedTaskInfo) { c.metricsHandler.Counter(metrics.ExpiredTasksPerTaskQueueCounter.Name()).Record(1) - task.finish(nil) + task.finish(nil, false) continue } @@ -395,7 +395,7 @@ func (c *physicalTaskQueueManagerImpl) ProcessSpooledTask( task *internalTask, ) error { if !c.taskValidator.maybeValidate(task.event.AllocatedTaskInfo, c.queue.TaskType()) { - task.finish(nil) + task.finish(nil, false) c.metricsHandler.Counter(metrics.ExpiredTasksPerTaskQueueCounter.Name()).Record(1) // Don't try to set read level here because it may have been advanced already. return nil diff --git a/service/matching/physical_task_queue_manager_test.go b/service/matching/physical_task_queue_manager_test.go index 3034c6a9e49..fabe212e99e 100644 --- a/service/matching/physical_task_queue_manager_test.go +++ b/service/matching/physical_task_queue_manager_test.go @@ -234,7 +234,7 @@ func runOneShotPoller(ctx context.Context, tqm physicalTaskQueueManager) (*goro. out <- err return nil } - task.finish(err) + task.finish(err, true) out <- task return nil }) diff --git a/service/matching/task.go b/service/matching/task.go index be2fb037542..ea3c3146c65 100644 --- a/service/matching/task.go +++ b/service/matching/task.go @@ -78,6 +78,7 @@ type ( // redirectInfo is only set when redirect rule is applied on the task. for forwarded tasks, this is populated // based on forwardInfo. redirectInfo *taskqueuespb.BuildIdRedirectInfo + recycleToken func() } ) @@ -234,7 +235,16 @@ func (task *internalTask) pollNexusTaskQueueResponse() *matchingservice.PollNexu // finish marks a task as finished. Should be called after a poller picks up a task // and marks it as started. If the task is unable to marked as started, then this // method should be called with a non-nil error argument. -func (task *internalTask) finish(err error) { +// +// If the task took a rate limit token and didn't "use" it by actually dispatching the task, +// finish will be called with wasValid=false and task.recycleToken=clockedRateLimiter.RecycleToken, +// so finish will call the rate limiter's RecycleToken to give the unused token back to any process +// that is waiting on the token, if one exists. +func (task *internalTask) finish(err error, wasValid bool) { + if !wasValid && task.recycleToken != nil { + task.recycleToken() + } + switch { case task.responseC != nil: task.responseC <- err diff --git a/service/matching/task_validation.go b/service/matching/task_validation.go index aff15851194..48ddbc95fe7 100644 --- a/service/matching/task_validation.go +++ b/service/matching/task_validation.go @@ -45,6 +45,12 @@ const ( type ( taskValidator interface { + // maybeValidate checks if a task has expired / is valid + // if return false, then task is invalid and should be discarded + // if return true, then task is *maybe-valid*, and should be dispatched + // + // a task is invalid if this task is already failed; timeout; completed, etc. + // a task is *not invalid* if this task can be started, or caller cannot verify the validity maybeValidate( task *persistencespb.AllocatedTaskInfo, taskType enumspb.TaskQueueType, @@ -80,12 +86,6 @@ func newTaskValidator( } } -// check if a task has expired / is valid -// if return false, then task is invalid and should be discarded -// if return true, then task is *maybe-valid*, and should be dispatched -// -// a task is invalid if this task is already failed; timeout; completed, etc -// a task is *not invalid* if this task can be started, or caller cannot verify the validity func (v *taskValidatorImpl) maybeValidate( task *persistencespb.AllocatedTaskInfo, taskType enumspb.TaskQueueType,