Skip to content

Commit

Permalink
Fix Prometheus metrics
Browse files Browse the repository at this point in the history
This fixes an issue with the metric for failed tasks and also adds test coverage for the metrics.

Follow up to #181
  • Loading branch information
tillprochaska committed Jul 4, 2024
1 parent f12e13e commit 2be9117
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 13 deletions.
21 changes: 8 additions & 13 deletions servicelayer/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,19 +558,14 @@ def handle(self, task: Task, channel) -> Tuple[bool, bool]:
Returns a tuple of (success, retry)."""
success = True
retry = True

task_retry_count = task.get_retry_count(self.conn)

try:
dataset = Dataset(
conn=self.conn, name=dataset_from_collection_id(task.collection_id)
)
if dataset.should_execute(task.task_id):
task_retry_count = task.get_retry_count(self.conn)
if task_retry_count:
metrics.TASKS_FAILED.labels(
stage=task.operation,
retries=task_retry_count,
failed_permanently=False,
).inc()

if task_retry_count > settings.WORKER_RETRY:
raise MaxRetriesExceededError(
f"Max retries reached for task {task.task_id}. Aborting."
Expand Down Expand Up @@ -604,11 +599,6 @@ def handle(self, task: Task, channel) -> Tuple[bool, bool]:
# In this case, a task ID was found neither in the
# list of Pending, nor the list of Running tasks
# in Redis. It was never attempted.
metrics.TASKS_FAILED.labels(
stage=task.operation,
retries=0,
failed_permanently=True,
).inc()
success = False
except MaxRetriesExceededError:
log.exception(
Expand All @@ -618,6 +608,11 @@ def handle(self, task: Task, channel) -> Tuple[bool, bool]:
retry = False
except Exception:
log.exception("Error in task handling")
metrics.TASKS_FAILED.labels(
stage=task.operation,
retries=task_retry_count,
failed_permanently=task_retry_count >= settings.WORKER_RETRY,
).inc()
success = False
finally:
self.after_task(task)
Expand Down
105 changes: 105 additions & 0 deletions tests/test_taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from random import randrange

import pika
from prometheus_client import REGISTRY
from prometheus_client.metrics import MetricWrapperBase
import pytest

from servicelayer import settings
from servicelayer.cache import get_fakeredis
Expand Down Expand Up @@ -32,6 +35,11 @@ def dispatch_task(self, task: Task) -> Task:
return task


class FailingWorker(Worker):
def dispatch_task(self, task: Task) -> Task:
raise Exception("Woops")


class TaskQueueTest(TestCase):
def test_task_queue(self):
test_queue_name = "sls-queue-ingest"
Expand Down Expand Up @@ -184,6 +192,103 @@ def did_nack():
assert dataset.is_task_tracked(Task(**body))


@pytest.fixture
def prom_registry():
# This relies on internal implementation details of the client to reset
# previously collected metrics before every test execution. Unfortunately,
# there is no clean way of achieving the same thing that doesn't add a lot
# of complexity to the test and application code.
collectors = REGISTRY._collector_to_names.keys()
for collector in collectors:
if isinstance(collector, MetricWrapperBase):
collector._metrics.clear()
collector._metric_init()

yield REGISTRY


def test_prometheus_metrics_succeeded(prom_registry):
conn = get_fakeredis()
rmq_channel = get_rabbitmq_channel()
worker = CountingWorker(conn=conn, queues=["ingest"], num_threads=1)
declare_rabbitmq_queue(channel=rmq_channel, queue="ingest")

queue_task(
rmq_channel=rmq_channel,
redis_conn=conn,
collection_id=123,
stage="ingest",
)
worker.process(blocking=False)

started = prom_registry.get_sample_value(
"servicelayer_tasks_started_total",
{"stage": "ingest"},
)
assert started == 1

succeeded = prom_registry.get_sample_value(
"servicelayer_tasks_succeeded_total",
{"stage": "ingest", "retries": "0"},
)
assert succeeded == 1

# Under the hood, histogram metrics create multiple time series tracking
# the number and sum of observations, as well as individual histogram buckets.
duration_sum = prom_registry.get_sample_value(
"servicelayer_task_duration_seconds_sum",
{"stage": "ingest"},
)
duration_count = prom_registry.get_sample_value(
"servicelayer_task_duration_seconds_count",
{"stage": "ingest"},
)
assert duration_sum > 0
assert duration_count == 1


def test_prometheus_metrics_failed(prom_registry):
conn = get_fakeredis()
rmq_channel = get_rabbitmq_channel()
worker = FailingWorker(conn=conn, queues=["ingest"], num_threads=1)
declare_rabbitmq_queue(channel=rmq_channel, queue="ingest")

queue_task(
rmq_channel=rmq_channel,
redis_conn=conn,
collection_id=123,
stage="ingest",
)
worker.process(blocking=False)

started = prom_registry.get_sample_value(
"servicelayer_tasks_started_total",
{"stage": "ingest"},
)
assert settings.WORKER_RETRY == 3
assert started == 4 # Initial attempt + 3 retries

first_attempt = REGISTRY.get_sample_value(
"servicelayer_tasks_failed_total",
{
"stage": "ingest",
"retries": "0",
"failed_permanently": "False",
},
)
assert first_attempt == 1

last_attempt = REGISTRY.get_sample_value(
"servicelayer_tasks_failed_total",
{
"stage": "ingest",
"retries": "3",
"failed_permanently": "True",
},
)
assert last_attempt == 1


def test_get_priority_bucket():
redis = get_fakeredis()
rmq_channel = get_rabbitmq_channel()
Expand Down

0 comments on commit 2be9117

Please sign in to comment.