Skip to content

Commit

Permalink
fix(langchain): langgraph traces were broken (#1895)
Browse files Browse the repository at this point in the history
  • Loading branch information
nirga authored Aug 24, 2024
1 parent 27daeff commit 3a0bdf3
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,14 @@
from opentelemetry.instrumentation.langchain.version import __version__


from opentelemetry.instrumentation.langchain.callback_wrapper import callback_wrapper
from opentelemetry.instrumentation.langchain.callback_handler import (
TraceloopCallbackHandler,
)

logger = logging.getLogger(__name__)

_instruments = ("langchain >= 0.0.346", "langchain-core > 0.1.0")

ASYNC_CALLBACK_FUNCTIONS = ("ainvoke", "astream", "atransform")
SYNC_CALLBACK_FUNCTIONS = ("invoke", "stream", "transform")
WRAPPED_METHODS = [
{"package": "langchain.agents", "class": "AgentExecutor"},
{
"package": "langchain.chains.base",
"class": "Chain",
},
{
"package": "langchain_core.runnables.base",
"class": "RunnableSequence",
},
{
"package": "langchain.prompts.base",
"class": "BasePromptTemplate",
},
{
"package": "langchain.chat_models.base",
"class": "BaseChatModel",
},
{
"package": "langchain.schema",
"class": "BaseOutputParser",
},
]


class LangchainInstrumentor(BaseInstrumentor):
"""An instrumentor for Langchain SDK."""
Expand All @@ -59,19 +35,31 @@ def instrumentation_dependencies(self) -> Collection[str]:
def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, __version__, tracer_provider)
for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_class = wrapped_method.get("class")
for func_name in SYNC_CALLBACK_FUNCTIONS + ASYNC_CALLBACK_FUNCTIONS:
wrap_function_wrapper(
wrap_package,
f"{wrap_class}.{func_name}",
callback_wrapper(tracer, wrapped_method),
)

wrap_function_wrapper(
module="langchain_core.callbacks",
name="BaseCallbackManager.__init__",
wrapper=_BaseCallbackManagerInitWrapper(TraceloopCallbackHandler(tracer)),
)

def _uninstrument(self, **kwargs):
for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_class = wrapped_method.get("class")
for func_name in SYNC_CALLBACK_FUNCTIONS + ASYNC_CALLBACK_FUNCTIONS:
unwrap(wrap_package, f"{wrap_class}.{func_name}")
unwrap("langchain_core.callbacks", "BaseCallbackManager.__init__")


class _BaseCallbackManagerInitWrapper:
def __init__(self, callback_manager: "TraceloopCallbackHandler"):
self._callback_manager = callback_manager

def __call__(
self,
wrapped,
instance,
args,
kwargs,
) -> None:
wrapped(*args, **kwargs)
for handler in instance.inheritable_handlers:
if isinstance(handler, type(self._callback_manager)):
break
else:
instance.add_handler(self._callback_manager, True)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from langchain_core.callbacks import (
BaseCallbackHandler,
BaseCallbackManager,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
Expand All @@ -22,7 +21,6 @@

from opentelemetry import context as context_api
from opentelemetry.instrumentation.langchain.utils import (
_with_tracer_wrapper,
dont_throw,
should_send_prompts,
)
Expand All @@ -47,57 +45,6 @@ class SpanHolder:
entity_path: str


@dont_throw
def _add_callback(
tracer, callbacks: Union[List[BaseCallbackHandler], BaseCallbackManager]
):
cb = SyncSpanCallbackHandler(tracer)
if isinstance(callbacks, BaseCallbackManager):
for c in callbacks.handlers:
if isinstance(c, SyncSpanCallbackHandler):
cb = c
break
else:
callbacks.add_handler(cb)
elif isinstance(callbacks, list):
for c in callbacks:
if isinstance(c, SyncSpanCallbackHandler):
cb = c
break
else:
callbacks.append(cb)


@_with_tracer_wrapper
def callback_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
"""Hook into the invoke function, config is part of args, 2nd place.
sources:
https://python.langchain.com/v0.2/docs/how_to/callbacks_attach/
https://python.langchain.com/v0.2/docs/how_to/callbacks_runtime/
"""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)

if len(args) > 1:
# args[1] is config which (may) contain the callbacks setting
callbacks = args[1].get("callbacks", [])
elif kwargs.get("config"):
callbacks = kwargs.get("config", {}).get("callbacks", [])
else:
callbacks = []

_add_callback(tracer, callbacks)

if len(args) > 1:
args[1]["callbacks"] = callbacks
elif kwargs.get("config"):
kwargs["config"]["callbacks"] = callbacks
else:
kwargs["config"] = {"callbacks": callbacks}

return wrapped(*args, **kwargs)


def _message_type_to_role(message_type: str) -> str:
if message_type == "human":
return "user"
Expand Down Expand Up @@ -260,7 +207,7 @@ def _set_chat_response(span: Span, response: LLMResult) -> None:
i += 1


