diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index cf0c18b..3dc1b1d 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -29,6 +29,7 @@ ) from controlflow.utilities.context import ctx from controlflow.utilities.general import ControlFlowModel, hash_objects +from controlflow.utilities.prefect import create_markdown_artifact, prefect_task from .memory import Memory @@ -247,6 +248,7 @@ def plan( context=context, ) + @prefect_task(task_run_name="Call LLM") def _run_model( self, messages: list[BaseMessage], @@ -280,6 +282,19 @@ def _run_model( yield AgentMessage(agent=self, message=response) + create_markdown_artifact( + markdown=f""" +{response.content or '(No content)'} + +#### Payload +```json +{response.json(indent=2)} +``` +""", + description=f"LLM Response for Agent {self.name}", + key="agent-message", + ) + if controlflow.settings.log_all_messages: logger.debug(f"Response: {response}") @@ -288,6 +303,7 @@ def _run_model( result = handle_tool_call(tool_call, tools=tools) yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result) + @prefect_task(task_run_name="Call LLM") async def _run_model_async( self, messages: list[BaseMessage], @@ -321,6 +337,19 @@ async def _run_model_async( yield AgentMessage(agent=self, message=response) + create_markdown_artifact( + markdown=f""" +{response.content or '(No content)'} + +#### Payload +```json +{response.json(indent=2)} +``` +""", + description=f"LLM Response for Agent {self.name}", + key="agent-message", + ) + if controlflow.settings.log_all_messages: logger.debug(f"Response: {response}") diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index 6cc1b84..6c0cc07 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -1,7 +1,8 @@ import uuid -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from prefect.context import FlowRunContext from pydantic import Field import controlflow @@ -11,6 +12,7 @@ from controlflow.utilities.context import ctx from controlflow.utilities.general import ControlFlowModel from controlflow.utilities.logging import get_logger +from controlflow.utilities.prefect import prefect_flow_context if TYPE_CHECKING: pass @@ -109,10 +111,17 @@ def add_events(self, events: list[Event]): self.history.add_events(thread_id=self.thread_id, events=events) @contextmanager - def create_context(self): - # creating a new flow will reset any parent task tracking - with ctx(flow=self, tasks=None): - yield self + def create_context(self, **prefect_kwargs): + # create a new Prefect flow if we're not already in a flow run + if FlowRunContext.get() is None: + prefect_context = prefect_flow_context(**prefect_kwargs) + else: + prefect_context = nullcontext() + + with prefect_context: + # creating a new flow will reset any parent task tracking + with ctx(flow=self, tasks=None): + yield self def get_flow() -> Optional[Flow]: diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 7e3f047..69fb809 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -16,6 +16,7 @@ from controlflow.tasks.task import Task from controlflow.tools.tools import Tool, as_tools from controlflow.utilities.general import ControlFlowModel +from controlflow.utilities.prefect import prefect_task logger = logging.getLogger(__name__) @@ -124,17 +125,10 @@ def get_tools(self) -> list[Tool]: tools = as_tools(tools) return tools + @prefect_task(task_run_name="Orchestrator.run()") def run( self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None ): - """ - Run the orchestration process until completion or limits are reached. - - Args: - max_llm_calls (int, optional): Maximum number of LLM calls to make. - max_agent_turns (int, optional): Maximum number of agent turns to run - (each turn can consist of multiple LLM calls) - """ import controlflow.events.orchestrator_events call_count = 0 @@ -163,7 +157,6 @@ def run( logger.debug(f"Max agent turns reached: {max_agent_turns}") break - # this check seems redundant to the check below, but this one exits the outer loop if max_llm_calls is not None and call_count >= max_llm_calls: break @@ -171,43 +164,20 @@ def run( self.turn_strategy.begin_turn() # Mark assigned tasks as running - for task in (assigned_tasks := self.get_tasks("assigned")): + for task in self.get_tasks("assigned"): if not task.is_running(): task.mark_running() self.flow.add_events( [ OrchestratorMessage( - content=f"Starting task {task.name} (ID {task.id}) with objective: {task.objective}" + content=f"Starting task {task.name} (ID {task.id}) " + f"with objective: {task.objective}" ) ] ) - # Execute LLM calls until the turn should end - while not self.turn_strategy.should_end_turn(): - for task in assigned_tasks: - if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: - task.mark_failed( - reason="Max LLM calls reached for this task." - ) - else: - task._llm_calls += 1 - - # Check if there are any ready tasks left - if not any(t.is_ready() for t in assigned_tasks): - logger.debug("No `ready` tasks to run") - break - - call_count += 1 - messages = self.compile_messages() - tools = self.get_tools() - - for event in self.agent._run_model(messages=messages, tools=tools): - self.handle_event(event) - - # Check if we've reached the call limit within a turn - if max_llm_calls is not None and call_count >= max_llm_calls: - logger.debug(f"Max LLM calls reached: {max_llm_calls}") - break + # Run the agent's turn + call_count += self.run_agent_turn(max_llm_calls - call_count) # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -231,6 +201,7 @@ def run( ) ) + @prefect_task async def run_async( self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None ): @@ -270,7 +241,6 @@ async def run_async( logger.debug(f"Max agent turns reached: {max_agent_turns}") break - # this check seems redundant to the check below, but this one exits the outer loop if max_llm_calls is not None and call_count >= max_llm_calls: break @@ -278,7 +248,7 @@ async def run_async( self.turn_strategy.begin_turn() # Mark assigned tasks as running - for task in (assigned_tasks := self.get_tasks("assigned")): + for task in self.get_tasks("assigned"): if not task.is_running(): task.mark_running() self.flow.add_events( @@ -289,34 +259,10 @@ async def run_async( ] ) - # Execute LLM calls until the turn should end - while not self.turn_strategy.should_end_turn(): - for task in assigned_tasks: - if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: - task.mark_failed( - reason="Max LLM calls reached for this task." - ) - else: - task._llm_calls += 1 - - # Check if there are any ready tasks left - if not any(t.is_ready() for t in assigned_tasks): - logger.debug("No `ready` tasks to run") - break - - call_count += 1 - messages = self.compile_messages() - tools = self.get_tools() - - async for event in self.agent._run_model_async( - messages=messages, tools=tools - ): - self.handle_event(event) - - # Check if we've reached the call limit within a turn - if max_llm_calls is not None and call_count >= max_llm_calls: - logger.debug(f"Max LLM calls reached: {max_llm_calls}") - break + # Run the agent's turn + call_count += await self.run_agent_turn_async( + max_llm_calls - call_count + ) # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -340,6 +286,88 @@ async def run_async( ) ) + @prefect_task(task_run_name="Agent turn: {self.agent.name}") + def run_agent_turn(self, max_llm_calls: Optional[int]) -> int: + """ + Run a single agent turn, which may consist of multiple LLM calls. + + Args: + max_llm_calls (Optional[int]): The number of LLM calls allowed. + + Returns: + int: The number of LLM calls made during this turn. + """ + call_count = 0 + assigned_tasks = self.get_tasks("assigned") + + while not self.turn_strategy.should_end_turn(): + for task in assigned_tasks: + if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: + task.mark_failed(reason="Max LLM calls reached for this task.") + else: + task._llm_calls += 1 + + # Check if there are any ready tasks left + if not any(t.is_ready() for t in assigned_tasks): + logger.debug("No `ready` tasks to run") + break + + call_count += 1 + messages = self.compile_messages() + tools = self.get_tools() + + for event in self.agent._run_model(messages=messages, tools=tools): + self.handle_event(event) + + # Check if we've reached the call limit within a turn + if max_llm_calls is not None and call_count >= max_llm_calls: + logger.debug(f"Max LLM calls reached: {max_llm_calls}") + break + + return call_count + + @prefect_task + async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int: + """ + Run a single agent turn asynchronously, which may consist of multiple LLM calls. + + Args: + max_llm_calls (Optional[int]): The number of LLM calls allowed. + + Returns: + int: The number of LLM calls made during this turn. + """ + call_count = 0 + assigned_tasks = self.get_tasks("assigned") + + while not self.turn_strategy.should_end_turn(): + for task in assigned_tasks: + if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: + task.mark_failed(reason="Max LLM calls reached for this task.") + else: + task._llm_calls += 1 + + # Check if there are any ready tasks left + if not any(t.is_ready() for t in assigned_tasks): + logger.debug("No `ready` tasks to run") + break + + call_count += 1 + messages = self.compile_messages() + tools = self.get_tools() + + async for event in self.agent._run_model_async( + messages=messages, tools=tools + ): + self.handle_event(event) + + # Check if we've reached the call limit within a turn + if max_llm_calls is not None and call_count >= max_llm_calls: + logger.debug(f"Max LLM calls reached: {max_llm_calls}") + break + + return call_count + def compile_prompt(self) -> str: """ Compile the prompt for the current turn. diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index ae91dac..4950078 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -16,6 +16,7 @@ _SpecialGenericAlias, ) +from prefect.context import TaskRunContext from pydantic import ( Field, PydanticSchemaGenerationError, @@ -47,6 +48,12 @@ logger = get_logger(__name__) +def get_task_run_name(): + context = TaskRunContext.get() + task = context.parameters["self"] + return f"Task.run() ({task.friendly_name()})" + + class Labels(RootModel): root: tuple[Any, ...] @@ -296,7 +303,7 @@ def friendly_name(self): name = f'"{self.objective[:50]}..."' else: name = f'"{self.objective}"' - return f"Task {self.id} ({name})" + return f"Task #{self.id} ({name})" def serialize_for_prompt(self) -> dict: """ @@ -331,6 +338,7 @@ def add_dependency(self, task: "Task"): self.depends_on.add(task) task._downstreams.add(self) + @prefect_task(task_run_name=get_task_run_name) def run( self, agent: Optional[Agent] = None, @@ -358,6 +366,7 @@ def run( elif self.is_failed(): raise ValueError(f"{self.friendly_name()} failed: {self.result}") + @prefect_task(task_run_name=get_task_run_name) async def run_async( self, agent: Optional[Agent] = None, diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py index 8fd2b46..0d411da 100644 --- a/tests/utilities/test_testing.py +++ b/tests/utilities/test_testing.py @@ -45,6 +45,6 @@ def test_record_task_events(default_fake_llm): } assert events[3].tool_result.model_dump() == dict( tool_call_id="call_ZEPdV8mCgeBe5UHjKzm6e3pe", - str_result='Task 12345 ("say hello") marked successful.', + str_result='Task #12345 ("say hello") marked successful.', is_error=False, )