Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update prefect integration #301

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from controlflow.utilities.context import ctx
from controlflow.utilities.general import ControlFlowModel, hash_objects
from controlflow.utilities.prefect import create_markdown_artifact, prefect_task

from .memory import Memory

Expand Down Expand Up @@ -247,6 +248,7 @@ def plan(
context=context,
)

@prefect_task(task_run_name="Call LLM")
def _run_model(
self,
messages: list[BaseMessage],
Expand Down Expand Up @@ -280,6 +282,19 @@ def _run_model(

yield AgentMessage(agent=self, message=response)

create_markdown_artifact(
markdown=f"""
{response.content or '(No content)'}

#### Payload
```json
{response.json(indent=2)}
```
""",
description=f"LLM Response for Agent {self.name}",
key="agent-message",
)

if controlflow.settings.log_all_messages:
logger.debug(f"Response: {response}")

Expand All @@ -288,6 +303,7 @@ def _run_model(
result = handle_tool_call(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)

@prefect_task(task_run_name="Call LLM")
async def _run_model_async(
self,
messages: list[BaseMessage],
Expand Down Expand Up @@ -321,6 +337,19 @@ async def _run_model_async(

yield AgentMessage(agent=self, message=response)

create_markdown_artifact(
markdown=f"""
{response.content or '(No content)'}

#### Payload
```json
{response.json(indent=2)}
```
""",
description=f"LLM Response for Agent {self.name}",
key="agent-message",
)

if controlflow.settings.log_all_messages:
logger.debug(f"Response: {response}")

Expand Down
19 changes: 14 additions & 5 deletions src/controlflow/flows/flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from prefect.context import FlowRunContext
from pydantic import Field

import controlflow
Expand All @@ -11,6 +12,7 @@
from controlflow.utilities.context import ctx
from controlflow.utilities.general import ControlFlowModel
from controlflow.utilities.logging import get_logger
from controlflow.utilities.prefect import prefect_flow_context

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -109,10 +111,17 @@ def add_events(self, events: list[Event]):
self.history.add_events(thread_id=self.thread_id, events=events)

@contextmanager
def create_context(self):
# creating a new flow will reset any parent task tracking
with ctx(flow=self, tasks=None):
yield self
def create_context(self, **prefect_kwargs):
# create a new Prefect flow if we're not already in a flow run
if FlowRunContext.get() is None:
prefect_context = prefect_flow_context(**prefect_kwargs)
else:
prefect_context = nullcontext()

with prefect_context:
# creating a new flow will reset any parent task tracking
with ctx(flow=self, tasks=None):
yield self


def get_flow() -> Optional[Flow]:
Expand Down
162 changes: 95 additions & 67 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from controlflow.tasks.task import Task
from controlflow.tools.tools import Tool, as_tools
from controlflow.utilities.general import ControlFlowModel
from controlflow.utilities.prefect import prefect_task

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -124,17 +125,10 @@ def get_tools(self) -> list[Tool]:
tools = as_tools(tools)
return tools

@prefect_task(task_run_name="Orchestrator.run()")
def run(
self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None
):
"""
Run the orchestration process until completion or limits are reached.

Args:
max_llm_calls (int, optional): Maximum number of LLM calls to make.
max_agent_turns (int, optional): Maximum number of agent turns to run
(each turn can consist of multiple LLM calls)
"""
import controlflow.events.orchestrator_events

call_count = 0
Expand Down Expand Up @@ -163,51 +157,27 @@ def run(
logger.debug(f"Max agent turns reached: {max_agent_turns}")
break

# this check seems redundant to the check below, but this one exits the outer loop
if max_llm_calls is not None and call_count >= max_llm_calls:
break

turn_count += 1
self.turn_strategy.begin_turn()

# Mark assigned tasks as running
for task in (assigned_tasks := self.get_tasks("assigned")):
for task in self.get_tasks("assigned"):
if not task.is_running():
task.mark_running()
self.flow.add_events(
[
OrchestratorMessage(
content=f"Starting task {task.name} (ID {task.id}) with objective: {task.objective}"
content=f"Starting task {task.name} (ID {task.id}) "
f"with objective: {task.objective}"
)
]
)

# Execute LLM calls until the turn should end
while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
task.mark_failed(
reason="Max LLM calls reached for this task."
)
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
logger.debug("No `ready` tasks to run")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

for event in self.agent._run_model(messages=messages, tools=tools):
self.handle_event(event)

# Check if we've reached the call limit within a turn
if max_llm_calls is not None and call_count >= max_llm_calls:
logger.debug(f"Max LLM calls reached: {max_llm_calls}")
break
# Run the agent's turn
call_count += self.run_agent_turn(max_llm_calls - call_count)

# Select the next agent for the following turn
if available_agents := self.get_available_agents():
Expand All @@ -231,6 +201,7 @@ def run(
)
)

@prefect_task
async def run_async(
self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None
):
Expand Down Expand Up @@ -270,15 +241,14 @@ async def run_async(
logger.debug(f"Max agent turns reached: {max_agent_turns}")
break

# this check seems redundant to the check below, but this one exits the outer loop
if max_llm_calls is not None and call_count >= max_llm_calls:
break

turn_count += 1
self.turn_strategy.begin_turn()

# Mark assigned tasks as running
for task in (assigned_tasks := self.get_tasks("assigned")):
for task in self.get_tasks("assigned"):
if not task.is_running():
task.mark_running()
self.flow.add_events(
Expand All @@ -289,34 +259,10 @@ async def run_async(
]
)

# Execute LLM calls until the turn should end
while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
task.mark_failed(
reason="Max LLM calls reached for this task."
)
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
logger.debug("No `ready` tasks to run")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

async for event in self.agent._run_model_async(
messages=messages, tools=tools
):
self.handle_event(event)

# Check if we've reached the call limit within a turn
if max_llm_calls is not None and call_count >= max_llm_calls:
logger.debug(f"Max LLM calls reached: {max_llm_calls}")
break
# Run the agent's turn
call_count += await self.run_agent_turn_async(
max_llm_calls - call_count
)

# Select the next agent for the following turn
if available_agents := self.get_available_agents():
Expand All @@ -340,6 +286,88 @@ async def run_async(
)
)

@prefect_task(task_run_name="Agent turn: {self.agent.name}")
def run_agent_turn(self, max_llm_calls: Optional[int]) -> int:
"""
Run a single agent turn, which may consist of multiple LLM calls.

Args:
max_llm_calls (Optional[int]): The number of LLM calls allowed.

Returns:
int: The number of LLM calls made during this turn.
"""
call_count = 0
assigned_tasks = self.get_tasks("assigned")

while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
task.mark_failed(reason="Max LLM calls reached for this task.")
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
logger.debug("No `ready` tasks to run")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

for event in self.agent._run_model(messages=messages, tools=tools):
self.handle_event(event)

# Check if we've reached the call limit within a turn
if max_llm_calls is not None and call_count >= max_llm_calls:
logger.debug(f"Max LLM calls reached: {max_llm_calls}")
break

return call_count

@prefect_task
async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int:
"""
Run a single agent turn asynchronously, which may consist of multiple LLM calls.

Args:
max_llm_calls (Optional[int]): The number of LLM calls allowed.

Returns:
int: The number of LLM calls made during this turn.
"""
call_count = 0
assigned_tasks = self.get_tasks("assigned")

while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
task.mark_failed(reason="Max LLM calls reached for this task.")
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
logger.debug("No `ready` tasks to run")
break

call_count += 1
messages = self.compile_messages()
tools = self.get_tools()

async for event in self.agent._run_model_async(
messages=messages, tools=tools
):
self.handle_event(event)

# Check if we've reached the call limit within a turn
if max_llm_calls is not None and call_count >= max_llm_calls:
logger.debug(f"Max LLM calls reached: {max_llm_calls}")
break

return call_count

def compile_prompt(self) -> str:
"""
Compile the prompt for the current turn.
Expand Down
11 changes: 10 additions & 1 deletion src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_SpecialGenericAlias,
)

from prefect.context import TaskRunContext
from pydantic import (
Field,
PydanticSchemaGenerationError,
Expand Down Expand Up @@ -47,6 +48,12 @@
logger = get_logger(__name__)


def get_task_run_name():
context = TaskRunContext.get()
task = context.parameters["self"]
return f"Task.run() ({task.friendly_name()})"


class Labels(RootModel):
root: tuple[Any, ...]

Expand Down Expand Up @@ -296,7 +303,7 @@ def friendly_name(self):
name = f'"{self.objective[:50]}..."'
else:
name = f'"{self.objective}"'
return f"Task {self.id} ({name})"
return f"Task #{self.id} ({name})"

def serialize_for_prompt(self) -> dict:
"""
Expand Down Expand Up @@ -331,6 +338,7 @@ def add_dependency(self, task: "Task"):
self.depends_on.add(task)
task._downstreams.add(self)

@prefect_task(task_run_name=get_task_run_name)
def run(
self,
agent: Optional[Agent] = None,
Expand Down Expand Up @@ -358,6 +366,7 @@ def run(
elif self.is_failed():
raise ValueError(f"{self.friendly_name()} failed: {self.result}")

@prefect_task(task_run_name=get_task_run_name)
async def run_async(
self,
agent: Optional[Agent] = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/utilities/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ def test_record_task_events(default_fake_llm):
}
assert events[3].tool_result.model_dump() == dict(
tool_call_id="call_ZEPdV8mCgeBe5UHjKzm6e3pe",
str_result='Task 12345 ("say hello") marked successful.',
str_result='Task #12345 ("say hello") marked successful.',
is_error=False,
)