Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top level run fns and support orchestrators with no agent #280

Merged
merged 4 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
documentation:
- changed-files:
- any-glob-to-any-file: 'docs/*'
- changed-files:
- any-glob-to-any-file: "docs/*"

tests:
- changed-files:
- any-glob-to-any-file: "tests/*"
7 changes: 4 additions & 3 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

from controlflow.defaults import defaults

# base classes
from .agents import Agent
from .tasks import Task, run, run_async
from .tasks import Task
from .flows import Flow
from .orchestration import turn_strategies


# functions and decorators
from .instructions import instructions
from .decorators import flow, task
from .tools import tool
from .run import run, run_async, run_tasks, run_tasks_async


# --- Version ---
Expand Down
37 changes: 26 additions & 11 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ class Orchestrator(ControlFlowModel):

model_config = dict(arbitrary_types_allowed=True)
flow: "Flow" = Field(description="The flow that the orchestrator is managing")
agent: Agent = Field(description="The currently active agent")
agent: Optional[Agent] = Field(
None,
description="The currently active agent. If not provided, the turn strategy will select one.",
)
tasks: list[Task] = Field(description="Tasks to be executed by the agent.")
turn_strategy: TurnStrategy = Field(
default=None,
Expand Down Expand Up @@ -124,8 +127,11 @@ def _run_turn(self, max_calls_per_turn: Optional[int] = None):
Run a single turn of the orchestration process.
Args:
calls_per_turn (int, optional): Maximum number of LLM calls to run per turn.
max_calls_per_turn (int, optional): Maximum number of LLM calls to run per turn.
"""
if not self.agent:
raise ValueError("No agent set.")

if max_calls_per_turn is None:
max_calls_per_turn = controlflow.settings.orchestrator_max_calls_per_turn

Expand Down Expand Up @@ -159,8 +165,11 @@ async def _run_turn_async(self, max_calls_per_turn: Optional[int] = None):
Run a single turn of the orchestration process asynchronously.
Args:
calls_per_turn (int, optional): Maximum number of LLM calls to run per turn.
max_calls_per_turn (int, optional): Maximum number of LLM calls to run per turn.
"""
if not self.agent:
raise ValueError("No agent set.")

if max_calls_per_turn is None:
max_calls_per_turn = controlflow.settings.orchestrator_max_calls_per_turn

Expand Down Expand Up @@ -198,10 +207,15 @@ def run(
Args:
turns (int, optional): Maximum number of turns to run.
calls_per_turn (int, optional): Maximum number of LLM calls per turn.
max_calls_per_turn (int, optional): Maximum number of LLM calls per turn.
"""
import controlflow.events.orchestrator_events

if not self.agent:
self.agent = self.turn_strategy.get_next_agent(
None, self.get_available_agents()
)

if max_turns is None:
max_turns = controlflow.settings.orchestrator_max_turns

Expand All @@ -211,9 +225,7 @@ def run(

turn = 0
try:
while (
self.get_tasks("ready") and not self.turn_strategy.should_end_session()
):
while self.get_tasks("ready"):
if max_turns is not None and turn >= max_turns:
break
self._run_turn(max_calls_per_turn=max_calls_per_turn)
Expand All @@ -240,10 +252,15 @@ async def run_async(
Args:
turns (int, optional): Maximum number of turns to run.
calls_per_turn (int, optional): Maximum number of LLM calls per turn.
max_calls_per_turn (int, optional): Maximum number of LLM calls per turn.
"""
import controlflow.events.orchestrator_events

if not self.agent:
self.agent = self.turn_strategy.get_next_agent(
None, self.get_available_agents()
)

if max_turns is None:
max_turns = controlflow.settings.orchestrator_max_turns

Expand All @@ -253,9 +270,7 @@ async def run_async(

turn = 0
try:
while (
self.get_tasks("ready") and not self.turn_strategy.should_end_session()
):
while self.get_tasks("ready"):
if max_turns is not None and turn >= max_turns:
break
await self._run_turn_async(max_calls_per_turn=max_calls_per_turn)
Expand Down
39 changes: 15 additions & 24 deletions src/controlflow/orchestration/turn_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_tools(

@abstractmethod
def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
pass

Expand All @@ -37,16 +37,6 @@ def should_end_turn(self) -> bool:
"""
return self.end_turn

def should_end_session(self) -> bool:
"""
Determine if the session should end. The session is the collection of
all turns for all agents.
Returns:
bool: True if the session should end, False otherwise.
"""
return False


def create_end_turn_tool(strategy: TurnStrategy) -> Tool:
@tool
Expand Down Expand Up @@ -79,18 +69,19 @@ def delegate_to_agent(agent_id: str, message: str = None) -> str:


class Single(TurnStrategy):
agent: Agent

def get_tools(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
) -> List[Tool]:
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
return current_agent

def should_end_session(self) -> bool:
return self.end_turn
if self.agent not in available_agents:
raise ValueError(f"The specified agent {self.agent.id} is not available.")
return self.agent


class Popcorn(TurnStrategy):
Expand All @@ -103,11 +94,11 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
if self.next_agent and self.next_agent in available_agents:
return self.next_agent
return next(iter(available_agents)) # Always return an available agent
return next(iter(available_agents))


class Random(TurnStrategy):
Expand All @@ -117,7 +108,7 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
return random.choice(list(available_agents.keys()))

Expand All @@ -129,10 +120,10 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
agents = list(available_agents.keys())
if current_agent not in agents:
if current_agent is None or current_agent not in agents:
return agents[0]
current_index = agents.index(current_agent)
next_index = (current_index + 1) % len(agents)
Expand All @@ -146,7 +137,7 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
# Select the agent with the most tasks
return max(available_agents, key=lambda agent: len(available_agents[agent]))
Expand All @@ -164,9 +155,9 @@ def get_tools(
return [create_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Agent, available_agents: Dict[Agent, List[Task]]
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
) -> Agent:
if current_agent is self.moderator:
if current_agent is None or current_agent is self.moderator:
return (
self.next_agent
if self.next_agent in available_agents
Expand Down
69 changes: 69 additions & 0 deletions src/controlflow/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from controlflow.flows import Flow, get_flow
from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy
from controlflow.tasks.task import Task


def run(
objective: str,
*,
turn_strategy: TurnStrategy = None,
max_calls_per_turn: int = None,
max_turns: int = None,
**task_kwargs,
):
task = Task(
objective=objective,
**task_kwargs,
)
return task.run(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
)


async def run_async(
objective: str,
*,
turn_strategy: TurnStrategy = None,
max_calls_per_turn: int = None,
max_turns: int = None,
**task_kwargs,
):
task = Task(
objective=objective,
**task_kwargs,
)
return await task.run_async(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
)


def run_tasks(
tasks: list[Task],
flow: Flow = None,
turn_strategy: TurnStrategy = None,
**run_kwargs,
):
"""
Convenience function to run a list of tasks to completion.
"""
flow = flow or get_flow() or Flow()
orchestrator = Orchestrator(tasks=tasks, flow=flow, turn_strategy=turn_strategy)
return orchestrator.run(**run_kwargs)


async def run_tasks_async(
tasks: list[Task],
flow: Flow = None,
turn_strategy: TurnStrategy = None,
**run_kwargs,
):
"""
Convenience function to run a list of tasks to completion asynchronously.
"""
flow = flow or get_flow() or Flow()
orchestrator = Orchestrator(tasks=tasks, flow=flow, turn_strategy=turn_strategy)
return await orchestrator.run_async(**run_kwargs)
2 changes: 1 addition & 1 deletion src/controlflow/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .task import Task, run, run_async
from .task import Task
40 changes: 0 additions & 40 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,43 +665,3 @@ def _generate_result_schema(result_type: type[T]) -> type[T]:
"Please use a custom type or add compatibility."
)
return result_schema


def run(
objective: str,
*task_args,
turn_strategy: "TurnStrategy" = None,
max_calls_per_turn: int = None,
max_turns: int = None,
**task_kwargs,
):
task = Task(
objective=objective,
*task_args,
**task_kwargs,
)
return task.run(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
)


async def run_async(
objective: str,
*task_args,
turn_strategy: "TurnStrategy" = None,
max_calls_per_turn: int = None,
max_turns: int = None,
**task_kwargs,
):
task = Task(
objective=objective,
*task_args,
**task_kwargs,
)
return await task.run_async(
turn_strategy=turn_strategy,
max_calls_per_turn=max_calls_per_turn,
max_turns=max_turns,
)
Loading