diff --git a/docs/concepts/tasks.mdx b/docs/concepts/tasks.mdx index e47174b..c96bcfc 100644 --- a/docs/concepts/tasks.mdx +++ b/docs/concepts/tasks.mdx @@ -264,6 +264,59 @@ task = cf.Task( ) ``` +Note that this setting reflects the configuration of the `completion_tools` parameter. + +### Completion tools + +import { VersionBadge } from '/snippets/version-badge.mdx' + + + +In addition to specifying which agents are automatically given completion tools, you can control which completion tools are generated for a task using the `completion_tools` parameter. This allows you to specify whether you want to provide success and/or failure tools, or even provide custom completion tools. + +The `completion_tools` parameter accepts a list of strings, where each string represents a tool to be generated. The available options are: + +- `"SUCCEED"`: Generates a tool for marking the task as successful. +- `"FAIL"`: Generates a tool for marking the task as failed. + +If `completion_tools` is not specified, both `"SUCCEED"` and `"FAIL"` tools will be generated by default. + +You can manually create completion tools and provide them to your agents by calling `task.get_success_tool()` and `task.get_fail_tool()`. + + +If you exclude `completion_tools`, agents may be unable to complete the task or become stuck in a failure state. Without caps on LLM turns or calls, this could lead to runaway LLM usage. Make sure to manually manage how agents complete tasks if you are using a custom set of completion tools. + + +Here are some examples: + +``` +# Generate both success and failure tools (default behavior, equivalent to `completion_tools=None`) +task = cf.Task( + objective="Write a poem about AI", + completion_tools=["SUCCEED", "FAIL"], +) + +# Only generate a success tool +task = cf.Task( + objective="Write a poem about AI", + completion_tools=["SUCCEED"], +) + +# Only generate a failure tool +task = cf.Task( + objective="Write a poem about AI", + completion_tools=["FAIL"], +) + +# Don't generate any completion tools +task = cf.Task( + objective="Write a poem about AI", + completion_tools=[], +) +``` + +By controlling which completion tools are generated, you can customize the task completion process to better suit your workflow needs. For example, you might want to prevent agents from marking a task as failed, or you might want to provide your own custom completion tools instead of using the default ones. + ### Name The name of a task is a string that identifies the task within the workflow. It is used primarily for logging and debugging purposes, though it is also shown to agents during execution to help identify the task they are working on. diff --git a/docs/examples/call-routing.mdx b/docs/examples/call-routing.mdx index 6f01917..713e0d0 100644 --- a/docs/examples/call-routing.mdx +++ b/docs/examples/call-routing.mdx @@ -80,7 +80,7 @@ def routing_flow(): ), agents=[trainee], result_type=None, - tools=[main_task.create_success_tool()] + tools=[main_task.get_success_tool()] ) if main_task.result == target_department: diff --git a/pyproject.toml b/pyproject.toml index a2b6ff1..07aec50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "langchain_openai>=0.2", "langchain-anthropic>=0.2", "markdownify>=0.12.1", + "openai<1.47", # 1.47.0 introduced a bug with attempting to reuse an async client that doesnt have an obvious solution "pydantic-settings>=2.2.1", "textual>=0.61.1", "tiktoken>=0.7.0", diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 89c2f5d..dcb658f 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -115,8 +115,7 @@ def get_tools(self) -> list[Tool]: # add completion tools if task.completion_agents is None or self.agent in task.completion_agents: - tools.append(task.create_success_tool()) - tools.append(task.create_fail_tool()) + tools.extend(task.get_completion_tools()) # add turn strategy tools only if there are multiple available agents available_agents = self.get_available_agents() diff --git a/src/controlflow/orchestration/turn_strategies.py b/src/controlflow/orchestration/turn_strategies.py index 803ffc9..5494f8b 100644 --- a/src/controlflow/orchestration/turn_strategies.py +++ b/src/controlflow/orchestration/turn_strategies.py @@ -38,7 +38,7 @@ def should_end_turn(self) -> bool: return self.end_turn -def create_end_turn_tool(strategy: TurnStrategy) -> Tool: +def get_end_turn_tool(strategy: TurnStrategy) -> Tool: @tool def end_turn() -> str: """ @@ -51,7 +51,7 @@ def end_turn() -> str: return end_turn -def create_delegate_tool( +def get_delegate_tool( strategy: TurnStrategy, available_agents: dict[Agent, list[Task]] ) -> Tool: @tool @@ -77,7 +77,7 @@ class SingleAgent(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -93,7 +93,7 @@ class Popcorn(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_delegate_tool(self, available_agents)] + return [get_delegate_tool(self, available_agents)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -107,7 +107,7 @@ class Random(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -119,7 +119,7 @@ class RoundRobin(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -136,7 +136,7 @@ class MostBusy(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -152,9 +152,9 @@ def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: if current_agent == self.moderator: - return [create_delegate_tool(self, available_agents)] + return [get_delegate_tool(self, available_agents)] else: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 34c1b7e..22f1bfc 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -8,6 +8,7 @@ Any, Callable, GenericAlias, + Literal, Optional, TypeVar, Union, @@ -53,6 +54,9 @@ logger = get_logger(__name__) +COMPLETION_TOOLS = Literal["SUCCEED", "FAIL"] + + def get_task_run_name(): context = TaskRunContext.get() task = context.parameters["self"] @@ -144,6 +148,14 @@ class Task(ControlFlowModel): default_factory=list, description="Tools available to every agent working on this task.", ) + completion_tools: Optional[list[COMPLETION_TOOLS]] = Field( + default=None, + description=""" + Completion tools that will be generated for this task. If None, all + tools will be generated; if a list of strings, only the corresponding + tools will be generated automatically. + """, + ) completion_agents: Optional[list[Agent]] = Field( default=None, description="Agents that are allowed to mark this task as complete. If None, all agents are allowed.", @@ -472,19 +484,32 @@ def get_agents(self) -> list[Agent]: else: return [controlflow.defaults.agent] - def get_tools(self) -> list[Union[Tool, Callable]]: + def get_tools(self) -> list[Tool]: + """ + Return a list of all tools available for the task. + + Note this does not include completion tools, which are handled separately. + """ tools = self.tools.copy() if self.interactive: tools.append(cli_input) for memory in self.memories: tools.extend(memory.get_tools()) - return tools + return as_tools(tools) def get_completion_tools(self) -> list[Tool]: - tools = [ - self.create_success_tool(), - self.create_fail_tool(), - ] + """ + Return a list of all completion tools available for the task. + """ + tools = [] + completion_tools = self.completion_tools + if completion_tools is None: + completion_tools = ["SUCCEED", "FAIL"] + + if "SUCCEED" in completion_tools: + tools.append(self.get_success_tool()) + if "FAIL" in completion_tools: + tools.append(self.get_fail_tool()) return tools def get_prompt(self) -> str: @@ -517,7 +542,7 @@ def mark_failed(self, reason: Optional[str] = None): def mark_skipped(self): self.set_status(TaskStatus.SKIPPED) - def create_success_tool(self) -> Tool: + def get_success_tool(self) -> Tool: """ Create an agent-compatible tool for marking this task as successful. """ @@ -587,7 +612,7 @@ def succeed(result: result_schema) -> str: # type: ignore return succeed - def create_fail_tool(self) -> Tool: + def get_fail_tool(self) -> Tool: """ Create an agent-compatible tool for failing this task. """ diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index da6e258..4f0a7fe 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -483,27 +483,27 @@ def always_return_none(value: Any) -> None: class TestSuccessTool: def test_success_tool(self): task = Task("choose 5", result_type=int) - tool = task.create_success_tool() + tool = task.get_success_tool() tool.run(input=dict(result=5)) assert task.is_successful() assert task.result == 5 def test_success_tool_with_list_of_options(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) - tool = task.create_success_tool() + tool = task.get_success_tool() tool.run(input=dict(result=1)) assert task.is_successful() assert task.result == "good" def test_success_tool_with_list_of_options_requires_int(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) - tool = task.create_success_tool() + tool = task.get_success_tool() with pytest.raises(ValueError): tool.run(input=dict(result="good")) def test_tuple_of_ints_result(self): task = Task("choose 5", result_type=(4, 5, 6)) - tool = task.create_success_tool() + tool = task.get_success_tool() tool.run(input=dict(result=1)) assert task.result == 5 @@ -516,7 +516,7 @@ class Person(BaseModel): "Who is the oldest?", result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)), ) - tool = task.create_success_tool() + tool = task.get_success_tool() tool.run(input=dict(result=1)) assert task.result == Person(name="Bob", age=35) assert isinstance(task.result, Person) @@ -549,3 +549,71 @@ async def test_task_run_async_with_handlers(self, default_fake_llm): assert len(handler.events) > 0 assert len(handler.agent_messages) == 1 + + +class TestCompletionTools: + def test_default_completion_tools(self): + task = Task(objective="Test task") + assert task.completion_tools is None + tools = task.get_completion_tools() + assert len(tools) == 2 + assert any(t.name == f"mark_task_{task.id}_successful" for t in tools) + assert any(t.name == f"mark_task_{task.id}_failed" for t in tools) + + def test_only_succeed_tool(self): + task = Task(objective="Test task", completion_tools=["SUCCEED"]) + tools = task.get_completion_tools() + assert len(tools) == 1 + assert tools[0].name == f"mark_task_{task.id}_successful" + + def test_only_fail_tool(self): + task = Task(objective="Test task", completion_tools=["FAIL"]) + tools = task.get_completion_tools() + assert len(tools) == 1 + assert tools[0].name == f"mark_task_{task.id}_failed" + + def test_no_completion_tools(self): + task = Task(objective="Test task", completion_tools=[]) + tools = task.get_completion_tools() + assert len(tools) == 0 + + def test_invalid_completion_tool(self): + with pytest.raises(ValueError): + Task(objective="Test task", completion_tools=["INVALID"]) + + def test_manual_success_tool(self): + task = Task(objective="Test task", completion_tools=[], result_type=int) + success_tool = task.get_success_tool() + success_tool.run(input=dict(result=5)) + assert task.is_successful() + assert task.result == 5 + + def test_manual_fail_tool(self): + task = Task(objective="Test task", completion_tools=[]) + fail_tool = task.get_fail_tool() + assert fail_tool.name == f"mark_task_{task.id}_failed" + fail_tool.run(input=dict(reason="test error")) + assert task.is_failed() + assert task.result == "test error" + + def test_completion_tools_with_run(self): + task = Task("Calculate 2 + 2", result_type=int, completion_tools=["SUCCEED"]) + result = task.run(max_llm_calls=1) + assert result == 4 + assert task.is_successful() + + def test_no_completion_tools_with_run(self): + task = Task("Calculate 2 + 2", result_type=int, completion_tools=[]) + task.run(max_llm_calls=1) + assert task.is_incomplete() + + async def test_completion_tools_with_run_async(self): + task = Task("Calculate 2 + 2", result_type=int, completion_tools=["SUCCEED"]) + result = await task.run_async(max_llm_calls=1) + assert result == 4 + assert task.is_successful() + + async def test_no_completion_tools_with_run_async(self): + task = Task("Calculate 2 + 2", result_type=int, completion_tools=[]) + await task.run_async(max_llm_calls=1) + assert task.is_incomplete()