From 247983864a96eba458271adcd5e2c7d1c6b78bc2 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 4 Sep 2024 09:39:36 -0400 Subject: [PATCH 1/4] Support orchestrators with `None` as the agent --- src/controlflow/orchestration/orchestrator.py | 37 +++++++++---- .../orchestration/turn_strategies.py | 39 +++++-------- src/controlflow/tasks/task.py | 9 +++ tests/orchestration/test_orchestrator.py | 55 +++++++++++++++++++ tests/orchestration/test_turn_strategies.py | 6 +- 5 files changed, 106 insertions(+), 40 deletions(-) diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 3c95d8b..fd8ebf8 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -30,7 +30,10 @@ class Orchestrator(ControlFlowModel): model_config = dict(arbitrary_types_allowed=True) flow: "Flow" = Field(description="The flow that the orchestrator is managing") - agent: Agent = Field(description="The currently active agent") + agent: Optional[Agent] = Field( + None, + description="The currently active agent. If not provided, the turn strategy will select one.", + ) tasks: list[Task] = Field(description="Tasks to be executed by the agent.") turn_strategy: TurnStrategy = Field( default=None, @@ -124,8 +127,11 @@ def _run_turn(self, max_calls_per_turn: Optional[int] = None): Run a single turn of the orchestration process. Args: - calls_per_turn (int, optional): Maximum number of LLM calls to run per turn. + max_calls_per_turn (int, optional): Maximum number of LLM calls to run per turn. """ + if not self.agent: + raise ValueError("No agent set.") + if max_calls_per_turn is None: max_calls_per_turn = controlflow.settings.orchestrator_max_calls_per_turn @@ -159,8 +165,11 @@ async def _run_turn_async(self, max_calls_per_turn: Optional[int] = None): Run a single turn of the orchestration process asynchronously. Args: - calls_per_turn (int, optional): Maximum number of LLM calls to run per turn. + max_calls_per_turn (int, optional): Maximum number of LLM calls to run per turn. """ + if not self.agent: + raise ValueError("No agent set.") + if max_calls_per_turn is None: max_calls_per_turn = controlflow.settings.orchestrator_max_calls_per_turn @@ -198,10 +207,15 @@ def run( Args: turns (int, optional): Maximum number of turns to run. - calls_per_turn (int, optional): Maximum number of LLM calls per turn. + max_calls_per_turn (int, optional): Maximum number of LLM calls per turn. """ import controlflow.events.orchestrator_events + if not self.agent: + self.agent = self.turn_strategy.get_next_agent( + None, self.get_available_agents() + ) + if max_turns is None: max_turns = controlflow.settings.orchestrator_max_turns @@ -211,9 +225,7 @@ def run( turn = 0 try: - while ( - self.get_tasks("ready") and not self.turn_strategy.should_end_session() - ): + while self.get_tasks("ready"): if max_turns is not None and turn >= max_turns: break self._run_turn(max_calls_per_turn=max_calls_per_turn) @@ -240,10 +252,15 @@ async def run_async( Args: turns (int, optional): Maximum number of turns to run. - calls_per_turn (int, optional): Maximum number of LLM calls per turn. + max_calls_per_turn (int, optional): Maximum number of LLM calls per turn. """ import controlflow.events.orchestrator_events + if not self.agent: + self.agent = self.turn_strategy.get_next_agent( + None, self.get_available_agents() + ) + if max_turns is None: max_turns = controlflow.settings.orchestrator_max_turns @@ -253,9 +270,7 @@ async def run_async( turn = 0 try: - while ( - self.get_tasks("ready") and not self.turn_strategy.should_end_session() - ): + while self.get_tasks("ready"): if max_turns is not None and turn >= max_turns: break await self._run_turn_async(max_calls_per_turn=max_calls_per_turn) diff --git a/src/controlflow/orchestration/turn_strategies.py b/src/controlflow/orchestration/turn_strategies.py index 4bfe1b2..00feb12 100644 --- a/src/controlflow/orchestration/turn_strategies.py +++ b/src/controlflow/orchestration/turn_strategies.py @@ -20,7 +20,7 @@ def get_tools( @abstractmethod def get_next_agent( - self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] + self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] ) -> Agent: pass @@ -37,16 +37,6 @@ def should_end_turn(self) -> bool: """ return self.end_turn - def should_end_session(self) -> bool: - """ - Determine if the session should end. The session is the collection of - all turns for all agents. - - Returns: - bool: True if the session should end, False otherwise. - """ - return False - def create_end_turn_tool(strategy: TurnStrategy) -> Tool: @tool @@ -79,18 +69,19 @@ def delegate_to_agent(agent_id: str, message: str = None) -> str: class Single(TurnStrategy): + agent: Agent + def get_tools( self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] ) -> List[Tool]: return [create_end_turn_tool(self)] def get_next_agent( - self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] + self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] ) -> Agent: - return current_agent - - def should_end_session(self) -> bool: - return self.end_turn + if self.agent not in available_agents: + raise ValueError(f"The specified agent {self.agent.id} is not available.") + return self.agent class Popcorn(TurnStrategy): @@ -103,11 +94,11 @@ def get_tools( return [create_end_turn_tool(self)] def get_next_agent( - self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] + self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] ) -> Agent: if self.next_agent and self.next_agent in available_agents: return self.next_agent - return next(iter(available_agents)) # Always return an available agent + return next(iter(available_agents)) class Random(TurnStrategy): @@ -117,7 +108,7 @@ def get_tools( return [create_end_turn_tool(self)] def get_next_agent( - self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] + self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] ) -> Agent: return random.choice(list(available_agents.keys())) @@ -129,10 +120,10 @@ def get_tools( return [create_end_turn_tool(self)] def get_next_agent( - self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] + self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] ) -> Agent: agents = list(available_agents.keys()) - if current_agent not in agents: + if current_agent is None or current_agent not in agents: return agents[0] current_index = agents.index(current_agent) next_index = (current_index + 1) % len(agents) @@ -146,7 +137,7 @@ def get_tools( return [create_end_turn_tool(self)] def get_next_agent( - self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] + self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] ) -> Agent: # Select the agent with the most tasks return max(available_agents, key=lambda agent: len(available_agents[agent])) @@ -164,9 +155,9 @@ def get_tools( return [create_end_turn_tool(self)] def get_next_agent( - self, current_agent: Agent, available_agents: Dict[Agent, List[Task]] + self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] ) -> Agent: - if current_agent is self.moderator: + if current_agent is None or current_agent is self.moderator: return ( self.next_agent if self.next_agent in available_agents diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 039b940..a887805 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -705,3 +705,12 @@ async def run_async( max_calls_per_turn=max_calls_per_turn, max_turns=max_turns, ) + + +def run_tasks(tasks: list[Task], *args, agent: Agent = None, **kwargs): + orchestrator = controlflow.orchestration.Orchestrator( + tasks=tasks, + agent=agent or controlflow.defaults.agent, + flow=controlflow.flows.get_flow() or controlflow.flows.Flow(), + ) + return orchestrator.run() diff --git a/tests/orchestration/test_orchestrator.py b/tests/orchestration/test_orchestrator.py index b02e49a..852bfbd 100644 --- a/tests/orchestration/test_orchestrator.py +++ b/tests/orchestration/test_orchestrator.py @@ -4,6 +4,7 @@ from controlflow.agents import Agent from controlflow.flows import Flow from controlflow.orchestration.orchestrator import Orchestrator +from controlflow.orchestration.turn_strategies import Popcorn # Add this import from controlflow.tasks.task import Task @@ -95,6 +96,60 @@ def test_max_calls_per_turn_reached( assert get_call_count() == 3 * controlflow.settings.orchestrator_max_turns +class TestOrchestratorCreation: + def test_create_orchestrator_with_agent(self): + agent = Agent() + task = Task("Test task", agents=[agent]) + flow = Flow() + orchestrator = Orchestrator(tasks=[task], flow=flow, agent=agent) + + assert orchestrator.agent == agent + assert orchestrator.flow == flow + assert orchestrator.tasks == [task] + + def test_create_orchestrator_without_agent(self): + task = Task("Test task") + flow = Flow() + orchestrator = Orchestrator(tasks=[task], flow=flow, agent=None) + + assert orchestrator.agent is None + assert orchestrator.flow == flow + assert orchestrator.tasks == [task] + + def test_run_sets_agent_if_none(self): + agent1 = Agent(id="agent1") + agent2 = Agent(id="agent2") + task = Task("Test task", agents=[agent1, agent2]) + flow = Flow() + turn_strategy = Popcorn() + orchestrator = Orchestrator( + tasks=[task], flow=flow, agent=None, turn_strategy=turn_strategy + ) + + assert orchestrator.agent is None + + orchestrator.run(max_turns=0) + + assert orchestrator.agent is not None + assert orchestrator.agent in [agent1, agent2] + + def test_run_keeps_existing_agent_if_set(self): + agent1 = Agent(id="agent1") + agent2 = Agent(id="agent2") + task = Task("Test task", agents=[agent1, agent2]) + flow = Flow() + turn_strategy = Popcorn() + orchestrator = Orchestrator( + tasks=[task], flow=flow, agent=agent1, turn_strategy=turn_strategy + ) + + assert orchestrator.agent == agent1 + + orchestrator.run(max_turns=0) + + assert orchestrator.agent == agent1 + + def test_run(): result = controlflow.run("what's 2 + 2", result_type=int) assert result == 4 diff --git a/tests/orchestration/test_turn_strategies.py b/tests/orchestration/test_turn_strategies.py index b6bc87d..28f6c9c 100644 --- a/tests/orchestration/test_turn_strategies.py +++ b/tests/orchestration/test_turn_strategies.py @@ -34,7 +34,7 @@ def available_agents(agents: list[Agent], tasks: list[Task]): def test_single_strategy(agents, available_agents): - strategy = Single() + strategy = Single(agents[0]) current_agent = agents[0] tools = strategy.get_tools(current_agent, available_agents) @@ -44,10 +44,6 @@ def test_single_strategy(agents, available_agents): next_agent = strategy.get_next_agent(current_agent, available_agents) assert next_agent == current_agent - assert not strategy.should_end_session() - strategy.end_turn = True - assert strategy.should_end_session() - def test_popcorn_strategy(agents, available_agents): strategy = Popcorn() From 5f7a87a043fb19ce91c375732c4d0d166818daf6 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:18:31 -0400 Subject: [PATCH 2/4] Add top-level run fns --- src/controlflow/__init__.py | 7 +- src/controlflow/run.py | 69 +++++++++++++++ src/controlflow/tasks/__init__.py | 2 +- src/controlflow/tasks/task.py | 49 ----------- tests/tasks/test_tasks.py | 37 -------- tests/test_run.py | 136 +++++++++++++++++++++++++----- 6 files changed, 187 insertions(+), 113 deletions(-) create mode 100644 src/controlflow/run.py diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index 324bfbe..dad19a1 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -6,15 +6,16 @@ from controlflow.defaults import defaults +# base classes from .agents import Agent -from .tasks import Task, run, run_async +from .tasks import Task from .flows import Flow -from .orchestration import turn_strategies - +# functions and decorators from .instructions import instructions from .decorators import flow, task from .tools import tool +from .run import run, run_async, run_tasks, run_tasks_async # --- Version --- diff --git a/src/controlflow/run.py b/src/controlflow/run.py new file mode 100644 index 0000000..66951a3 --- /dev/null +++ b/src/controlflow/run.py @@ -0,0 +1,69 @@ +from controlflow.flows import Flow, get_flow +from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy +from controlflow.tasks.task import Task + + +def run( + objective: str, + *, + turn_strategy: TurnStrategy = None, + max_calls_per_turn: int = None, + max_turns: int = None, + **task_kwargs, +): + task = Task( + objective=objective, + **task_kwargs, + ) + return task.run( + turn_strategy=turn_strategy, + max_calls_per_turn=max_calls_per_turn, + max_turns=max_turns, + ) + + +async def run_async( + objective: str, + *, + turn_strategy: TurnStrategy = None, + max_calls_per_turn: int = None, + max_turns: int = None, + **task_kwargs, +): + task = Task( + objective=objective, + **task_kwargs, + ) + return await task.run_async( + turn_strategy=turn_strategy, + max_calls_per_turn=max_calls_per_turn, + max_turns=max_turns, + ) + + +def run_tasks( + tasks: list[Task], + flow: Flow = None, + turn_strategy: TurnStrategy = None, + **run_kwargs, +): + """ + Convenience function to run a list of tasks to completion. + """ + flow = flow or get_flow() or Flow() + orchestrator = Orchestrator(tasks=tasks, flow=flow, turn_strategy=turn_strategy) + return orchestrator.run(**run_kwargs) + + +async def run_tasks_async( + tasks: list[Task], + flow: Flow = None, + turn_strategy: TurnStrategy = None, + **run_kwargs, +): + """ + Convenience function to run a list of tasks to completion asynchronously. + """ + flow = flow or get_flow() or Flow() + orchestrator = Orchestrator(tasks=tasks, flow=flow, turn_strategy=turn_strategy) + return await orchestrator.run_async(**run_kwargs) diff --git a/src/controlflow/tasks/__init__.py b/src/controlflow/tasks/__init__.py index 8d9d8f2..280e6a2 100644 --- a/src/controlflow/tasks/__init__.py +++ b/src/controlflow/tasks/__init__.py @@ -1 +1 @@ -from .task import Task, run, run_async +from .task import Task diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index a887805..5256497 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -665,52 +665,3 @@ def _generate_result_schema(result_type: type[T]) -> type[T]: "Please use a custom type or add compatibility." ) return result_schema - - -def run( - objective: str, - *task_args, - turn_strategy: "TurnStrategy" = None, - max_calls_per_turn: int = None, - max_turns: int = None, - **task_kwargs, -): - task = Task( - objective=objective, - *task_args, - **task_kwargs, - ) - return task.run( - turn_strategy=turn_strategy, - max_calls_per_turn=max_calls_per_turn, - max_turns=max_turns, - ) - - -async def run_async( - objective: str, - *task_args, - turn_strategy: "TurnStrategy" = None, - max_calls_per_turn: int = None, - max_turns: int = None, - **task_kwargs, -): - task = Task( - objective=objective, - *task_args, - **task_kwargs, - ) - return await task.run_async( - turn_strategy=turn_strategy, - max_calls_per_turn=max_calls_per_turn, - max_turns=max_turns, - ) - - -def run_tasks(tasks: list[Task], *args, agent: Agent = None, **kwargs): - orchestrator = controlflow.orchestration.Orchestrator( - tasks=tasks, - agent=agent or controlflow.defaults.agent, - flow=controlflow.flows.get_flow() or controlflow.flows.Flow(), - ) - return orchestrator.run() diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index b535f43..8300716 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -441,40 +441,3 @@ class Person(BaseModel): tool.run(input=dict(result=1)) assert task.result == Person(name="Bob", age=35) assert isinstance(task.result, Person) - - -class TestRun: - @pytest.mark.parametrize( - "turns, calls_per_turn, expected_calls", - [ - (1, 1, 1), - (1, 2, 2), - (2, 1, 2), - (3, 2, 6), - ], - ) - def test_run_with_limits( - self, - monkeypatch, - default_fake_llm, - turns, - calls_per_turn, - expected_calls, - ): - call_count = 0 - original_run_model = Agent._run_model - - def mock_run_model(*args, **kwargs): - nonlocal call_count - call_count += 1 - return original_run_model(*args, **kwargs) - - monkeypatch.setattr(Agent, "_run_model", mock_run_model) - - task = Task("send messages") - task.run( - max_calls_per_turn=calls_per_turn, - max_turns=turns, - ) - - assert call_count == expected_calls diff --git a/tests/test_run.py b/tests/test_run.py index ab06059..2f82a61 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -14,32 +14,122 @@ async def test_run_async(): assert result == 4 -@pytest.mark.parametrize( - "turns, calls_per_turn, expected_calls", - [ - (1, 1, 1), - (1, 2, 2), - (2, 1, 2), - (3, 2, 6), - ], -) -def test_run_with_limits( - monkeypatch, default_fake_llm, turns, calls_per_turn, expected_calls -): +class TestLimits: call_count = 0 - original_run_model = Agent._run_model - def mock_run_model(*args, **kwargs): - nonlocal call_count - call_count += 1 - return original_run_model(*args, **kwargs) + @pytest.fixture(autouse=True) + def setup(self, monkeypatch, default_fake_llm): + self.call_count = 0 - monkeypatch.setattr(Agent, "_run_model", mock_run_model) + original_run_model = Agent._run_model + original_run_model_async = Agent._run_model_async - controlflow.run( - "send messages", - max_calls_per_turn=calls_per_turn, - max_turns=turns, + def mock_run_model(*args, **kwargs): + self.call_count += 1 + return original_run_model(*args, **kwargs) + + async def mock_run_model_async(*args, **kwargs): + self.call_count += 1 + async for event in original_run_model_async(*args, **kwargs): + yield event + + monkeypatch.setattr(Agent, "_run_model", mock_run_model) + monkeypatch.setattr(Agent, "_run_model_async", mock_run_model_async) + + @pytest.mark.parametrize( + "max_turns, max_calls_per_turn, expected_calls", + [ + (1, 1, 1), + (1, 2, 2), + (2, 1, 2), + (3, 2, 6), + ], + ) + def test_run_with_limits( + self, + max_turns, + max_calls_per_turn, + expected_calls, + ): + controlflow.run( + "send messages", + max_calls_per_turn=max_calls_per_turn, + max_turns=max_turns, + ) + + assert self.call_count == expected_calls + + @pytest.mark.parametrize( + "max_turns, max_calls_per_turn, expected_calls", + [ + (1, 1, 1), + (1, 2, 2), + (2, 1, 2), + (3, 2, 6), + ], + ) + async def test_run_async_with_limits( + self, + max_turns, + max_calls_per_turn, + expected_calls, + ): + await controlflow.run_async( + "send messages", + max_calls_per_turn=max_calls_per_turn, + max_turns=max_turns, + ) + + assert self.call_count == expected_calls + + @pytest.mark.parametrize( + "max_turns, max_calls_per_turn, expected_calls", + [ + (1, 1, 1), + (1, 2, 2), + (2, 1, 2), + (3, 2, 6), + ], + ) + def test_run_task_with_limits( + self, + max_turns, + max_calls_per_turn, + expected_calls, + ): + controlflow.run_tasks( + tasks=[ + controlflow.Task("send messages"), + controlflow.Task("send messages"), + ], + max_calls_per_turn=max_calls_per_turn, + max_turns=max_turns, + ) + + assert self.call_count == expected_calls + + @pytest.mark.parametrize( + "max_turns, max_calls_per_turn, expected_calls", + [ + (1, 1, 1), + (1, 2, 2), + (2, 1, 2), + (3, 2, 6), + ], ) + async def test_run_task_async_with_limits( + self, + max_turns, + max_calls_per_turn, + expected_calls, + ): + await controlflow.run_tasks_async( + tasks=[ + controlflow.Task("send messages"), + controlflow.Task("send messages"), + ], + max_calls_per_turn=max_calls_per_turn, + max_turns=max_turns, + ) - assert call_count == expected_calls + assert self.call_count == expected_calls From 71341778cea288b9555d5c8260f324b3dde6777b Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:20:18 -0400 Subject: [PATCH 3/4] Update labeler.yml --- .github/labeler.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index d3265a0..bb670b7 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,3 +1,7 @@ documentation: -- changed-files: - - any-glob-to-any-file: 'docs/*' \ No newline at end of file + - changed-files: + - any-glob-to-any-file: "docs/*" + +tests: + - changed-files: + - any-glob-to-any-file: "tests/*" From ca5ee08fa7267f8d8e45610abe28e4c910d3ff7c Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:22:26 -0400 Subject: [PATCH 4/4] Update test_turn_strategies.py --- tests/orchestration/test_turn_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/orchestration/test_turn_strategies.py b/tests/orchestration/test_turn_strategies.py index 28f6c9c..2a77c20 100644 --- a/tests/orchestration/test_turn_strategies.py +++ b/tests/orchestration/test_turn_strategies.py @@ -34,7 +34,7 @@ def available_agents(agents: list[Agent], tasks: list[Task]): def test_single_strategy(agents, available_agents): - strategy = Single(agents[0]) + strategy = Single(agent=agents[0]) current_agent = agents[0] tools = strategy.get_tools(current_agent, available_agents)