Skip to content

Commit

Permalink
Merge pull request #429 from majanjua-amzn/master
Browse files Browse the repository at this point in the history
[Lambda] Create dummy segment when trace header is incomplete
  • Loading branch information
wangzlei authored May 8, 2024
2 parents a6a3e86 + d174f8d commit 164b3bb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
32 changes: 19 additions & 13 deletions aws_xray_sdk/core/lambda_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import threading

from aws_xray_sdk import global_sdk_config
from .models.dummy_entities import DummySegment
from .models.facade_segment import FacadeSegment
from .models.trace_header import TraceHeader
from .context import Context
Expand Down Expand Up @@ -44,7 +45,7 @@ class LambdaContext(Context):
"""
Lambda service will generate a segment for each function invocation which
cannot be mutated. The context doesn't keep any manually created segment
but instead every time ``get_trace_entity()`` gets called it refresh the facade
but instead every time ``get_trace_entity()`` gets called it refresh the
segment based on environment variables set by Lambda worker.
"""
def __init__(self):
Expand All @@ -65,12 +66,12 @@ def end_segment(self, end_time=None):

def put_subsegment(self, subsegment):
"""
Refresh the facade segment every time this function is invoked to prevent
Refresh the segment every time this function is invoked to prevent
a new subsegment from being attached to a leaked segment/subsegment.
"""
current_entity = self.get_trace_entity()

if not self._is_subsegment(current_entity) and current_entity.initializing:
if not self._is_subsegment(current_entity) and (getattr(current_entity, 'initializing', None) or isinstance(current_entity, DummySegment)):
if global_sdk_config.sdk_enabled():
log.warning("Subsegment %s discarded due to Lambda worker still initializing" % subsegment.name)
return
Expand Down Expand Up @@ -99,9 +100,9 @@ def get_trace_entity(self):

def _refresh_context(self):
"""
Get current facade segment. To prevent resource leaking in Lambda worker,
Get current segment. To prevent resource leaking in Lambda worker,
every time there is segment present, we compare its trace id to current
environment variables. If it is different we create a new facade segment
environment variables. If it is different we create a new segment
and clean up subsegments stored.
"""
header_str = os.getenv(LAMBDA_TRACE_HEADER_KEY)
Expand Down Expand Up @@ -136,8 +137,8 @@ def handle_context_missing(self):

def _initialize_context(self, trace_header):
"""
Create a facade segment based on environment variables
set by AWS Lambda and initialize storage for subsegments.
Create a segment based on environment variables set by
AWS Lambda and initialize storage for subsegments.
"""
sampled = None
if not global_sdk_config.sdk_enabled():
Expand All @@ -148,12 +149,17 @@ def _initialize_context(self, trace_header):
elif trace_header.sampled == 1:
sampled = True

segment = FacadeSegment(
name='facade',
traceid=trace_header.root,
entityid=trace_header.parent,
sampled=sampled,
)
segment = None
if not trace_header.root or not trace_header.parent or trace_header.sampled is None:
segment = DummySegment()
log.debug("Creating NoOp/Dummy parent segment")
else:
segment = FacadeSegment(
name='facade',
traceid=trace_header.root,
entityid=trace_header.parent,
sampled=sampled,
)
segment.save_origin_trace_header(trace_header)
setattr(self._local, 'segment', segment)
setattr(self._local, 'entities', [])
34 changes: 30 additions & 4 deletions tests/test_lambda_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from aws_xray_sdk import global_sdk_config
import pytest
from aws_xray_sdk.core import lambda_launcher
from aws_xray_sdk.core.models.dummy_entities import DummySegment
from aws_xray_sdk.core.models.subsegment import Subsegment


Expand Down Expand Up @@ -67,23 +68,48 @@ def test_disable():


def test_non_initialized():
# Context that hasn't been initialized by lambda container should not add subsegments to the facade segment.
# Context that hasn't been initialized by lambda container should not add subsegments to the dummy segment.
temp_header_var = os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY]
del os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY]

temp_context = lambda_launcher.LambdaContext()
facade_segment = temp_context.get_trace_entity()
subsegment = Subsegment("TestSubsegment", "local", facade_segment)
dummy_segment = temp_context.get_trace_entity()
subsegment = Subsegment("TestSubsegment", "local", dummy_segment)
temp_context.put_subsegment(subsegment)

assert temp_context.get_trace_entity() == facade_segment
assert temp_context.get_trace_entity() == dummy_segment

# "Lambda" container added metadata now. Should see subsegment now.
# The following put_segment call will overwrite the dummy segment in the context with an intialized facade segment that accepts a subsegment.
os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = temp_header_var
temp_context.put_subsegment(subsegment)

assert temp_context.get_trace_entity() == subsegment

def test_lambda_passthrough():
# Hold previous environment value
temp_header_var = os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY]
del os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY]

# Set header to lambda passthrough style header
os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = "Root=%s;Lineage=10:1234abcd:3" % TRACE_ID

temp_context = lambda_launcher.LambdaContext()
dummy_segment = temp_context.get_trace_entity()
subsegment = Subsegment("TestSubsegment", "local", dummy_segment)
temp_context.put_subsegment(subsegment)

# Resulting entity is not the same dummy segment, so simply check that it is a dummy segment
assert isinstance(temp_context.get_trace_entity(), DummySegment)

# Reset header value and ensure behaviour returns to normal
del os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY]
os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = temp_header_var
temp_context.put_subsegment(subsegment)

assert temp_context.get_trace_entity() == subsegment



def test_set_trace_entity():
segment = context.get_trace_entity()
Expand Down

0 comments on commit 164b3bb

Please sign in to comment.