Skip to content

Commit

Permalink
Merge pull request #301 from PrefectHQ/prefect
Browse files Browse the repository at this point in the history
Update prefect integration
  • Loading branch information
jlowin authored Sep 11, 2024
2 parents c345086 + 6a3ed57 commit 3478c29
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 74 deletions.
29 changes: 29 additions & 0 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from controlflow.utilities.context import ctx
from controlflow.utilities.general import ControlFlowModel, hash_objects
from controlflow.utilities.prefect import create_markdown_artifact, prefect_task

from .memory import Memory

Expand Down Expand Up @@ -247,6 +248,7 @@ def plan(
context=context,
)

@prefect_task(task_run_name="Call LLM")
def _run_model(
self,
messages: list[BaseMessage],
Expand Down Expand Up @@ -280,6 +282,19 @@ def _run_model(

yield AgentMessage(agent=self, message=response)

create_markdown_artifact(
markdown=f"""
{response.content or '(No content)'}
#### Payload
```json
{response.json(indent=2)}
```
""",
description=f"LLM Response for Agent {self.name}",
key="agent-message",
)

if controlflow.settings.log_all_messages:
logger.debug(f"Response: {response}")

Expand All @@ -288,6 +303,7 @@ def _run_model(
result = handle_tool_call(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)

@prefect_task(task_run_name="Call LLM")
async def _run_model_async(
self,
messages: list[BaseMessage],
Expand Down Expand Up @@ -321,6 +337,19 @@ async def _run_model_async(

yield AgentMessage(agent=self, message=response)

create_markdown_artifact(
markdown=f"""
{response.content or '(No content)'}
#### Payload
```json
{response.json(indent=2)}
```
""",
description=f"LLM Response for Agent {self.name}",
key="agent-message",
)

if controlflow.settings.log_all_messages:
logger.debug(f"Response: {response}")

Expand Down
19 changes: 14 additions & 5 deletions src/controlflow/flows/flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from prefect.context import FlowRunContext
from pydantic import Field

import controlflow
Expand All @@ -11,6 +12,7 @@
from controlflow.utilities.context import ctx
from controlflow.utilities.general import ControlFlowModel
from controlflow.utilities.logging import get_logger
from controlflow.utilities.prefect import prefect_flow_context

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -109,10 +111,17 @@ def add_events(self, events: list[Event]):
self.history.add_events(thread_id=self.thread_id, events=events)

@contextmanager
def create_context(self):
# creating a new flow will reset any parent task tracking
with ctx(flow=self, tasks=None):
yield self
def create_context(self, **prefect_kwargs):
# create a new Prefect flow if we're not already in a flow run
if FlowRunContext.get() is None:
prefect_context = prefect_flow_context(**prefect_kwargs)
else:
prefect_context = nullcontext()

with prefect_context:
# creating a new flow will reset any parent task tracking
with ctx(flow=self, tasks=None):
yield self


def get_flow() -> Optional[Flow]:
Expand Down
162 changes: 95 additions & 67 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from controlflow.tasks.task import Task
from controlflow.tools.tools import Tool, as_tools
from controlflow.utilities.general import ControlFlowModel
from controlflow.utilities.prefect import prefect_task

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -124,17 +125,10 @@ def get_tools(self) -> list[Tool]:
tools = as_tools(tools)
return tools

@prefect_task(task_run_name="Orchestrator.run()")
def run(
self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None
):
"""
Run the orchestration process until completion or limits are reached.
Args:
max_llm_calls (int, optional): Maximum number of LLM calls to make.
max_agent_turns (int, optional): Maximum number of agent turns to run
(each turn can consist of multiple LLM calls)
"""
import controlflow.events.orchestrator_events

call_count = 0
Expand Down Expand Up @@ -163,51 +157,27 @@ def run(
logger.debug(f"Max agent turns reached: {max_agent_turns}")
break

# this check seems redundant to the check below, but this one exits the outer loop
if max_llm_calls is not None and call_count >= max_llm_calls:
break

turn_count += 1
self.turn_strategy.begin_turn()

# Mark assigned tasks as running
for task in (assigned_tasks := self.get_tasks("assigned")):
for task in self.get_tasks("assigned"):
if not task.is_running():
task.mark_running()
self.flow.add_events(
[
OrchestratorMessage(
content=f"Starting task {task.name} (ID {task.id}) with objective: {task.objective}"
content=f"Starting task {task.name} (ID {task.id}) "
f"with objective: {task.objective}"
)
]
)

# Execute LLM calls until the turn should end
while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
task.mark_failed(
reason="Max LLM calls reached for this task."
)
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
logger.debug("No `ready` tasks to run")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

for event in self.agent._run_model(messages=messages, tools=tools):
self.handle_event(event)

# Check if we've reached the call limit within a turn
if max_llm_calls is not None and call_count >= max_llm_calls:
logger.debug(f"Max LLM calls reached: {max_llm_calls}")
break
# Run the agent's turn
call_count += self.run_agent_turn(max_llm_calls - call_count)

# Select the next agent for the following turn
if available_agents := self.get_available_agents():
Expand All @@ -231,6 +201,7 @@ def run(
)
)

@prefect_task
async def run_async(
self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None
):
Expand Down Expand Up @@ -270,15 +241,14 @@ async def run_async(
logger.debug(f"Max agent turns reached: {max_agent_turns}")
break

# this check seems redundant to the check below, but this one exits the outer loop
if max_llm_calls is not None and call_count >= max_llm_calls:
break

turn_count += 1
self.turn_strategy.begin_turn()

# Mark assigned tasks as running
for task in (assigned_tasks := self.get_tasks("assigned")):
for task in self.get_tasks("assigned"):
if not task.is_running():
task.mark_running()
self.flow.add_events(
Expand All @@ -289,34 +259,10 @@ async def run_async(
]
)

# Execute LLM calls until the turn should end
while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
task.mark_failed(
reason="Max LLM calls reached for this task."
)
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
logger.debug("No `ready` tasks to run")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

async for event in self.agent._run_model_async(
messages=messages, tools=tools
):
self.handle_event(event)

# Check if we've reached the call limit within a turn
if max_llm_calls is not None and call_count >= max_llm_calls:
logger.debug(f"Max LLM calls reached: {max_llm_calls}")
break
# Run the agent's turn
call_count += await self.run_agent_turn_async(
max_llm_calls - call_count
)

# Select the next agent for the following turn
if available_agents := self.get_available_agents():
Expand All @@ -340,6 +286,88 @@ async def run_async(
)
)

@prefect_task(task_run_name="Agent turn: {self.agent.name}")
def run_agent_turn(self, max_llm_calls: Optional[int]) -> int:
"""
Run a single agent turn, which may consist of multiple LLM calls.
Args:
max_llm_calls (Optional[int]): The number of LLM calls allowed.
Returns:
int: The number of LLM calls made during this turn.
"""
call_count = 0
assigned_tasks = self.get_tasks("assigned")

while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
task.mark_failed(reason="Max LLM calls reached for this task.")
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
logger.debug("No `ready` tasks to run")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

for event in self.agent._run_model(messages=messages, tools=tools):
self.handle_event(event)

# Check if we've reached the call limit within a turn
if max_llm_calls is not None and call_count >= max_llm_calls:
logger.debug(f"Max LLM calls reached: {max_llm_calls}")
break

return call_count

@prefect_task
async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int:
"""
Run a single agent turn asynchronously, which may consist of multiple LLM calls.
Args:
max_llm_calls (Optional[int]): The number of LLM calls allowed.
Returns:
int: The number of LLM calls made during this turn.
"""
call_count = 0
assigned_tasks = self.get_tasks("assigned")

while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
task.mark_failed(reason="Max LLM calls reached for this task.")
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
logger.debug("No `ready` tasks to run")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

async for event in self.agent._run_model_async(
messages=messages, tools=tools
):
self.handle_event(event)

# Check if we've reached the call limit within a turn
if max_llm_calls is not None and call_count >= max_llm_calls:
logger.debug(f"Max LLM calls reached: {max_llm_calls}")
break

return call_count

def compile_prompt(self) -> str:
"""
Compile the prompt for the current turn.
Expand Down
11 changes: 10 additions & 1 deletion src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_SpecialGenericAlias,
)

from prefect.context import TaskRunContext
from pydantic import (
Field,
PydanticSchemaGenerationError,
Expand Down Expand Up @@ -47,6 +48,12 @@
logger = get_logger(__name__)


def get_task_run_name():
context = TaskRunContext.get()
task = context.parameters["self"]
return f"Task.run() ({task.friendly_name()})"


class Labels(RootModel):
root: tuple[Any, ...]

Expand Down Expand Up @@ -296,7 +303,7 @@ def friendly_name(self):
name = f'"{self.objective[:50]}..."'
else:
name = f'"{self.objective}"'
return f"Task {self.id} ({name})"
return f"Task #{self.id} ({name})"

def serialize_for_prompt(self) -> dict:
"""
Expand Down Expand Up @@ -331,6 +338,7 @@ def add_dependency(self, task: "Task"):
self.depends_on.add(task)
task._downstreams.add(self)

@prefect_task(task_run_name=get_task_run_name)
def run(
self,
agent: Optional[Agent] = None,
Expand Down Expand Up @@ -358,6 +366,7 @@ def run(
elif self.is_failed():
raise ValueError(f"{self.friendly_name()} failed: {self.result}")

@prefect_task(task_run_name=get_task_run_name)
async def run_async(
self,
agent: Optional[Agent] = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/utilities/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ def test_record_task_events(default_fake_llm):
}
assert events[3].tool_result.model_dump() == dict(
tool_call_id="call_ZEPdV8mCgeBe5UHjKzm6e3pe",
str_result='Task 12345 ("say hello") marked successful.',
str_result='Task #12345 ("say hello") marked successful.',
is_error=False,
)

0 comments on commit 3478c29

Please sign in to comment.