class SyncSpanCallbackHandler(BaseCallbackHandler):
class TraceloopCallbackHandler(BaseCallbackHandler):
def __init__(self, tracer: Tracer) -> None:
super().__init__()
self.tracer = tracer
Expand Down Expand Up @@ -314,21 +261,23 @@ def _create_span(

if parent_run_id is not None and parent_run_id in self.spans:
span = self.tracer.start_span(
span_name, context=self.spans[parent_run_id].context, kind=kind
span_name,
context=set_span_in_context(self.spans[parent_run_id].span),
kind=kind,
)
else:
span = self.tracer.start_span(span_name)

span.set_attribute(SpanAttributes.TRACELOOP_WORKFLOW_NAME, workflow_name)
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_PATH, entity_path)

current_context = set_span_in_context(span)

token = context_api.attach(
context_api.set_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY, True)
)

self.spans[run_id] = SpanHolder(span, token, current_context, [], workflow_name, entity_name, entity_path)
self.spans[run_id] = SpanHolder(
span, token, None, [], workflow_name, entity_name, entity_path
)

if parent_run_id is not None and parent_run_id in self.spans:
self.spans[parent_run_id].children.append(run_id)
Expand All @@ -354,7 +303,7 @@ def _create_task_span(
workflow_name=workflow_name,
entity_name=entity_name,
entity_path=entity_path,
metadata=metadata
metadata=metadata,
)

span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, kind.value)
Expand Down Expand Up @@ -400,6 +349,9 @@ def on_chain_start(
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

workflow_name = ""
entity_path = ""

Expand All @@ -424,7 +376,7 @@ def on_chain_start(
workflow_name,
name,
entity_path,
metadata
metadata,
)
if should_send_prompts():
span.set_attribute(
Expand All @@ -450,6 +402,9 @@ def on_chain_end(
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

span = self._get_span(run_id)
if should_send_prompts():
span.set_attribute(
Expand Down Expand Up @@ -479,6 +434,9 @@ def on_chat_model_start(
**kwargs: Any,
) -> Any:
"""Run when Chat Model starts running."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

name = self._get_name_from_callback(serialized, kwargs=kwargs)
span = self._create_llm_span(
run_id, parent_run_id, name, LLMRequestTypeValues.CHAT, metadata=metadata
Expand All @@ -498,6 +456,9 @@ def on_llm_start(
**kwargs: Any,
) -> Any:
"""Run when Chat Model starts running."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

name = self._get_name_from_callback(serialized, kwargs=kwargs)
span = self._create_llm_span(
run_id, parent_run_id, name, LLMRequestTypeValues.COMPLETION
Expand All @@ -513,6 +474,9 @@ def on_llm_end(
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
):
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

span = self._get_span(run_id)

token_usage = (response.llm_output or {}).get("token_usage")
Expand Down Expand Up @@ -560,12 +524,21 @@ def on_tool_start(
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

name = self._get_name_from_callback(serialized, kwargs=kwargs)
workflow_name = self.get_workflow_name(parent_run_id)
entity_path = self.get_entity_path(parent_run_id)

span = self._create_task_span(
run_id, parent_run_id, name, TraceloopSpanKindValues.TOOL, workflow_name, name, entity_path
run_id,
parent_run_id,
name,
TraceloopSpanKindValues.TOOL,
workflow_name,
name,
entity_path,
)
if should_send_prompts():
span.set_attribute(
Expand All @@ -592,6 +565,9 @@ def on_tool_end(
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

span = self._get_span(run_id)
if should_send_prompts():
span.set_attribute(
Expand All @@ -618,7 +594,10 @@ def get_entity_path(self, parent_run_id: str):

if parent_span is None:
return ""
elif parent_span.entity_path == "" and parent_span.entity_name == parent_span.workflow_name:
elif (
parent_span.entity_path == ""
and parent_span.entity_name == parent_span.workflow_name
):
return ""
elif parent_span.entity_path == "":
return f"{parent_span.entity_name}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,6 @@ def default(self, o):
return super().default(o)


def _with_tracer_wrapper(func):
"""Helper for providing tracer for wrapper functions."""

def _with_tracer(tracer, to_wrap):
def wrapper(wrapped, instance, args, kwargs):
return func(tracer, to_wrap, wrapped, instance, args, kwargs)

return wrapper

return _with_tracer


def should_send_prompts():
return (
os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
Expand Down
18 changes: 6 additions & 12 deletions packages/sample-app/sample_app/langgraph_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from langgraph.prebuilt import ToolNode

from traceloop.sdk import Traceloop
from traceloop.sdk.decorators import workflow as traceloop_workflow

Traceloop.init(app_name="langgraph_example")

Expand Down Expand Up @@ -83,15 +82,10 @@ def call_model(state: MessagesState):
app = workflow.compile(checkpointer=checkpointer)


@traceloop_workflow()
def run_app():
# Use the Runnable
final_state = app.invoke(
{"messages": [HumanMessage(content="what is the weather in sf in Celsius")]},
config={"configurable": {"thread_id": 42}},
)

print(final_state["messages"][-1].content)

# Use the Runnable
final_state = app.invoke(
{"messages": [HumanMessage(content="what is the weather in sf in Celsius")]},
config={"configurable": {"thread_id": 42}},
)

run_app()
print(final_state["messages"][-1].content)

0 comments on commit 3a0bdf3

Please sign in to comment.