Skip to content

Commit

Permalink
Merge pull request #330 from PrefectHQ/completion-tools
Browse files Browse the repository at this point in the history
Allow completion tools to be customized per-task
  • Loading branch information
jlowin authored Sep 24, 2024
2 parents 168b90e + 3121f4b commit 176a4bd
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 25 deletions.
53 changes: 53 additions & 0 deletions docs/concepts/tasks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,59 @@ task = cf.Task(
)
```

Note that this setting reflects the configuration of the `completion_tools` parameter.

### Completion tools

import { VersionBadge } from '/snippets/version-badge.mdx'

<VersionBadge version="0.10.0" />

In addition to specifying which agents are automatically given completion tools, you can control which completion tools are generated for a task using the `completion_tools` parameter. This allows you to specify whether you want to provide success and/or failure tools, or even provide custom completion tools.

The `completion_tools` parameter accepts a list of strings, where each string represents a tool to be generated. The available options are:

- `"SUCCEED"`: Generates a tool for marking the task as successful.
- `"FAIL"`: Generates a tool for marking the task as failed.

If `completion_tools` is not specified, both `"SUCCEED"` and `"FAIL"` tools will be generated by default.

You can manually create completion tools and provide them to your agents by calling `task.get_success_tool()` and `task.get_fail_tool()`.

<Warning>
If you exclude `completion_tools`, agents may be unable to complete the task or become stuck in a failure state. Without caps on LLM turns or calls, this could lead to runaway LLM usage. Make sure to manually manage how agents complete tasks if you are using a custom set of completion tools.
</Warning>

Here are some examples:

```
# Generate both success and failure tools (default behavior, equivalent to `completion_tools=None`)
task = cf.Task(
objective="Write a poem about AI",
completion_tools=["SUCCEED", "FAIL"],
)
# Only generate a success tool
task = cf.Task(
objective="Write a poem about AI",
completion_tools=["SUCCEED"],
)
# Only generate a failure tool
task = cf.Task(
objective="Write a poem about AI",
completion_tools=["FAIL"],
)
# Don't generate any completion tools
task = cf.Task(
objective="Write a poem about AI",
completion_tools=[],
)
```

By controlling which completion tools are generated, you can customize the task completion process to better suit your workflow needs. For example, you might want to prevent agents from marking a task as failed, or you might want to provide your own custom completion tools instead of using the default ones.

### Name

The name of a task is a string that identifies the task within the workflow. It is used primarily for logging and debugging purposes, though it is also shown to agents during execution to help identify the task they are working on.
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/call-routing.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def routing_flow():
),
agents=[trainee],
result_type=None,
tools=[main_task.create_success_tool()]
tools=[main_task.get_success_tool()]
)

if main_task.result == target_department:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"langchain_openai>=0.2",
"langchain-anthropic>=0.2",
"markdownify>=0.12.1",
"openai<1.47", # 1.47.0 introduced a bug with attempting to reuse an async client that doesnt have an obvious solution
"pydantic-settings>=2.2.1",
"textual>=0.61.1",
"tiktoken>=0.7.0",
Expand Down
3 changes: 1 addition & 2 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def get_tools(self) -> list[Tool]:

# add completion tools
if task.completion_agents is None or self.agent in task.completion_agents:
tools.append(task.create_success_tool())
tools.append(task.create_fail_tool())
tools.extend(task.get_completion_tools())

# add turn strategy tools only if there are multiple available agents
available_agents = self.get_available_agents()
Expand Down
18 changes: 9 additions & 9 deletions src/controlflow/orchestration/turn_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def should_end_turn(self) -> bool:
return self.end_turn


def create_end_turn_tool(strategy: TurnStrategy) -> Tool:
def get_end_turn_tool(strategy: TurnStrategy) -> Tool:
@tool
def end_turn() -> str:
"""
Expand All @@ -51,7 +51,7 @@ def end_turn() -> str:
return end_turn


def create_delegate_tool(
def get_delegate_tool(
strategy: TurnStrategy, available_agents: dict[Agent, list[Task]]
) -> Tool:
@tool
Expand All @@ -77,7 +77,7 @@ class SingleAgent(TurnStrategy):
def get_tools(
self, current_agent: Agent, available_agents: dict[Agent, list[Task]]
) -> list[Tool]:
return [create_end_turn_tool(self)]
return [get_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
Expand All @@ -93,7 +93,7 @@ class Popcorn(TurnStrategy):
def get_tools(
self, current_agent: Agent, available_agents: dict[Agent, list[Task]]
) -> list[Tool]:
return [create_delegate_tool(self, available_agents)]
return [get_delegate_tool(self, available_agents)]

def get_next_agent(
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
Expand All @@ -107,7 +107,7 @@ class Random(TurnStrategy):
def get_tools(
self, current_agent: Agent, available_agents: dict[Agent, list[Task]]
) -> list[Tool]:
return [create_end_turn_tool(self)]
return [get_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
Expand All @@ -119,7 +119,7 @@ class RoundRobin(TurnStrategy):
def get_tools(
self, current_agent: Agent, available_agents: dict[Agent, list[Task]]
) -> list[Tool]:
return [create_end_turn_tool(self)]
return [get_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
Expand All @@ -136,7 +136,7 @@ class MostBusy(TurnStrategy):
def get_tools(
self, current_agent: Agent, available_agents: dict[Agent, list[Task]]
) -> list[Tool]:
return [create_end_turn_tool(self)]
return [get_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
Expand All @@ -152,9 +152,9 @@ def get_tools(
self, current_agent: Agent, available_agents: dict[Agent, list[Task]]
) -> list[Tool]:
if current_agent == self.moderator:
return [create_delegate_tool(self, available_agents)]
return [get_delegate_tool(self, available_agents)]
else:
return [create_end_turn_tool(self)]
return [get_end_turn_tool(self)]

def get_next_agent(
self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]]
Expand Down
41 changes: 33 additions & 8 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
Callable,
GenericAlias,
Literal,
Optional,
TypeVar,
Union,
Expand Down Expand Up @@ -53,6 +54,9 @@
logger = get_logger(__name__)


COMPLETION_TOOLS = Literal["SUCCEED", "FAIL"]


def get_task_run_name():
context = TaskRunContext.get()
task = context.parameters["self"]
Expand Down Expand Up @@ -144,6 +148,14 @@ class Task(ControlFlowModel):
default_factory=list,
description="Tools available to every agent working on this task.",
)
completion_tools: Optional[list[COMPLETION_TOOLS]] = Field(
default=None,
description="""
Completion tools that will be generated for this task. If None, all
tools will be generated; if a list of strings, only the corresponding
tools will be generated automatically.
""",
)
completion_agents: Optional[list[Agent]] = Field(
default=None,
description="Agents that are allowed to mark this task as complete. If None, all agents are allowed.",
Expand Down Expand Up @@ -472,19 +484,32 @@ def get_agents(self) -> list[Agent]:
else:
return [controlflow.defaults.agent]

def get_tools(self) -> list[Union[Tool, Callable]]:
def get_tools(self) -> list[Tool]:
"""
Return a list of all tools available for the task.
Note this does not include completion tools, which are handled separately.
"""
tools = self.tools.copy()
if self.interactive:
tools.append(cli_input)
for memory in self.memories:
tools.extend(memory.get_tools())
return tools
return as_tools(tools)

def get_completion_tools(self) -> list[Tool]:
tools = [
self.create_success_tool(),
self.create_fail_tool(),
]
"""
Return a list of all completion tools available for the task.
"""
tools = []
completion_tools = self.completion_tools
if completion_tools is None:
completion_tools = ["SUCCEED", "FAIL"]

if "SUCCEED" in completion_tools:
tools.append(self.get_success_tool())
if "FAIL" in completion_tools:
tools.append(self.get_fail_tool())
return tools

def get_prompt(self) -> str:
Expand Down Expand Up @@ -517,7 +542,7 @@ def mark_failed(self, reason: Optional[str] = None):
def mark_skipped(self):
self.set_status(TaskStatus.SKIPPED)

def create_success_tool(self) -> Tool:
def get_success_tool(self) -> Tool:
"""
Create an agent-compatible tool for marking this task as successful.
"""
Expand Down Expand Up @@ -587,7 +612,7 @@ def succeed(result: result_schema) -> str: # type: ignore

return succeed

def create_fail_tool(self) -> Tool:
def get_fail_tool(self) -> Tool:
"""
Create an agent-compatible tool for failing this task.
"""
Expand Down
78 changes: 73 additions & 5 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,27 +483,27 @@ def always_return_none(value: Any) -> None:
class TestSuccessTool:
def test_success_tool(self):
task = Task("choose 5", result_type=int)
tool = task.create_success_tool()
tool = task.get_success_tool()
tool.run(input=dict(result=5))
assert task.is_successful()
assert task.result == 5

def test_success_tool_with_list_of_options(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.create_success_tool()
tool = task.get_success_tool()
tool.run(input=dict(result=1))
assert task.is_successful()
assert task.result == "good"

def test_success_tool_with_list_of_options_requires_int(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.create_success_tool()
tool = task.get_success_tool()
with pytest.raises(ValueError):
tool.run(input=dict(result="good"))

def test_tuple_of_ints_result(self):
task = Task("choose 5", result_type=(4, 5, 6))
tool = task.create_success_tool()
tool = task.get_success_tool()
tool.run(input=dict(result=1))
assert task.result == 5

Expand All @@ -516,7 +516,7 @@ class Person(BaseModel):
"Who is the oldest?",
result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)),
)
tool = task.create_success_tool()
tool = task.get_success_tool()
tool.run(input=dict(result=1))
assert task.result == Person(name="Bob", age=35)
assert isinstance(task.result, Person)
Expand Down Expand Up @@ -549,3 +549,71 @@ async def test_task_run_async_with_handlers(self, default_fake_llm):

assert len(handler.events) > 0
assert len(handler.agent_messages) == 1


class TestCompletionTools:
def test_default_completion_tools(self):
task = Task(objective="Test task")
assert task.completion_tools is None
tools = task.get_completion_tools()
assert len(tools) == 2
assert any(t.name == f"mark_task_{task.id}_successful" for t in tools)
assert any(t.name == f"mark_task_{task.id}_failed" for t in tools)

def test_only_succeed_tool(self):
task = Task(objective="Test task", completion_tools=["SUCCEED"])
tools = task.get_completion_tools()
assert len(tools) == 1
assert tools[0].name == f"mark_task_{task.id}_successful"

def test_only_fail_tool(self):
task = Task(objective="Test task", completion_tools=["FAIL"])
tools = task.get_completion_tools()
assert len(tools) == 1
assert tools[0].name == f"mark_task_{task.id}_failed"

def test_no_completion_tools(self):
task = Task(objective="Test task", completion_tools=[])
tools = task.get_completion_tools()
assert len(tools) == 0

def test_invalid_completion_tool(self):
with pytest.raises(ValueError):
Task(objective="Test task", completion_tools=["INVALID"])

def test_manual_success_tool(self):
task = Task(objective="Test task", completion_tools=[], result_type=int)
success_tool = task.get_success_tool()
success_tool.run(input=dict(result=5))
assert task.is_successful()
assert task.result == 5

def test_manual_fail_tool(self):
task = Task(objective="Test task", completion_tools=[])
fail_tool = task.get_fail_tool()
assert fail_tool.name == f"mark_task_{task.id}_failed"
fail_tool.run(input=dict(reason="test error"))
assert task.is_failed()
assert task.result == "test error"

def test_completion_tools_with_run(self):
task = Task("Calculate 2 + 2", result_type=int, completion_tools=["SUCCEED"])
result = task.run(max_llm_calls=1)
assert result == 4
assert task.is_successful()

def test_no_completion_tools_with_run(self):
task = Task("Calculate 2 + 2", result_type=int, completion_tools=[])
task.run(max_llm_calls=1)
assert task.is_incomplete()

async def test_completion_tools_with_run_async(self):
task = Task("Calculate 2 + 2", result_type=int, completion_tools=["SUCCEED"])
result = await task.run_async(max_llm_calls=1)
assert result == 4
assert task.is_successful()

async def test_no_completion_tools_with_run_async(self):
task = Task("Calculate 2 + 2", result_type=int, completion_tools=[])
await task.run_async(max_llm_calls=1)
assert task.is_incomplete()

0 comments on commit 176a4bd

Please sign in to comment.