diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index e89040d..7e3f047 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -189,8 +189,8 @@ def run( task.mark_failed( reason="Max LLM calls reached for this task." ) - else: - task._llm_calls += 1 + else: + task._llm_calls += 1 # Check if there are any ready tasks left if not any(t.is_ready() for t in assigned_tasks): @@ -296,8 +296,8 @@ async def run_async( task.mark_failed( reason="Max LLM calls reached for this task." ) - else: - task._llm_calls += 1 + else: + task._llm_calls += 1 # Check if there are any ready tasks left if not any(t.is_ready() for t in assigned_tasks): diff --git a/tests/orchestration/test_orchestrator.py b/tests/orchestration/test_orchestrator.py index f68b8d4..522329f 100644 --- a/tests/orchestration/test_orchestrator.py +++ b/tests/orchestration/test_orchestrator.py @@ -1,99 +1,113 @@ import pytest -import controlflow from controlflow.agents import Agent from controlflow.flows import Flow from controlflow.orchestration.orchestrator import Orchestrator -from controlflow.orchestration.turn_strategies import Popcorn # Add this import +from controlflow.orchestration.turn_strategies import ( # Add this import + Popcorn, + TurnStrategy, +) from controlflow.tasks.task import Task -@pytest.fixture -def mocked_orchestrator(monkeypatch): - agent = Agent() - task = Task("Test task", agents=[agent]) - flow = Flow() - orchestrator = Orchestrator(tasks=[task], flow=flow, agent=agent) - +class TestOrchestratorLimits: call_count = 0 turn_count = 0 - original_run_model = Agent._run_model - original_run_turn = Orchestrator._run_turn - def mock_run_model(*args, **kwargs): - nonlocal call_count - call_count += 1 - return original_run_model(*args, **kwargs) + @pytest.fixture + def mocked_orchestrator(self, default_fake_llm): + # Reset counts at the start of each test + self.call_count = 0 + self.turn_count = 0 - def mock_run_turn(*args, **kwargs): - nonlocal turn_count - turn_count += 1 - return original_run_turn(*args, **kwargs) + class TwoCallTurnStrategy(TurnStrategy): + calls: int = 0 - monkeypatch.setattr(Agent, "_run_model", mock_run_model) - monkeypatch.setattr(Orchestrator, "_run_turn", mock_run_turn) + def get_tools(self, *args, **kwargs): + return [] - return orchestrator, lambda: call_count, lambda: turn_count + def get_next_agent(self, current_agent, available_agents): + return current_agent + def begin_turn(ts_instance): + self.turn_count += 1 + super().begin_turn() -class TestOrchestratorLimits: - def test_default_limits(self, mocked_orchestrator, default_fake_llm, monkeypatch): - monkeypatch.setattr(controlflow.defaults, "model", default_fake_llm) - orchestrator, get_call_count, get_turn_count = mocked_orchestrator - - orchestrator.run() + def should_end_turn(ts_instance): + ts_instance.calls += 1 + # if this would be the third call, end the turn + if ts_instance.calls >= 3: + ts_instance.calls = 0 + return True + # record a new call for the unit test + self.call_count += 1 + return False - assert get_turn_count() == controlflow.settings.orchestrator_max_turns - assert ( - get_call_count() - == controlflow.settings.orchestrator_max_turns - * controlflow.settings.orchestrator_max_calls + agent = Agent() + task = Task("Test task", agents=[agent]) + flow = Flow() + orchestrator = Orchestrator( + tasks=[task], flow=flow, agent=agent, turn_strategy=TwoCallTurnStrategy() ) + return orchestrator + + def test_default_limits(self, mocked_orchestrator): + mocked_orchestrator.run() + + assert self.turn_count == 5 + assert self.call_count == 10 + @pytest.mark.parametrize( - "max_agent_turns, max_llm_calls, expected_calls", + "max_agent_turns, max_llm_calls, expected_turns, expected_calls", [ - (1, 1, 1), - (1, 2, 2), - (2, 1, 2), - (3, 2, 6), + (1, 1, 1, 1), + (1, 2, 1, 2), + (5, 3, 2, 3), + (3, 12, 3, 6), ], ) def test_custom_limits( self, mocked_orchestrator, - default_fake_llm, - monkeypatch, max_agent_turns, max_llm_calls, + expected_turns, expected_calls, ): - monkeypatch.setattr(controlflow.defaults, "model", default_fake_llm) - orchestrator, get_call_count, _ = mocked_orchestrator - - orchestrator.run(max_agent_turns=max_agent_turns, max_llm_calls=max_llm_calls) - - assert get_call_count() == expected_calls - - def test_max_turns_reached( - self, mocked_orchestrator, default_fake_llm, monkeypatch - ): - monkeypatch.setattr(controlflow.defaults, "model", default_fake_llm) - orchestrator, _, get_turn_count = mocked_orchestrator - - orchestrator.run(max_agent_turns=5) - - assert get_turn_count() == 5 - - def test_max_calls_reached( - self, mocked_orchestrator, default_fake_llm, monkeypatch - ): - monkeypatch.setattr(controlflow.defaults, "model", default_fake_llm) - orchestrator, get_call_count, _ = mocked_orchestrator - - orchestrator.run(max_llm_calls=3) + mocked_orchestrator.run( + max_agent_turns=max_agent_turns, max_llm_calls=max_llm_calls + ) - assert get_call_count() == 3 * controlflow.settings.orchestrator_max_turns + assert self.turn_count == expected_turns + assert self.call_count == expected_calls + + def test_task_limit(self, mocked_orchestrator): + task = Task("Test task", max_llm_calls=5, agents=[mocked_orchestrator.agent]) + mocked_orchestrator.tasks = [task] + mocked_orchestrator.run() + assert task.is_failed() + assert self.turn_count == 3 + # Note: the call count will be 6 because the orchestrator call count is + # incremented in "should_end_turn" which is called before the task's + # call count is evaluated + assert self.call_count == 6 + + def test_task_lifetime_limit(self, mocked_orchestrator): + task = Task("Test task", max_llm_calls=5, agents=[mocked_orchestrator.agent]) + mocked_orchestrator.tasks = [task] + mocked_orchestrator.run(max_agent_turns=1) + assert task.is_incomplete() + mocked_orchestrator.run(max_agent_turns=1) + assert task.is_incomplete() + mocked_orchestrator.run(max_agent_turns=1) + assert task.is_failed() + + assert self.turn_count == 3 + # Note: the call count will be 6 because the orchestrator call count is + # incremented in "should_end_turn" which is called before the task's + # call count is evaluated + assert self.call_count == 6 class TestOrchestratorCreation: @@ -148,44 +162,3 @@ def test_run_keeps_existing_agent_if_set(self): orchestrator.run(max_agent_turns=0) assert orchestrator.agent == agent1 - - -def test_run(): - result = controlflow.run("what's 2 + 2", result_type=int) - assert result == 4 - - -async def test_run_async(): - result = await controlflow.run_async("what's 2 + 2", result_type=int) - assert result == 4 - - -@pytest.mark.parametrize( - "max_agent_turns, max_llm_calls, expected_calls", - [ - (1, 1, 1), - (1, 2, 2), - (2, 1, 2), - (3, 2, 6), - ], -) -def test_run_with_limits( - monkeypatch, default_fake_llm, max_agent_turns, max_llm_calls, expected_calls -): - call_count = 0 - original_run_model = Agent._run_model - - def mock_run_model(self, *args, **kwargs): - nonlocal call_count - call_count += 1 - return original_run_model(self, *args, **kwargs) - - monkeypatch.setattr(Agent, "_run_model", mock_run_model) - - controlflow.run( - "send messages", - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - ) - - assert call_count == expected_calls diff --git a/tests/test_run.py b/tests/test_run.py index b253b64..007d556 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,8 +1,4 @@ -import pytest - -import controlflow -from controlflow.agents.agent import Agent -from controlflow.run import run, run_async, run_tasks, run_tasks_async +from controlflow.run import run, run_async def test_run(): @@ -13,124 +9,3 @@ def test_run(): async def test_run_async(): result = await run_async("what's 2 + 2", result_type=int) assert result == 4 - - -class TestLimits: - call_count = 0 - - @pytest.fixture(autouse=True) - def setup(self, monkeypatch, default_fake_llm): - self.call_count = 0 - - original_run_model = Agent._run_model - original_run_model_async = Agent._run_model_async - - def mock_run_model(*args, **kwargs): - self.call_count += 1 - return original_run_model(*args, **kwargs) - - async def mock_run_model_async(*args, **kwargs): - self.call_count += 1 - async for event in original_run_model_async(*args, **kwargs): - yield event - - monkeypatch.setattr(Agent, "_run_model", mock_run_model) - monkeypatch.setattr(Agent, "_run_model_async", mock_run_model_async) - - @pytest.mark.parametrize( - "max_agent_turns, max_llm_calls, expected_calls", - [ - (1, 1, 1), - (1, 2, 2), - (2, 1, 2), - (3, 2, 6), - ], - ) - def test_run_with_limits( - self, - max_agent_turns, - max_llm_calls, - expected_calls, - ): - run( - "send messages", - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - ) - - assert self.call_count == expected_calls - - @pytest.mark.parametrize( - "max_agent_turns, max_llm_calls, expected_calls", - [ - (1, 1, 1), - (1, 2, 2), - (2, 1, 2), - (3, 2, 6), - ], - ) - async def test_run_async_with_limits( - self, - max_agent_turns, - max_llm_calls, - expected_calls, - ): - await run_async( - "send messages", - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - ) - - assert self.call_count == expected_calls - - @pytest.mark.parametrize( - "max_agent_turns, max_llm_calls, expected_calls", - [ - (1, 1, 1), - (1, 2, 2), - (2, 1, 2), - (3, 2, 6), - ], - ) - def test_run_task_with_limits( - self, - max_agent_turns, - max_llm_calls, - expected_calls, - ): - run_tasks( - tasks=[ - controlflow.Task("send messages"), - controlflow.Task("send messages"), - ], - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - ) - - assert self.call_count == expected_calls - - @pytest.mark.parametrize( - "max_agent_turns, max_llm_calls, expected_calls", - [ - (1, 1, 1), - (1, 2, 2), - (2, 1, 2), - (3, 2, 6), - ], - ) - async def test_run_task_async_with_limits( - self, - max_agent_turns, - max_llm_calls, - expected_calls, - ): - await run_tasks_async( - tasks=[ - controlflow.Task("send messages"), - controlflow.Task("send messages"), - ], - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - ) - - assert self.call_count == expected_calls