Skip to content

Commit

Permalink
Merge pull request #280 from PrefectHQ/orchestrator
Browse files Browse the repository at this point in the history
Add top level `run` fns and support orchestrators with no agent
  • Loading branch information
jlowin authored Sep 4, 2024
2 parents aaf6802 + ca5ee08 commit 673c462
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 146 deletions.
8 changes: 6 additions & 2 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
documentation:
- changed-files:
- any-glob-to-any-file: 'docs/*'
- changed-files:
- any-glob-to-any-file: "docs/*"

tests:
- changed-files:
- any-glob-to-any-file: "tests/*"
7 changes: 4 additions & 3 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

from controlflow.defaults import defaults

# base classes
from .agents import Agent
from .tasks import Task, run, run_async
from .tasks import Task
from .flows import Flow
from .orchestration import turn_strategies


# functions and decorators
from .instructions import instructions
from .decorators import flow, task
from .tools import tool
from .run import run, run_async, run_tasks, run_tasks_async


# --- Version ---
Expand Down
37 changes: 26 additions & 11 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ class Orchestrator(ControlFlowModel):

model_config = dict(arbitrary_types_allowed=True)
flow: "Flow" = Field(description="The flow that the orchestrator is managing")
agent: Agent = Field(description="The currently active agent")
agent: Optional[Agent] = Field(
None,
description="The currently active agent. If not provided, the turn strategy will select one.",
)
tasks: list[Task] = Field(description="Tasks to be executed by the agent.")
turn_strategy: TurnStrategy = Field(
default=None,
Expand Down Expand Up @@ -124,8 +127,11 @@ def _run_turn(self, max_calls_per_turn: Optional[int] = None):
Run a single turn of the orchestration process.
Args:
calls_per_turn (int, optional): Maximum number of LLM calls to run per turn.
max_calls_per_turn (int, optional): Maximum number of LLM calls to run per turn.
"""
if not self.agent:
raise ValueError("No agent set.")

if max_calls_per_turn is None:
max_calls_per_turn = controlflow.settings.orchestrator_max_calls_per_turn

Expand Down Expand Up @@ -159,8 +165,11 @@ async def _run_turn_async(self, max_calls_per_turn: Optional[int] = None):
Run a single turn of the orchestration process asynchronously.
Args:
calls_per_turn (int, optional): Maximum number of LLM calls to run per turn.
max_calls_per_turn (int, optional): Maximum number of LLM calls to run per turn.
"""
if not self.agent:
raise ValueError("No agent set.")

if max_calls_per_turn is None:
max_calls_per_turn = controlflow.settings.orchestrator_max_calls_per_turn

Expand Down Expand Up @@ -198,10 +207,15 @@ def run(
Args:
turns (int, optional): Maximum number of turns to run.
calls_per_turn (int, optional): Maximum number of LLM calls per turn.
max_calls_per_turn (int, optional): Maximum number of LLM calls per turn.
"""
import controlflow.events.orchestrator_events

if not self.agent:
self.agent = self.turn_strategy.get_next_agent(
None, self.get_available_agents()
)

if max_turns is None:
max_turns = controlflow.settings.orchestrator_max_turns

Expand All @@ -211,9 +225,7 @@ def run(

turn = 0
try:
while (
self.get_tasks("ready") and not self.turn_strategy.should_end_session()
):
while self.get_tasks("ready"):
if max_turns is not None and turn >= max_turns:
break
self._run_turn(max_calls_per_turn=max_calls_per_turn)
Expand All @@ -240,10 +252,15 @@ async def run_async(
Args:
turns (int, optional): Maximum number of turns to run.
calls_per_turn (int, optional): Maximum number of LLM calls per turn.
max_calls_per_turn (int, optional): Maximum number of LLM calls per turn.
"""
import controlflow.events.orchestrator_events

if not self.agent:
self.agent = self.turn_strategy.get_next_agent(
None, self.get_available_agents()
)

if max_turns is None:
max_turns = controlflow.settings.orchestrator_max_turns

Expand All @@ -253,9 +270,7 @@ async def run_async(

turn = 0
try:
while (
self.get_tasks("ready") and not self.turn_strategy.should_end_session()
):
while self.get_tasks("ready"):
if max_turns is not None and turn >= max_turns:
break
await self._run_turn_async(max_calls_per_turn=max_calls_per_turn)
Expand Down
39 changes: 15 additions & 24 deletions src/controlflow/orchestration/turn_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_tools(

@abstractmethod
def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
pass

Expand All @@ -37,16 +37,6 @@ def should_end_turn(self) -> bool:
"""
return self.end_turn

def should_end_session(self) -> bool:
"""
Determine if the session should end. The session is the collection of
all turns for all agents.
Returns:
bool: True if the session should end, False otherwise.
"""
return False


def create_end_turn_tool(strategy: TurnStrategy) -> Tool:
@tool
Expand Down Expand Up @@ -79,18 +69,19 @@ def delegate_to_agent(agent_id: str, message: str = None) -> str:


class Single(TurnStrategy):
agent: Agent

def get_tools(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
) -> List[Tool]:
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
return current_agent

def should_end_session(self) -> bool:
return self.end_turn
if self.agent not in available_agents:
raise ValueError(f"The specified agent {self.agent.id} is not available.")
return self.agent


class Popcorn(TurnStrategy):
Expand All @@ -103,11 +94,11 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
if self.next_agent and self.next_agent in available_agents:
return self.next_agent
return next(iter(available_agents)) # Always return an available agent
return next(iter(available_agents))


class Random(TurnStrategy):
Expand All @@ -117,7 +108,7 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
return random.choice(list(available_agents.keys()))

Expand All @@ -129,10 +120,10 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
agents = list(available_agents.keys())
if current_agent not in agents:
if current_agent is None or current_agent not in agents:
return agents[0]
current_index = agents.index(current_agent)
next_index = (current_index + 1) % len(agents)
Expand All @@ -146,7 +137,7 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
# Select the agent with the most tasks
return max(available_agents, key=lambda agent: len(available_agents[agent]))
Expand All @@ -164,9 +155,9 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
if current_agent is self.moderator:
if current_agent is None or current_agent is self.moderator:
return (
self.next_agent
if self.next_agent in available_agents
Expand Down
69 changes: 69 additions & 0 deletions src/controlflow/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from controlflow.flows import Flow, get_flow
from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy
from controlflow.tasks.task import Task


def run(
objective: str,
*,
turn_strategy: TurnStrategy = None,
max_calls_per_turn: int = None,
max_turns: int = None,
**task_kwargs,
):
task = Task(
objective=objective,
**task_kwargs,
)
return task.run(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
)


async def run_async(
objective: str,
*,
turn_strategy: TurnStrategy = None,
max_calls_per_turn: int = None,
max_turns: int = None,
**task_kwargs,
):
task = Task(
objective=objective,
**task_kwargs,
)
return await task.run_async(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
)


def run_tasks(
tasks: list[Task],
flow: Flow = None,
turn_strategy: TurnStrategy = None,
**run_kwargs,
):
"""
Convenience function to run a list of tasks to completion.
"""
flow = flow or get_flow() or Flow()
orchestrator = Orchestrator(tasks=tasks, flow=flow, turn_strategy=turn_strategy)
return orchestrator.run(**run_kwargs)


async def run_tasks_async(
tasks: list[Task],
flow: Flow = None,
turn_strategy: TurnStrategy = None,
**run_kwargs,
):
"""
Convenience function to run a list of tasks to completion asynchronously.
"""
flow = flow or get_flow() or Flow()
orchestrator = Orchestrator(tasks=tasks, flow=flow, turn_strategy=turn_strategy)
return await orchestrator.run_async(**run_kwargs)
2 changes: 1 addition & 1 deletion src/controlflow/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .task import Task, run, run_async
from .task import Task
40 changes: 0 additions & 40 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,43 +665,3 @@ def _generate_result_schema(result_type: type[T]) -> type[T]:
"Please use a custom type or add compatibility."
)
return result_schema


def run(
objective: str,
*task_args,
turn_strategy: "TurnStrategy" = None,
max_calls_per_turn: int = None,
max_turns: int = None,
**task_kwargs,
):
task = Task(
objective=objective,
*task_args,
**task_kwargs,
)
return task.run(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
)


async def run_async(
objective: str,
*task_args,
turn_strategy: "TurnStrategy" = None,
max_calls_per_turn: int = None,
max_turns: int = None,
**task_kwargs,
):
task = Task(
objective=objective,
*task_args,
**task_kwargs,
)
return await task.run_async(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
)
Loading

0 comments on commit 673c462

Please sign in to comment.