Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Sep 9, 2024
1 parent f400d59 commit f3fd37d
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 236 deletions.
8 changes: 4 additions & 4 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def run(
task.mark_failed(
reason="Max LLM calls reached for this task."
)
else:
task._llm_calls += 1
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
Expand Down Expand Up @@ -296,8 +296,8 @@ async def run_async(
task.mark_failed(
reason="Max LLM calls reached for this task."
)
else:
task._llm_calls += 1
else:
task._llm_calls += 1

# Check if there are any ready tasks left
if not any(t.is_ready() for t in assigned_tasks):
Expand Down
185 changes: 79 additions & 106 deletions tests/orchestration/test_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,113 @@
import pytest

import controlflow
from controlflow.agents import Agent
from controlflow.flows import Flow
from controlflow.orchestration.orchestrator import Orchestrator
from controlflow.orchestration.turn_strategies import Popcorn # Add this import
from controlflow.orchestration.turn_strategies import ( # Add this import
Popcorn,
TurnStrategy,
)
from controlflow.tasks.task import Task


@pytest.fixture
def mocked_orchestrator(monkeypatch):
agent = Agent()
task = Task("Test task", agents=[agent])
flow = Flow()
orchestrator = Orchestrator(tasks=[task], flow=flow, agent=agent)

class TestOrchestratorLimits:
call_count = 0
turn_count = 0
original_run_model = Agent._run_model
original_run_turn = Orchestrator._run_turn

def mock_run_model(*args, **kwargs):
nonlocal call_count
call_count += 1
return original_run_model(*args, **kwargs)
@pytest.fixture
def mocked_orchestrator(self, default_fake_llm):
# Reset counts at the start of each test
self.call_count = 0
self.turn_count = 0

def mock_run_turn(*args, **kwargs):
nonlocal turn_count
turn_count += 1
return original_run_turn(*args, **kwargs)
class TwoCallTurnStrategy(TurnStrategy):
calls: int = 0

monkeypatch.setattr(Agent, "_run_model", mock_run_model)
monkeypatch.setattr(Orchestrator, "_run_turn", mock_run_turn)
def get_tools(self, *args, **kwargs):
return []

return orchestrator, lambda: call_count, lambda: turn_count
def get_next_agent(self, current_agent, available_agents):
return current_agent

def begin_turn(ts_instance):
self.turn_count += 1
super().begin_turn()

class TestOrchestratorLimits:
def test_default_limits(self, mocked_orchestrator, default_fake_llm, monkeypatch):
monkeypatch.setattr(controlflow.defaults, "model", default_fake_llm)
orchestrator, get_call_count, get_turn_count = mocked_orchestrator

orchestrator.run()
def should_end_turn(ts_instance):
ts_instance.calls += 1
# if this would be the third call, end the turn
if ts_instance.calls >= 3:
ts_instance.calls = 0
return True
# record a new call for the unit test
self.call_count += 1
return False

assert get_turn_count() == controlflow.settings.orchestrator_max_turns
assert (
get_call_count()
== controlflow.settings.orchestrator_max_turns
* controlflow.settings.orchestrator_max_calls
agent = Agent()
task = Task("Test task", agents=[agent])
flow = Flow()
orchestrator = Orchestrator(
tasks=[task], flow=flow, agent=agent, turn_strategy=TwoCallTurnStrategy()
)

return orchestrator

def test_default_limits(self, mocked_orchestrator):
mocked_orchestrator.run()

assert self.turn_count == 5
assert self.call_count == 10

@pytest.mark.parametrize(
"max_agent_turns, max_llm_calls, expected_calls",
"max_agent_turns, max_llm_calls, expected_turns, expected_calls",
[
(1, 1, 1),
(1, 2, 2),
(2, 1, 2),
(3, 2, 6),
(1, 1, 1, 1),
(1, 2, 1, 2),
(5, 3, 2, 3),
(3, 12, 3, 6),
],
)
def test_custom_limits(
self,
mocked_orchestrator,
default_fake_llm,
monkeypatch,
max_agent_turns,
max_llm_calls,
expected_turns,
expected_calls,
):
monkeypatch.setattr(controlflow.defaults, "model", default_fake_llm)
orchestrator, get_call_count, _ = mocked_orchestrator

orchestrator.run(max_agent_turns=max_agent_turns, max_llm_calls=max_llm_calls)

assert get_call_count() == expected_calls

def test_max_turns_reached(
self, mocked_orchestrator, default_fake_llm, monkeypatch
):
monkeypatch.setattr(controlflow.defaults, "model", default_fake_llm)
orchestrator, _, get_turn_count = mocked_orchestrator

orchestrator.run(max_agent_turns=5)

assert get_turn_count() == 5

def test_max_calls_reached(
self, mocked_orchestrator, default_fake_llm, monkeypatch
):
monkeypatch.setattr(controlflow.defaults, "model", default_fake_llm)
orchestrator, get_call_count, _ = mocked_orchestrator

orchestrator.run(max_llm_calls=3)
mocked_orchestrator.run(
max_agent_turns=max_agent_turns, max_llm_calls=max_llm_calls
)

assert get_call_count() == 3 * controlflow.settings.orchestrator_max_turns
assert self.turn_count == expected_turns
assert self.call_count == expected_calls

def test_task_limit(self, mocked_orchestrator):
task = Task("Test task", max_llm_calls=5, agents=[mocked_orchestrator.agent])
mocked_orchestrator.tasks = [task]
mocked_orchestrator.run()
assert task.is_failed()
assert self.turn_count == 3
# Note: the call count will be 6 because the orchestrator call count is
# incremented in "should_end_turn" which is called before the task's
# call count is evaluated
assert self.call_count == 6

def test_task_lifetime_limit(self, mocked_orchestrator):
task = Task("Test task", max_llm_calls=5, agents=[mocked_orchestrator.agent])
mocked_orchestrator.tasks = [task]
mocked_orchestrator.run(max_agent_turns=1)
assert task.is_incomplete()
mocked_orchestrator.run(max_agent_turns=1)
assert task.is_incomplete()
mocked_orchestrator.run(max_agent_turns=1)
assert task.is_failed()

assert self.turn_count == 3
# Note: the call count will be 6 because the orchestrator call count is
# incremented in "should_end_turn" which is called before the task's
# call count is evaluated
assert self.call_count == 6


class TestOrchestratorCreation:
Expand Down Expand Up @@ -148,44 +162,3 @@ def test_run_keeps_existing_agent_if_set(self):
orchestrator.run(max_agent_turns=0)

assert orchestrator.agent == agent1


def test_run():
result = controlflow.run("what's 2 + 2", result_type=int)
assert result == 4


async def test_run_async():
result = await controlflow.run_async("what's 2 + 2", result_type=int)
assert result == 4


@pytest.mark.parametrize(
"max_agent_turns, max_llm_calls, expected_calls",
[
(1, 1, 1),
(1, 2, 2),
(2, 1, 2),
(3, 2, 6),
],
)
def test_run_with_limits(
monkeypatch, default_fake_llm, max_agent_turns, max_llm_calls, expected_calls
):
call_count = 0
original_run_model = Agent._run_model

def mock_run_model(self, *args, **kwargs):
nonlocal call_count
call_count += 1
return original_run_model(self, *args, **kwargs)

monkeypatch.setattr(Agent, "_run_model", mock_run_model)

controlflow.run(
"send messages",
max_llm_calls=max_llm_calls,
max_agent_turns=max_agent_turns,
)

assert call_count == expected_calls
127 changes: 1 addition & 126 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import pytest

import controlflow
from controlflow.agents.agent import Agent
from controlflow.run import run, run_async, run_tasks, run_tasks_async
from controlflow.run import run, run_async


def test_run():
Expand All @@ -13,124 +9,3 @@ def test_run():
async def test_run_async():
result = await run_async("what's 2 + 2", result_type=int)
assert result == 4


class TestLimits:
call_count = 0

@pytest.fixture(autouse=True)
def setup(self, monkeypatch, default_fake_llm):
self.call_count = 0

original_run_model = Agent._run_model
original_run_model_async = Agent._run_model_async

def mock_run_model(*args, **kwargs):
self.call_count += 1
return original_run_model(*args, **kwargs)

async def mock_run_model_async(*args, **kwargs):
self.call_count += 1
async for event in original_run_model_async(*args, **kwargs):
yield event

monkeypatch.setattr(Agent, "_run_model", mock_run_model)
monkeypatch.setattr(Agent, "_run_model_async", mock_run_model_async)

@pytest.mark.parametrize(
"max_agent_turns, max_llm_calls, expected_calls",
[
(1, 1, 1),
(1, 2, 2),
(2, 1, 2),
(3, 2, 6),
],
)
def test_run_with_limits(
self,
max_agent_turns,
max_llm_calls,
expected_calls,
):
run(
"send messages",
max_llm_calls=max_llm_calls,
max_agent_turns=max_agent_turns,
)

assert self.call_count == expected_calls

@pytest.mark.parametrize(
"max_agent_turns, max_llm_calls, expected_calls",
[
(1, 1, 1),
(1, 2, 2),
(2, 1, 2),
(3, 2, 6),
],
)
async def test_run_async_with_limits(
self,
max_agent_turns,
max_llm_calls,
expected_calls,
):
await run_async(
"send messages",
max_llm_calls=max_llm_calls,
max_agent_turns=max_agent_turns,
)

assert self.call_count == expected_calls

@pytest.mark.parametrize(
"max_agent_turns, max_llm_calls, expected_calls",
[
(1, 1, 1),
(1, 2, 2),
(2, 1, 2),
(3, 2, 6),
],
)
def test_run_task_with_limits(
self,
max_agent_turns,
max_llm_calls,
expected_calls,
):
run_tasks(
tasks=[
controlflow.Task("send messages"),
controlflow.Task("send messages"),
],
max_llm_calls=max_llm_calls,
max_agent_turns=max_agent_turns,
)

assert self.call_count == expected_calls

@pytest.mark.parametrize(
"max_agent_turns, max_llm_calls, expected_calls",
[
(1, 1, 1),
(1, 2, 2),
(2, 1, 2),
(3, 2, 6),
],
)
async def test_run_task_async_with_limits(
self,
max_agent_turns,
max_llm_calls,
expected_calls,
):
await run_tasks_async(
tasks=[
controlflow.Task("send messages"),
controlflow.Task("send messages"),
],
max_llm_calls=max_llm_calls,
max_agent_turns=max_agent_turns,
)

assert self.call_count == expected_calls

0 comments on commit f3fd37d

Please sign in to comment.