diff --git a/.github/labeler.yml b/.github/labeler.yml index bb670b7e..5a21f99e 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,7 +1,13 @@ documentation: - changed-files: - - any-glob-to-any-file: "docs/*" + - any-glob-to-any-file: "docs/**" + +example: + - changed-files: + - any-glob-to-any-file: + - "examples/**" + - "docs/examples/**" tests: - changed-files: - - any-glob-to-any-file: "tests/*" + - any-glob-to-any-file: "tests/**" diff --git a/README.md b/README.md index 4af57423..36329b9f 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Next, configure your LLM provider. ControlFlow's default provider is OpenAI, whi export OPENAI_API_KEY=your-api-key ``` -To use a different LLM provider, [see the LLM configuration docs](https://controlflow.ai/guides/llms). +To use a different LLM provider, [see the LLM configuration docs](https://controlflow.ai/guides/configure-llms). ## Workflow Example diff --git a/docs/concepts/agents.mdx b/docs/concepts/agents.mdx index 0576d716..0790abf3 100644 --- a/docs/concepts/agents.mdx +++ b/docs/concepts/agents.mdx @@ -4,6 +4,8 @@ description: The intelligent workers in your AI workflows. icon: robot --- +import { VersionBadge } from '/snippets/version-badge.mdx' + Agents are the intelligent, autonomous entities that power your AI workflows in ControlFlow. They represent AI models capable of understanding instructions, making decisions, and completing tasks. ```python @@ -67,7 +69,36 @@ Tools are Python functions that the agent can call to perform specific actions o Each agent has a model, which is the LLM that powers the agent responses. This allows you to choose the most suitable model for your needs, based on factors such as performance, latency, and cost. -ControlFlow supports any LangChain LLM that supports chat and function calling. For more details on how to configure models, see the [LLMs guide](/guides/llms). +ControlFlow supports any LangChain LLM that supports chat and function calling. For more details on how to configure models, see the [LLMs guide](/guides/configure-llms). + +```python +import controlflow as cf + + +agent1 = cf.Agent(model="openai/gpt-4o") +agent2 = cf.Agent(model="anthropic/claude-3-5-sonnet-20240620") +``` + +### LLM rules + + +Each LLM provider may have different requirements for how messages are formatted or presented. For example, OpenAI permits system messages to be interspersed between user messages, but Anthropic requires them to be at the beginning of the conversation. ControlFlow uses provider-specific rules to properly compile messages for each agent's API. + +For common providers like OpenAI and Anthropic, LLM rules can be automatically inferred from the agent's model. However, you can use a custom `LLMRules` object to override these rules or provide rules for non-standard providers. + +Here is an example of how to tell the agent to use the Anthropic compilation rules with a custom model that can't be automatically inferred: + +```python +import controlflow as cf + +# note: this is just an example +llm_model = CustomAnthropicModel() + +agent = cf.Agent( + model=model, + llm_rules=cf.llm.rules.AnthropicRules(model=model) +) +``` ### Interactivity diff --git a/docs/concepts/tasks.mdx b/docs/concepts/tasks.mdx index e47174b9..79164ec1 100644 --- a/docs/concepts/tasks.mdx +++ b/docs/concepts/tasks.mdx @@ -264,6 +264,59 @@ task = cf.Task( ) ``` +Note that this setting reflects the configuration of the `completion_tools` parameter. + +### Completion tools + +import { VersionBadge } from '/snippets/version-badge.mdx' + + + +In addition to specifying which agents are automatically given completion tools, you can control which completion tools are generated for a task using the `completion_tools` parameter. This allows you to specify whether you want to provide success and/or failure tools, or even provide custom completion tools. + +The `completion_tools` parameter accepts a list of strings, where each string represents a tool to be generated. The available options are: + +- `"SUCCEED"`: Generates a tool for marking the task as successful. +- `"FAIL"`: Generates a tool for marking the task as failed. + +If `completion_tools` is not specified, both `"SUCCEED"` and `"FAIL"` tools will be generated by default. + +You can manually create completion tools and provide them to your agents by calling `task.get_success_tool()` and `task.get_fail_tool()`. + + +If you exclude `completion_tools`, agents may be unable to complete the task or become stuck in a failure state. Without caps on LLM turns or calls, this could lead to runaway LLM usage. Make sure to manually manage how agents complete tasks if you are using a custom set of completion tools. + + +Here are some examples: + +``` +# Generate both success and failure tools (default behavior, equivalent to `completion_tools=None`) +task = cf.Task( + objective="Write a poem about AI", + completion_tools=["SUCCEED", "FAIL"], +) + +# Only generate a success tool +task = cf.Task( + objective="Write a poem about AI", + completion_tools=["SUCCEED"], +) + +# Only generate a failure tool +task = cf.Task( + objective="Write a poem about AI", + completion_tools=["FAIL"], +) + +# Don't generate any completion tools +task = cf.Task( + objective="Write a poem about AI", + completion_tools=[], +) +``` + +By controlling which completion tools are generated, you can customize the task completion process to better suit your workflow needs. For example, you might want to prevent agents from marking a task as failed, or you might want to provide your own custom completion tools instead of using the default ones. + ### Name The name of a task is a string that identifies the task within the workflow. It is used primarily for logging and debugging purposes, though it is also shown to agents during execution to help identify the task they are working on. diff --git a/docs/examples/call-routing.mdx b/docs/examples/call-routing.mdx index 6f019174..713e0d0b 100644 --- a/docs/examples/call-routing.mdx +++ b/docs/examples/call-routing.mdx @@ -80,7 +80,7 @@ def routing_flow(): ), agents=[trainee], result_type=None, - tools=[main_task.create_success_tool()] + tools=[main_task.get_success_tool()] ) if main_task.result == target_department: diff --git a/docs/examples/features/early-termination.mdx b/docs/examples/features/early-termination.mdx new file mode 100644 index 00000000..ba98d5b8 --- /dev/null +++ b/docs/examples/features/early-termination.mdx @@ -0,0 +1,102 @@ +--- +title: Early Termination +description: Control workflow execution with flexible termination logic. +icon: circle-stop +--- + +import { VersionBadge } from "/snippets/version-badge.mdx" + + + +This example demonstrates how to use termination conditions with the `run_until` parameter to control the execution of a ControlFlow workflow. We'll create a simple research workflow that stops under various conditions, showcasing the flexibility of this feature. In this case, we'll allow research to continue until either two topics are researched or 15 LLM calls are made. + +## Code + +```python +import controlflow as cf +from controlflow.orchestration.conditions import AnyComplete, MaxLLMCalls +from pydantic import BaseModel + + +class ResearchPoint(BaseModel): + topic: str + key_findings: list[str] + + +@cf.flow +def research_workflow(topics: list[str]): + if len(topics) < 2: + raise ValueError("At least two topics are required") + + research_tasks = [ + cf.Task(f"Research {topic}", result_type=ResearchPoint) + for topic in topics + ] + + # Run tasks with termination conditions + results = cf.run_tasks( + research_tasks, + instructions="Research only one topic at a time.", + run_until=( + AnyComplete(min_complete=2) # stop after two tasks (if there are more than two topics) + | MaxLLMCalls(15) # or stop after 15 LLM calls, whichever comes first + ) + ) + + completed_research = [r for r in results if isinstance(r, ResearchPoint)] + return completed_research +``` + + + +Now, if we run this workflow on 4 topics, it will stop after two topics are researched: + +```python Example Usage +# Example usage +topics = [ + "Artificial Intelligence", + "Quantum Computing", + "Biotechnology", + "Renewable Energy", +] +results = research_workflow(topics) + +print(f"Completed research on {len(results)} topics:") +for research in results: + print(f"\nTopic: {research.topic}") + print("Key Findings:") + for finding in research.key_findings: + print(f"- {finding}") +``` + +```text Result +Completed research on 2 topics: + +Topic: Artificial Intelligence +Key Findings: +- Machine Learning and Deep Learning: These are subsets of AI that involve training models on large datasets to make predictions or decisions without being explicitly programmed. They are widely used in various applications, including image and speech recognition, natural language processing, and autonomous vehicles. +- AI Ethics and Bias: As AI systems become more prevalent, ethical concerns such as bias in AI algorithms, data privacy, and the impact on employment are increasingly significant. Ensuring fairness, transparency, and accountability in AI systems is a growing area of focus. +- AI in Healthcare: AI technologies are revolutionizing healthcare through applications in diagnostics, personalized medicine, and patient monitoring. AI can analyze medical data to assist in early disease detection and treatment planning. +- Natural Language Processing (NLP): NLP is a field of AI focused on the interaction between computers and humans through natural language. Recent advancements include transformers and large language models, which have improved the ability of machines to understand and generate human language. +- AI in Autonomous Systems: AI is a crucial component in developing autonomous systems, such as self-driving cars and drones, which require perception, decision-making, and control capabilities to navigate and operate in real-world environments. + +Topic: Quantum Computing +Key Findings: +- Quantum Bits (Qubits): Unlike classical bits, qubits can exist in multiple states simultaneously due to superposition. This allows quantum computers to process a vast amount of information at once, offering a potential exponential speed-up over classical computers for certain tasks. +- Quantum Entanglement: This phenomenon allows qubits that are entangled to be correlated with each other, even when separated by large distances. Entanglement is a key resource in quantum computing and quantum communication. +- Quantum Algorithms: Quantum algorithms, such as Shor's algorithm for factoring large numbers and Grover's algorithm for searching unsorted databases, demonstrate the potential power of quantum computing over classical approaches. +- Quantum Error Correction: Quantum systems are prone to errors due to decoherence and noise from the environment. Quantum error correction methods are essential for maintaining the integrity of quantum computations. +- Applications and Challenges: Quantum computing holds promise for solving complex problems in cryptography, material science, and optimization. However, significant technological challenges remain, including maintaining qubit coherence, scaling up the number of qubits, and developing practical quantum software. +``` + +## Key Concepts + +1. **Custom Termination Conditions**: We use a combination of `AnyComplete` and `MaxLLMCalls` conditions to control when the workflow should stop. + +2. **Flexible Workflow Control**: By using termination conditions with the `run_until` parameter, we can create more dynamic workflows that adapt to different scenarios. In this case, we're balancing between getting enough research done and limiting resource usage. + +3. **Partial Results**: The workflow can end before all tasks are complete, so we handle partial results by filtering for completed `ResearchPoint` objects. + +4. **Combining Conditions**: We use the `|` operator to combine multiple termination conditions. ControlFlow also supports `&` for more complex logic. + +This example demonstrates how termination conditions provide fine-grained control over workflow execution, allowing you to balance between task completion and resource usage. This can be particularly useful for managing costs, handling time-sensitive operations, or creating more responsive AI workflows. diff --git a/docs/examples/features/memory.mdx b/docs/examples/features/memory.mdx new file mode 100644 index 00000000..7005b9ca --- /dev/null +++ b/docs/examples/features/memory.mdx @@ -0,0 +1,101 @@ +--- +title: Using Memory +description: How to use memory to persist information across different conversations +icon: brain +--- +import { VersionBadge } from '/snippets/version-badge.mdx' + + + + +Memory in ControlFlow allows agents to store and retrieve information across different conversations or workflow executions. This is particularly useful for maintaining context over time or sharing information between separate interactions. + +## Setup + +In order to use memory, you'll need to configure a [memory provider](/patterns/memory#provider). For this example, we'll use the default Chroma provider. You'll need to `pip install chromadb` to install its dependencies. + +## Code + +In this example, we'll create a simple workflow that remembers a user's favorite color across different conversations. For simplicity, we'll demonstrate the memory by using two different flows, which represent two different threads. + +```python +import controlflow as cf + + +# Create a memory module for user preferences +user_preferences = cf.Memory( + key="user_preferences", + instructions="Store and retrieve user preferences." +) + + +# Create an agent with access to the memory +agent = cf.Agent(memories=[user_preferences]) + + +# Create a flow to ask for the user's favorite color +@cf.flow +def remember_color(): + return cf.run( + "Ask the user for their favorite color and store it in memory", + agents=[agent], + interactive=True, + ) + + +# Create a flow to recall the user's favorite color +@cf.flow +def recall_color(): + return cf.run( + "What is the user's favorite color?", + agents=[agent], + ) +``` + +Ordinarily, running the flows above would result in two separate -- unconnected -- conversations. The agent in the `recall_color` flow would have no way of knowing about the information from the first flow, even though its the same agent, because the conversation histories are not shared. + +However, because we gave the agent a memory module and instructions for how to use it, the agent *will* be able to recall the information from the first flow. + +Run the first flow: + +```python First flow +remember_color() +``` +```text Result +Agent: Hello! What is your favorite color? +User: I really like a blue-purple shade. +Agent: Great, thank you. +``` + + +When we run the second flow, the agent correctly recalls the favorite color: + +```python Second flow +result = recall_color() +print(result) +``` +```text Result +The user's favorite color is a blue-purple shade. +``` + + +## Key concepts + +1. **[Memory creation](/patterns/memory#creating-memory-modules)**: We create a `Memory` object with a unique key and instructions for its use. + + ```python + user_preferences = cf.Memory( + key="user_preferences", + instructions="Store and retrieve user preferences." + ) + ``` + +2. **[Assigning memory to agents](/patterns/memory#assigning-memories)**: We assign the memory to an agent, allowing it to access and modify the stored information. + + ```python + agent = cf.Agent(name="PreferenceAgent", memories=[user_preferences]) + ``` + +3. **[Using memory across flows](/patterns/memory#sharing-memories)**: By using the same memory in different flows, we can access information across separate conversations. + +This example demonstrates how ControlFlow's memory feature allows information to persist across different workflow executions, enabling more context-aware and personalized interactions. diff --git a/docs/examples/features/private-flows.mdx b/docs/examples/features/private-flows.mdx index e0c0db4d..fdbe4594 100644 --- a/docs/examples/features/private-flows.mdx +++ b/docs/examples/features/private-flows.mdx @@ -1,5 +1,5 @@ --- -title: Private flows +title: Private Flows description: Create isolated execution environments within your workflows. icon: lock --- diff --git a/docs/examples/features/tools.mdx b/docs/examples/features/tools.mdx index f60dc1c7..d095feb2 100644 --- a/docs/examples/features/tools.mdx +++ b/docs/examples/features/tools.mdx @@ -1,5 +1,5 @@ --- -title: Custom tools +title: Custom Tools description: Provide tools to expand agent capabilities. icon: wrench --- diff --git a/docs/examples/library.mdx b/docs/examples/library.mdx deleted file mode 100644 index 7940293a..00000000 --- a/docs/examples/library.mdx +++ /dev/null @@ -1,18 +0,0 @@ ---- -title: Library ---- - - - - Play a game of rock, paper, scissors against an AI - without letting it cheat. - - - Two agents cooperate to route customer calls to the correct department. - - - An autonomous software engineer that creates applications based on your input. - - - More examples are on the way! - - diff --git a/docs/examples/seinfeld-conversation.mdx b/docs/examples/seinfeld-conversation.mdx new file mode 100644 index 00000000..1a837194 --- /dev/null +++ b/docs/examples/seinfeld-conversation.mdx @@ -0,0 +1,127 @@ +--- +title: Seinfeld Conversation +description: Simulate a conversation between Seinfeld characters using multiple AI agents. +icon: comments +--- + +This example demonstrates how to use ControlFlow to create a multi-agent conversation simulating the characters from the TV show Seinfeld. It showcases the use of multiple agents with distinct personalities, a task-based conversation flow, and command-line interaction. + +## Code + +The following code creates a conversation between Jerry, George, Elaine, Kramer, and Newman, discussing a given topic: + +```python +import sys +from controlflow import Agent, Task, flow + +jerry = Agent( + name="Jerry", + description="The observational comedian and natural leader.", + instructions=""" + You are Jerry from the show Seinfeld. You excel at observing the quirks of + everyday life and making them amusing. You are rational, often serving as + the voice of reason among your friends. Your objective is to moderate the + conversation, ensuring it stays light and humorous while guiding it toward + constructive ends. + """, +) + +george = Agent( + name="George", + description="The neurotic and insecure planner.", + instructions=""" + You are George from the show Seinfeld. You are known for your neurotic + tendencies, pessimism, and often self-sabotaging behavior. Despite these + traits, you occasionally offer surprising wisdom. Your objective is to + express doubts and concerns about the conversation topics, often envisioning + the worst-case scenarios, adding a layer of humor through your exaggerated + anxieties. + """, +) + +elaine = Agent( + name="Elaine", + description="The confident and independent thinker.", + instructions=""" + You are Elaine from the show Seinfeld. You are bold, witty, and unafraid to + challenge social norms. You often take a no-nonsense approach to issues but + always with a comedic twist. Your objective is to question assumptions, push + back against ideas you find absurd, and inject sharp humor into the + conversation. + """, +) + +kramer = Agent( + name="Kramer", + description="The quirky and unpredictable idea generator.", + instructions=""" + You are Kramer from the show Seinfeld. Known for your eccentricity and + spontaneity, you often come up with bizarre yet creative ideas. Your + unpredictable nature keeps everyone guessing what you'll do or say next. + Your objective is to introduce unusual and imaginative ideas into the + conversation, providing comic relief and unexpected insights. + """, +) + +newman = Agent( + name="Newman", + description="The antagonist and foil to Jerry.", + instructions=""" + You are Newman from the show Seinfeld. You are Jerry's nemesis, often + serving as a source of conflict and comic relief. Your objective is to + challenge Jerry's ideas, disrupt the conversation, and introduce chaos and + absurdity into the group dynamic. + """, +) + +@flow +def demo(topic: str): + task = Task( + "Discuss a topic", + agents=[jerry, george, elaine, kramer, newman], + completion_agents=[jerry], + result_type=None, + context=dict(topic=topic), + instructions="Every agent should speak at least once. only one agent per turn. Keep responses 1-2 paragraphs max.", + ) + task.run() + +if __name__ == "__main__": + if len(sys.argv) > 1: + topic = sys.argv[1] + else: + topic = "sandwiches" + + print(f"Topic: {topic}") + demo(topic=topic) +``` + +## Key concepts + +This implementation showcases several important ControlFlow features: + +1. **Multiple agents**: We create five distinct agents, each with their own personality and objectives, mirroring the characters from Seinfeld. + +2. **Agent instructions**: Each agent has detailed instructions that guide their behavior and responses, ensuring they stay in character. + +3. **Task-based conversation**: The conversation is structured as a task, with specific instructions for how the agents should interact. + +4. **Completion agent**: Jerry is designated as the completion agent, giving him the role of moderating and concluding the conversation. + +5. **Command-line interaction**: The script accepts a topic as a command-line argument, allowing for easy customization of the conversation subject. + +## Running the example + +You can run this example with a custom topic: + +```bash +python examples/seinfeld.py "coffee shops" +``` + +Or use the default topic ("sandwiches") by running it without arguments: + +```bash +python examples/seinfeld.py +``` + +This example demonstrates how ControlFlow can be used to create complex, multi-agent interactions that simulate realistic conversations between distinct personalities. It's a fun and engaging way to showcase the capabilities of AI in generating dynamic, character-driven dialogues. \ No newline at end of file diff --git a/docs/guides/llms.mdx b/docs/guides/configure-llms.mdx similarity index 99% rename from docs/guides/llms.mdx rename to docs/guides/configure-llms.mdx index f40128c6..8df6e55d 100644 --- a/docs/guides/llms.mdx +++ b/docs/guides/configure-llms.mdx @@ -1,5 +1,5 @@ --- -title: Configuring LLM models +title: LLM Models description: ControlFlow supports a variety of LLMs and model providers. icon: sliders --- diff --git a/docs/guides/default-agent.mdx b/docs/guides/default-agent.mdx index 11a6fe2f..64b31dd5 100644 --- a/docs/guides/default-agent.mdx +++ b/docs/guides/default-agent.mdx @@ -1,5 +1,6 @@ --- -title: Configuring the Default Agent +title: Configuring a Default Agent +sidebarTitle: Default Agent description: Set global and flow-specific defaults. icon: robot --- @@ -8,7 +9,7 @@ ControlFlow uses a default agent when no specific agents are assigned to a task. ## Changing the Global Default Agent -The global default agent (whose name, of course, is Marvin) uses whatever [default model](/guides/llms#changing-the-default-model) you've configured. It has a basic set of general-purpose instructions and is not equipped with any tools. +The global default agent (whose name, of course, is Marvin) uses whatever [default model](/guides/configure-llms#changing-the-default-model) you've configured. It has a basic set of general-purpose instructions and is not equipped with any tools. To change the global default agent, assign a new agent to `controlflow.defaults.agent`: diff --git a/docs/guides/default-memory.mdx b/docs/guides/default-memory.mdx new file mode 100644 index 00000000..39eceded --- /dev/null +++ b/docs/guides/default-memory.mdx @@ -0,0 +1,62 @@ +--- +title: Configure a Default Memory Provider +sidebarTitle: Default Memory Provider +description: Set up a default persistent memory provider for your agents +icon: brain +--- +import { VersionBadge } from '/snippets/version-badge.mdx' + + +ControlFlow's [memory](/patterns/memory) feature allows agents to store and retrieve information across multiple workflows. Memory modules are backed by a vector database, configured using a `MemoryProvider`. + +Setting up a default provider simplifies the process of creating memory objects throughout your application. Once configured, you can create memory objects without specifying a provider each time. + + +While ControlFlow does not include any vector database dependencies by default, the default provider is set to `"chroma-db"`. This means that if you install the `chromadb` package, your memory modules will work without any additional configuration. + + +## Install dependencies + +To use a provider, you must first install its dependencies. Please refer to the [Memory doc](/patterns/memory) to see all supported providers and their required dependencies. + +For example, to use the default [Chroma](https://trychroma.com/) provider, you need to install `chromadb`: + +```bash +pip install chromadb +``` + +## Configure a default provider + +There are two ways to set up a default provider: using a string setting for common defaults, or instantiating a custom provider. Here, we'll use a persistent Chroma database as our example. + +### String configurations + +For simple provider setups, you can modify ControlFlow's default settings using a string value. The default value is `"chroma-db"`, which will create a persistent Chroma database. To change it: + + +```bash Environment variable +export CONTROLFLOW_MEMORY_PROVIDER="chroma-ephemeral" +``` +```python Runtime +import controlflow as cf + +cf.settings.memory_provider = "chroma-ephemeral" +``` + + +For a list of available string configurations, see the [Memory pattern guide](/patterns/memory). + +### Custom provider configuration + +For more advanced setups, instantiate a provider with custom settings and assign it to the ControlFlow default. Note this must be done at runtime. + +```python +import controlflow as cf +from controlflow.memory.providers.chroma import ChromaMemory +import chromadb + +# Set the default provider +cf.defaults.memory_provider = ChromaMemory( + client=chromadb.PersistentClient(path="/custom/path"), +) +``` diff --git a/docs/installation.mdx b/docs/installation.mdx index a380b93b..82bc5b6f 100644 --- a/docs/installation.mdx +++ b/docs/installation.mdx @@ -35,7 +35,7 @@ export CONTROLFLOW_LLM_MODEL="anthropic/claude-3-5-sonnet-20240620" ### Other providers -ControlFlow supports many other LLM providers as well, though you'll need to install their respective packages and configure the default LLM appropriately. See the [LLM documentation](/guides/llms) for more information. +ControlFlow supports many other LLM providers as well, though you'll need to install their respective packages and configure the default LLM appropriately. See the [LLM documentation](/guides/configure-llms) for more information. ## Next steps diff --git a/docs/mint.json b/docs/mint.json index 945df45f..4a44f92d 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -57,6 +57,7 @@ "patterns/tools", "patterns/interactivity", "patterns/dependencies", + "patterns/memory", "patterns/instructions", "patterns/planning", "patterns/history" @@ -66,8 +67,9 @@ "group": "Configuration", "pages": [ "guides/settings", - "guides/llms", - "guides/default-agent" + "guides/configure-llms", + "guides/default-agent", + "guides/default-memory" ] }, { @@ -76,7 +78,9 @@ "examples/features/dependent-tasks", "examples/features/tools", "examples/features/multi-llm", - "examples/features/private-flows" + "examples/features/private-flows", + "examples/features/memory", + "examples/features/early-termination" ] }, { @@ -98,8 +102,9 @@ "pages": [ "examples/language-tutor", "examples/rock-paper-scissors", - "examples/agent-engineer", - "examples/call-routing" + "examples/seinfeld-conversation", + "examples/call-routing", + "examples/agent-engineer" ] }, { diff --git a/docs/patterns/memory.mdx b/docs/patterns/memory.mdx new file mode 100644 index 00000000..9728935b --- /dev/null +++ b/docs/patterns/memory.mdx @@ -0,0 +1,283 @@ +--- +title: Memory +description: Enhance your agents with persistent memories. +icon: bookmark +--- +import { VersionBadge } from '/snippets/version-badge.mdx' + + + + +Within an agentic workflow, information is naturally added to the [thread history](/patterns/history) over time, making available to all agents who participate in the workflow. However, that information is not accessible from other threads, even if they relate to the same objective or resources. + +ControlFlow has a memory feature that allows agents to selectively store and recall information across multiple interactions. This feature is useful for creating more capable and context-aware agents. For example: + +- Remembering a user's name or other personal details across conversations +- Retaining facts from one session for use in another +- Keeping details about a repository's style guide for later reference +- Maintaining project-specific information across multiple interactions +- Enabling "soft" collaboration between agents through a shared knowledge base + +Memory modules provide this functionality, allowing agents to build up and access a persistent knowledge base over time. + +## How Memory Works + +ControlFlow memories are implemented as context-specific vector stores that permit agents to add and query information using natural language. Each memory object has a "key" that uniquely identifies it and partitions its contents from other vector stores for easy retrieval. For example, you might have a different memory store for each user, agent, project, or even task, used to persist information across multiple interactions with that object. Agents can be provided with multiple memory modules, allowing them to access different sets of memories simultaneously. + +## Creating Memory Modules + +To create a memory object, you need to provide a `key` and `instructions`. The `key` uniquely identifies the memory module so it can be accessed later. The `instructions` explain what kind of information should be stored, and how it should be used. + + +ControlFlow does not include any vector database dependencies by default to keep the library lightweight, so you must [install and configure](#provider) a provider before creating a memory object. + +To run the examples with minimal configuration, run `pip install chromadb` to install the dependency for the default Chroma provider. To change the default, see the [default provider guide](/guides/default-memory). + + + +```python +import controlflow as cf + +# Create a Memory module for storing weather information +memory = cf.Memory( + key="weather", + instructions="Stores information about the weather." +) +``` + +### Assigning Memories + +Like tools, memory modules can be provided to either agents or tasks. When provided to an agent, it will be able to access the memories when executing any task. When provided to a task, the memories will be available to any assigned agents. The choice of where to assign a memory module depends entirely on your preference and the design of your application; when the workflow is compiled the behavior is identical. + + +#### Assigning to an Agent + +```python +agent = cf.Agent( + name="Weather Agent", + memories=[memory] +) +``` + +#### Assigning to a Task + +```python +task = cf.Task( + name="Weather Task", + memories=[memory] +) +``` + +### Assigning Multiple Memories + +You can assign multiple memories to an agent or task. When this happens, the agent or task will have access to all of the modules and be able to store and retrieve information from each of them separately. + + +### Sharing Memories + +Remember that you can provide the same memory module to multiple agents or tasks. When this happens, the memories are shared across all of the agents and tasks. + + +Memories are partitioned by `key`, so you can provide different instructions to different agents for the same module. For example, you might have one agent that you encourage to record information to a memory module, and another that you encourage to read memories from the same module. + + + +## Configuration + +### Key + +The `key` is crucial for accessing the correct set of memories. It must be provided exactly the same way each time to access an existing memory. Keys should be descriptive and unique for each distinct memory set you want to maintain. + +### Instructions + +The `instructions` field is important because it tells the agent when and how to access or add to the memory. Unlike the `key`, instructions can be different for the same memory key across different Memory objects. This allows for flexibility in how agents interact with the same set of memories. + +Good instructions should explain: +- What kind of information the memory is used to store +- When the agent should read from or write to the memory +- Any specific guidelines for formatting or categorizing the stored information + +For example: + +```python +project_memory = cf.Memory( + key="project_alpha", + instructions=""" + This memory stores important details about Project Alpha. + - Read from this memory when you need information about project goals, timelines, or team members. + - Write to this memory when you learn new, important facts about the project. + - Always include dates when adding new information. + """ +) +``` + +### Provider + +The `provider` is the underlying storage mechanism for the memory. It is responsible for storing and retrieving the memory objects. + + +The default provider is "chroma-db", which uses a local persistent [Chroma](https://trychroma.com/) database. Run `pip install chromadb` to install its dependencies, after which you can start using memories with no additional configuration. + + +#### Installing provider dependencies +To configure a provider, you need to install its package and either configure the provider with a string value or create an instance of the provider and pass it to the memory module. + +ControlFlow does not include any vector database dependencies by default, in order to keep the library lightweight. + + +This table shows the supported providers and their respective dependencies: + +| Provider | Required dependencies | +| -------- | ----------------- | +| [Chroma](https://trychroma.com/) | `chromadb` | +| [LanceDB](https://lancedb.com/) | `lancedb` | + +You can install the dependencies for a provider with pip, for example `pip install chromadb` to use the Chroma provider. + +#### Configuring a provider with a string + +For straightforward provider configurations, you can pass a string value to the `provider` parameter that will instantiate a provider with default settings. The following strings are recognized: + +|Provider | Provider string | Description | +| -------- | -------- | ----------------- | +| Chroma | `chroma-ephemeral` | An ephemeral (in-memory) database. | +| Chroma | `chroma-db` | Uses a persistent, local-file-based database, with a default path of `~/.controlflow/memory/chroma`. | +| Chroma | `chroma-cloud` | Uses the Chroma Cloud service. The `CONTROLFLOW_CHROMA_CLOUD_API_KEY`, `CONTROLFLOW_CHROMA_CLOUD_TENANT`, and `CONTROLFLOW_CHROMA_CLOUD_DATABASE` settings are required. | +| LanceDB | `lancedb` | Uses a persistent, local-file-based database, with a default path of `~/.controlflow/memory/lancedb`. | +For example, if `chromadb` is installed, the following code will create a memory module that uses an ephemeral Chroma database: + +```python +import controlflow as cf + +cf.Memory(..., provider="chroma-ephemeral") +``` + +#### Configuring a Provider instance + +For more complex configurations, you can instantiate a provider directly and pass it to the memory module. + +For example, the Chroma provider accepts a `client` parameter that allows you to customize how the Chroma client connects, as well as a `collection_name` parameter to specify the name of the collection to use. + +```python +import controlflow as cf +from controlflow.memory.providers.chroma import ChromaMemory +import chromadb + +provider = ChromaMemory( + client=chromadb.PersistentClient(path="/path/to/save/to"), + collection_name="custom-{key}", +) + +memory = cf.Memory(..., provider=provider) +``` + +#### Configuring a default provider + +You can configure a default provider to avoid having to specify a provider each time you create a memory module. Please see the guide on [default providers](/guides/default-memory) for more information. + + + +## Example: Storing Weather Information + +In this example, we'll create a memory module for weather information and use it to retrieve that information in a different conversation. Begin by creating a memory module, assigning it to a task, and informing the task that it is 70 degrees today: + + +```python Code +import controlflow as cf + +# Create a Memory module +weather_memory = cf.Memory( + key="weather", + instructions="Store and retrieve information about the weather." +) + +cf.run("It is 70 degrees today.", memories=[weather_memory]) +``` + +```text Result +"The weather information has been stored: It is 70 degrees today." +``` + + +Now, in a different conversation, we can retrieve that information. Note that the setup is almost identical, except that the task asks the agent to answer a question about the weather. + + +```python Code +import controlflow as cf + +# Create a Memory module +weather_memory = cf.Memory( + key="weather", + instructions="Store and retrieve information about the weather." +) + +cf.run("What is the weather today?", memories=[weather_memory]) +``` + +```text Result +"It is 70 degrees today." +``` + + + +### Example: Slack Customer Service + +Suppose we have an agent that answers questions in Slack. We are going to equip the agent with the following memory modules: +- One for each user in the thread +- One for common problems that users encounter + +Since we always invoke the agent with these memories, it will be able to access persistent information about any user its assisting, as well as issues they frequently encounter, even if that information wasn't shared in the current thread. + +Here is example code for how this might work: + +```python +import controlflow as cf + + +@cf.flow +def customer_service_flow(slack_thread_id: str): + + # create a memory module for each user + user_memories = [ + cf.Memory( + key=user_id, + instructions=f"Store and retrieve any information about user {user_id}.", + ) + for user_id in get_user_ids(slack_thread_id) + ] + + # create a memory module for problems + problems_memory = cf.Memory( + key="problems", + instructions="Store and retrieve important information about common user problems.", + ) + + # create an agent with access to the memory modules + agent = cf.Agent( + name="Customer Service Agent", + instructions=""" + Help users by answering their questions. Use available + memories to personalize your response. + """, + memories=user_memories + [problems_memory] + ) + + # use the agent to respond + cf.run( + "Respond to the users' latest message", + agents=[agent], + context=dict(messages=get_messages(slack_thread_id)), + ) +``` + + +## Best Practices + +1. Use descriptive, unique keys for different memory sets +2. Provide clear, specific instructions to guide agents in using the memory effectively +3. Consider the lifespan of memories - some may be relevant only for a single session, while others may persist across multiple runs +4. Use multiple memory objects when an agent needs to access different sets of information +5. Leverage shared memories for collaborative scenarios where multiple agents need access to the same knowledge base +6. Regularly review and update memory instructions to ensure they remain relevant and useful + +By leveraging ControlFlow's Memory feature effectively, you can create more sophisticated agents that maintain context, learn from past interactions, and make more informed decisions based on accumulated knowledge across multiple conversations or sessions. \ No newline at end of file diff --git a/docs/patterns/running-tasks.mdx b/docs/patterns/running-tasks.mdx index 6ac5344c..d45d4099 100644 --- a/docs/patterns/running-tasks.mdx +++ b/docs/patterns/running-tasks.mdx @@ -1,9 +1,11 @@ --- -title: Running tasks +title: Running Tasks description: Control task execution and manage how agents collaborate. icon: play --- +import { VersionBadge } from "/snippets/version-badge.mdx" + Tasks represent a unit of work that needs to be completed by your agents. To execute that work and retrieve its result, you need to instruct your agents to run the task. @@ -356,6 +358,36 @@ Note that the setting `max_llm_calls` on the task results in the task failing if +#### Early termination conditions + + + +ControlFlow supports more flexible control over when an orchestration run should end through the use of `run_until` conditions. These conditions allow you to specify complex termination logic based on various factors such as task completion, failure, or custom criteria. + +To use a run until condition, you can pass it to the `run_until` parameter when calling `run`, `run_async`, `run_tasks`, or `run_tasks_async`. For example, the following tasks will run until either one of them is complete or 10 LLM calls have been made: + +```python +import controlflow as cf +from controlflow.orchestration.conditions import AnyComplete, MaxLLMCalls + +result = cf.run_tasks( + tasks=[cf.Task("write a poem about AI"), cf.Task("write a poem about ML")], + run_until=AnyComplete() | MaxLLMCalls(10) +) +``` + +(Note that because tasks can be run in parallel, it's possible for both subtasks to be completed.) + +Termination conditions can be combined using boolean logic: `|` indicates "or" and `&` indicates "and". A variety of built-in conditions are available: + +- `AllComplete()`: stop when all tasks are complete (this is the default behavior) +- `MaxLLMCalls(n: int)`: stop when `n` LLM calls have been made (equivalent to providing `max_llm_calls`) +- `MaxAgentTurns(n: int)`: stop when `n` agent turns have been made (equivalent to providing `max_agent_turns`) +- `AnyComplete(tasks: list[Task], min_complete: int=1)`: stop when at least `min_complete` tasks are complete. If no tasks are provided, all of the orchestrator's tasks will be used. +- `AnyFailed(tasks: list[Task], min_failed: int=1)`: stop when at least `min_failed` tasks have failed. If no tasks are provided, all of the orchestrator's tasks will be used. + + + ### Accessing an orchestrator directly If you want to "step" through the agentic loop yourself, you can create and invoke an `Orchestrator` directly. diff --git a/docs/patterns/task-results.mdx b/docs/patterns/task-results.mdx index 686f1457..66fccb2e 100644 --- a/docs/patterns/task-results.mdx +++ b/docs/patterns/task-results.mdx @@ -128,9 +128,12 @@ print(result) Note that annotated types are not validated; the annotation is provided as part of the agent's natural language instructions. You could additionaly provide a custom [result validator](#result-validators) to enforce the constraint. -### Specific values +### Labeling / classification -You can limit the result to one of a specific set of values, in order to label or classify a response. To do this, specify a list or tuple of allowed values for the result type. Here, we classify the media type of "Star Wars: Return of the Jedi": +Often, you may want an agent to choose a value from a specific set of options, in order to label or classify a response as one of potentially many choices. + + +To do this, specify a list, tuple, `Literal`, or enum of allowed values for the result type. Here, we classify the media type of "Star Wars: Return of the Jedi" from a list of options: ```python Code @@ -149,11 +152,78 @@ movie ``` - -For classification tasks, ControlFlow asks agents to choose a value by index rather than writing out the entire response. This optimization significantly improves latency while also conserving output tokens. +ControlFlow optimizes single-choice constrained selection by asking agents to choose a value by index rather than writing out the entire response. This optimization significantly improves latency while also conserving output tokens. +You can provide almost any Python object as a constrained choice, and ControlFlow will return *that object* as the result. Note that objects must be serialized in order to be shown to the agent. + +#### A list of labels + + +When you provide a set of constrained choices, the agent will choose **one and only one** as the task's result. Often, you will want to produce a list of labels, either because you want to classify multiple items at once OR because you want to allow the agent to choose multiple values for a single input. To do so, you must indicate that your expected result type is a list of either `Literal` values or enum members. + +In the following example, two media types are provided as context, and because the result type is a list, the agent is able to produce two responses: + + +```python Code +import controlflow as cf +from typing import Literal + +media = cf.run( + ["Star Wars: Return of the Jedi", "LOST"], + result_type=list[Literal["movie", "tv show", "book", "comic", "other"]] +) + +print(media) +``` + +```text Result +['movie', 'tv show'] +``` + + +In this example, the agent is able to choose multiple values for a single input, and the result is a list of strings: + + +```python Code +import controlflow as cf +from typing import Literal + +tags = cf.run( + 'Star Wars: Return of the Jedi', + instructions="Choose all that apply", + result_type=list[Literal["Star Wars", "space", "pirates", "adventure", "musical"]] +) + +print(tags) +``` + +```text Result +['Star Wars', 'space', 'adventure'] +``` + + + +Labeling multiple inputs at once relies on Python's built-in type annotations and does not provide the same list- and tuple-aware optimizations and sugar that ControlFlow provides for single-choice constrained selection. Therefore the following syntax, which is not considered proper Python, will error: + +```python +cf.run( + ..., + result_type=list[["A", "B"]] +) +``` + +but using a `Literal` or enum will work: + +```python +cf.run( + ..., + result_type=list[Literal["A", "B"]] +) +``` + + ### Pydantic models For complex, structured results, you can use a Pydantic model as the `result_type`. Pydantic models provide a powerful way to define data schemas and validate input data. diff --git a/docs/patterns/tools.mdx b/docs/patterns/tools.mdx index 87c141f2..fd73734d 100644 --- a/docs/patterns/tools.mdx +++ b/docs/patterns/tools.mdx @@ -139,7 +139,7 @@ LangChain has many [pre-built tools](https://python.langchain.com/v0.2/docs/inte First, install the dependencies: ```bash -pip install -U langchain-community, duckduckgo-search +pip install -U langchain-community duckduckgo-search ``` Then import the tool for use: @@ -251,4 +251,4 @@ Tools are particularly useful in scenarios where: By providing appropriate tools, you can significantly enhance the problem-solving capabilities of your AI agents and create more powerful and flexible workflows. -While tools are powerful, they should be used judiciously. Provide only the tools that are necessary for the task at hand to avoid overwhelming the agent with too many options. \ No newline at end of file +While tools are powerful, they should be used judiciously. Provide only the tools that are necessary for the task at hand to avoid overwhelming the agent with too many options. diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx index 02af1b9d..daeaaa8e 100644 --- a/docs/quickstart.mdx +++ b/docs/quickstart.mdx @@ -27,7 +27,7 @@ Next, set up your LLM provider. By default, ControlFlow uses OpenAI, so you'll n export OPENAI_API_KEY="your-api-key" ``` -To use another provider, see the docs on [configuring LLMs](/guides/llms). +To use another provider, see the docs on [configuring LLMs](/guides/configure-llms). ## Create some data @@ -110,7 +110,7 @@ In addition, note that the `result_type` of this task is a list of labels, indic -This example uses an OpenAI model, but you can use any LangChain-compatible LLM here. Follow the instructions in the [LLM docs](/guides/llms) to learn more. +This example uses an OpenAI model, but you can use any LangChain-compatible LLM here. Follow the instructions in the [LLM docs](/guides/configure-llms) to learn more. diff --git a/docs/snippets/version-badge.mdx b/docs/snippets/version-badge.mdx new file mode 100644 index 00000000..47e362b9 --- /dev/null +++ b/docs/snippets/version-badge.mdx @@ -0,0 +1,9 @@ +export const VersionBadge = ({ version }) => { + + + return ( + + New in version {version} + + ); +}; \ No newline at end of file diff --git a/docs/style.css b/docs/style.css new file mode 100644 index 00000000..06b6a616 --- /dev/null +++ b/docs/style.css @@ -0,0 +1,17 @@ +.version-badge { + display: inline-block; + padding: 0.2em 0.5em; + font-size: 0.75em; + font-weight: bold; + color: #e44bf4; + background-color: #fce8fd; + border: 1px solid #f2a5f9; + border-radius: 4px; + vertical-align: middle; +} + +.dark .version-badge { + color: #f17afc; + background-color: rgba(228, 75, 244, 0.2); + border-color: #783d7e; +} diff --git a/examples/anonymization.py b/examples/anonymization.py new file mode 100644 index 00000000..c06a759c --- /dev/null +++ b/examples/anonymization.py @@ -0,0 +1,30 @@ +from pydantic import BaseModel, Field + +import controlflow as cf + + +class AnonymizationResult(BaseModel): + original: str + anonymized: str + replacements: dict[str, str] = Field( + description=r"The replacements made during anonymization, {original} -> {placeholder}" + ) + + +def anonymize_text(text: str) -> AnonymizationResult: + return cf.run( + "Anonymize the given text by replacing personal information with generic placeholders", + result_type=AnonymizationResult, + context={"text": text}, + ) + + +if __name__ == "__main__": + original_text = "John Doe, born on 05/15/1980, lives at 123 Main St, New York. His email is john.doe@example.com." + + result = anonymize_text(original_text) + print(f"Original: {result.original}") + print(f"Anonymized: {result.anonymized}") + print("Replacements:") + for original, placeholder in result.replacements.items(): + print(f" {original} -> {placeholder}") diff --git a/examples/business_headline_sentiment.py b/examples/business_headline_sentiment.py deleted file mode 100644 index 995cb8fb..00000000 --- a/examples/business_headline_sentiment.py +++ /dev/null @@ -1,38 +0,0 @@ -# uv pip install langchain-community, duckduckgo-search - -from langchain_community.tools import DuckDuckGoSearchRun - -import controlflow as cf - -summarizer = cf.Agent( - name="Headline Summarizer", - description="An AI agent that fetches and summarizes current events", - tools=[DuckDuckGoSearchRun()], -) - -extractor = cf.Agent( - name="Entity Extractor", - description="An AI agent that does named entity recognition", -) - - -@cf.flow -def get_headlines(): - summarizer_task = cf.Task( - "Retrieve and summarize today's two top business headlines", - agent=summarizer, - result_type=list[str], - ) - - extractor_task = cf.Task( - "Extract any fortune 500 companies mentioned in the headlines and whether the sentiment is positive, neutral, or negative", - agent=extractor, - depends_on=[summarizer_task], - ) - - return summarizer_task, extractor_task - - -if __name__ == "__main__": - headlines, entity_sentiment = get_headlines() - print(headlines, entity_sentiment) diff --git a/examples/call_routing.py b/examples/call_routing.py new file mode 100644 index 00000000..965bad95 --- /dev/null +++ b/examples/call_routing.py @@ -0,0 +1,78 @@ +import random + +import controlflow as cf + +DEPARTMENTS = [ + "Sales", + "Support", + "Billing", + "Returns", +] + + +@cf.flow +def routing_flow(): + target_department = random.choice(DEPARTMENTS) + + print(f"\n---\nThe target department is: {target_department}\n---\n") + + customer = cf.Agent( + name="Customer", + instructions=f""" + You are training customer reps by pretending to be a customer + calling into a call center. You need to be routed to the + {target_department} department. Come up with a good backstory. + """, + ) + + trainee = cf.Agent( + name="Trainee", + instructions=""", + You are a trainee customer service representative. You need to + listen to the customer's story and route them to the correct + department. Note that the customer is another agent training you. + """, + ) + + with cf.Task( + "Route the customer to the correct department.", + agents=[trainee], + result_type=DEPARTMENTS, + ) as main_task: + while main_task.is_incomplete(): + cf.run( + "Talk to the trainee.", + instructions=( + "Post a message to talk. In order to help the trainee " + "learn, don't be direct about the department you want. " + "Instead, share a story that will let them practice. " + "After you speak, mark this task as complete." + ), + agents=[customer], + result_type=None, + ) + + cf.run( + "Talk to the customer.", + instructions=( + "Post a message to talk. Ask questions to learn more " + "about the customer. After you speak, mark this task as " + "complete. When you have enough information, use the main " + "task tool to route the customer to the correct department." + ), + agents=[trainee], + result_type=None, + tools=[main_task.get_success_tool()], + ) + + if main_task.result == target_department: + print("Success! The customer was routed to the correct department.") + else: + print( + f"Failed. The customer was routed to the wrong department. " + f"The correct department was {target_department}." + ) + + +if __name__ == "__main__": + routing_flow() diff --git a/examples/choose_a_number.py b/examples/choose_a_number.py deleted file mode 100644 index fe8f5e82..00000000 --- a/examples/choose_a_number.py +++ /dev/null @@ -1,16 +0,0 @@ -from controlflow import Agent, Task, flow - -a1 = Agent(name="A1", instructions="You struggle to make decisions.") -a2 = Agent( - name="A2", - instructions="You like to make decisions.", -) - - -@flow -def demo(): - task = Task("choose a number between 1 and 100", agents=[a1, a2], result_type=int) - return task.run() - - -demo() diff --git a/examples/code_explanation.py b/examples/code_explanation.py new file mode 100644 index 00000000..0d47c260 --- /dev/null +++ b/examples/code_explanation.py @@ -0,0 +1,31 @@ +from pydantic import BaseModel + +import controlflow as cf + + +class CodeExplanation(BaseModel): + code: str + explanation: str + language: str + + +def explain_code(code: str, language: str = None) -> CodeExplanation: + return cf.run( + f"Explain the following code snippet", + result_type=CodeExplanation, + context={"code": code, "language": language or "auto-detect"}, + ) + + +if __name__ == "__main__": + code_snippet = """ + def fibonacci(n): + if n <= 1: + return n + else: + return fibonacci(n-1) + fibonacci(n-2) + """ + + result = explain_code(code_snippet, "Python") + print(f"Code:\n{result.code}\n") + print(f"Explanation:\n{result.explanation}") diff --git a/examples/controlflow_docs.py b/examples/controlflow_docs.py deleted file mode 100644 index cedbf95c..00000000 --- a/examples/controlflow_docs.py +++ /dev/null @@ -1,96 +0,0 @@ -from pathlib import Path - -from langchain_openai import OpenAIEmbeddings - -import controlflow as cf -from controlflow.tools import tool - -try: - from langchain_community.document_loaders import DirectoryLoader - from langchain_community.vectorstores import LanceDB - from langchain_text_splitters import ( - MarkdownTextSplitter, - PythonCodeTextSplitter, - ) -except ImportError: - raise ImportError( - "Missing requirements: `pip install lancedb langchain-community langchain-text-splitters unstructured`" - ) - - -def create_code_db(): - # .py files - py_loader = DirectoryLoader( - Path(cf.__file__).parents[2] / "src/controlflow/", glob="**/*.py" - ) - py_raw_documents = py_loader.load() - py_splitter = PythonCodeTextSplitter(chunk_size=1400, chunk_overlap=200) - documents = py_splitter.split_documents(py_raw_documents) - return LanceDB.from_documents(documents, OpenAIEmbeddings()) - - -def create_docs_db(): - # .mdx files - mdx_loader = DirectoryLoader(Path(cf.__file__).parents[2] / "docs", glob="**/*.mdx") - mdx_raw_documents = mdx_loader.load() - mdx_splitter = MarkdownTextSplitter(chunk_size=1400, chunk_overlap=200) - documents = mdx_splitter.split_documents(mdx_raw_documents) - return LanceDB.from_documents(documents, OpenAIEmbeddings()) - - -code_db = create_code_db() -docs_db = create_docs_db() - - -@tool -def search_code(query: str, n=50) -> list[dict]: - """ - Semantic search over the current ControlFlow documentation - - Returns the top `n` results. - """ - results = docs_db.similarity_search(query, k=n) - return [ - dict(content=r.page_content, metadata=r.metadata["metadata"]) for r in results - ] - - -@tool -def search_docs(query: str, n=50) -> list[dict]: - """ - Semantic search over the current ControlFlow documentation - - Returns the top `n` results. - """ - results = docs_db.similarity_search(query, k=n) - return [ - dict(content=r.page_content, metadata=r.metadata["metadata"]) for r in results - ] - - -@tool -def read_file(path: str) -> str: - """ - Read a file from a path. - """ - with open(path) as f: - return f.read() - - -agent = cf.Agent( - "DocsAgent", - description="The agent for the ControlFlow documentation", - instructions="Use your tools to explore the ControlFlow code and documentation. If you find something interesting but only see a snippet with the search tools, use the read_file tool to get the full text.", - tools=[search_code, search_docs, read_file], -) - - -@cf.flow -def write_docs(topic: str): - task = cf.Task( - "Research the provided topic, then produce world-class documentation in the style of the existing docs.", - context=dict(topic=topic), - agents=[agent], - ) - task.generate_subtasks() - return task diff --git a/examples/early_termination.py b/examples/early_termination.py new file mode 100644 index 00000000..45d14ee0 --- /dev/null +++ b/examples/early_termination.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel + +import controlflow as cf +from controlflow.orchestration.conditions import AnyComplete, MaxLLMCalls + + +class ResearchPoint(BaseModel): + topic: str + key_findings: list[str] + + +@cf.flow +def research_workflow(topics: list[str]): + if len(topics) < 2: + raise ValueError("At least two topics are required") + + research_tasks = [ + cf.Task(f"Research {topic}", result_type=ResearchPoint) for topic in topics + ] + + # Run tasks until either two topics are researched or 15 LLM calls are made + results = cf.run_tasks( + research_tasks, + instructions="Research only one topic at a time.", + run_until=( + AnyComplete( + min_complete=2 + ) # stop after two tasks (if there are more than two topics) + | MaxLLMCalls(15) # or stop after 15 LLM calls, whichever comes first + ), + ) + + completed_research = [r for r in results if isinstance(r, ResearchPoint)] + return completed_research + + +if __name__ == "__main__": + # Example usage + topics = [ + "Artificial Intelligence", + "Quantum Computing", + "Biotechnology", + "Renewable Energy", + ] + results = research_workflow(topics) + + print(f"Completed research on {len(results)} topics:") + for research in results: + print(f"\nTopic: {research.topic}") + print("Key Findings:") + for finding in research.key_findings: + print(f"- {finding}") diff --git a/examples/engineer/engineer.py b/examples/engineer/engineer.py deleted file mode 100644 index 3dadfeb9..00000000 --- a/examples/engineer/engineer.py +++ /dev/null @@ -1,76 +0,0 @@ -from pathlib import Path - -from pydantic import BaseModel - -import controlflow as cf -import controlflow.tools.code -import controlflow.tools.filesystem - -# load the instructions -instructions = open(Path(__file__).parent / "instructions.md").read() - -# create the agent -agent = cf.Agent( - "Engineer", - instructions=instructions, - tools=[ - *controlflow.tools.filesystem.ALL_TOOLS, - controlflow.tools.code.python, - controlflow.tools.code.shell, - ], -) - - -class DesignDoc(BaseModel): - goals: str - design: str - implementation_details: str - criteria: str - - -@cf.flow -def run_engineer(): - # the first task is to work with the user to create a design doc - design_doc = cf.Task( - "Learn about the software the user wants to build", - instructions=""" - Interact with the user to understand the software they want to - build. What is its purpose? What language should you use? What does - it need to do? Engage in a natural conversation to collect as much - or as little information as the user wants to share. Once you have - enough, write out a design document to complete the task. - """, - interactive=True, - result_type=DesignDoc, - ) - - # next we create a directory for any files - mkdir = cf.Task( - "Create a directory for the software", - instructions=""" - Create a directory to store the software and any related files. The - directory should be named after the software. Return the path. - """, - result_type=str, - tools=[controlflow.tools.filesystem.mkdir], - agents=[agent], - ) - - # the final task is to write the software - software = cf.Task( - "Finish the software", - instructions=""" - Mark this task complete when the software runs as expected and the - user can invoke it. Until then, continue to build the software. - - All files must be written to the provided root directory. - """, - result_type=None, - context=dict(design_doc=design_doc, root_dir=mkdir), - agents=[agent], - ) - return software - - -if __name__ == "__main__": - run_engineer() diff --git a/examples/engineer/instructions.md b/examples/engineer/instructions.md deleted file mode 100644 index 92384940..00000000 --- a/examples/engineer/instructions.md +++ /dev/null @@ -1,37 +0,0 @@ -# Software Engineer Agent - -## Role and Purpose -You are a software engineer specialized in leveraging large language models (LLMs) to transform user ideas into fully functional software projects. Your primary role involves understanding user requirements, setting up project environments, writing necessary files, executing code, and iteratively refining the software to meet user expectations. - -## Process Overview -1. **Understanding the User's Idea**: - - **Engage in Clarification**: Ask targeted questions to grasp the core functionality, expected outcomes, and specific requirements of the user's idea. - - **Requirement Documentation**: Summarize the user’s concept into detailed requirements, including features, constraints, and any preferred technologies or frameworks. - -2. **Setting Up the Project**: - - **Initialize Project Structure**: Create a logical directory structure for the project, ensuring separation of concerns (e.g., `src/` for source code, `docs/` for documentation). - - **Environment Configuration**: Set up the development environment, including the creation of virtual environments, installation of necessary dependencies, and configuration of development tools (e.g., linters, formatters). - -3. **Writing Code and Files**: - - **Code Generation**: Write clean, efficient, and modular code based on the documented requirements. Ensure that code adheres to best practices and coding standards. - - **Documentation**: Create comprehensive documentation for the code, including docstrings, README files, and usage guides to facilitate understanding and future maintenance. - -4. **Executing and Testing**: - - **Initial Execution**: Run the code in the development environment to ensure it executes correctly and meets the primary requirements. - - **Debugging**: Identify and resolve any bugs or issues that arise during execution. Ensure the code runs smoothly and performs as expected. - -5. **Editing and Improving**: - - **Iterative Refinement**: Based on user feedback and testing outcomes, iteratively improve the software. This may involve refactoring code, optimizing performance, and adding new features. - - **Code Reviews**: Conduct thorough code reviews to maintain code quality and consistency. Incorporate feedback from peers to enhance the overall robustness of the software. - - **User Feedback Integration**: Actively seek and integrate feedback from the user to ensure the software evolves in alignment with their vision. - -## Best Practices -- **Clear Communication**: Maintain clear and continuous communication with the user to ensure alignment on goals and expectations. -- **Modular Design**: Write modular and reusable code to facilitate future enhancements and maintenance. - -## Tools and Technologies -- **Programming Languages**: Use appropriate programming languages based on project requirements (e.g., Python, JavaScript). -- **Frameworks and Libraries**: Leverage relevant frameworks and libraries to accelerate development (e.g., Django, React, TensorFlow). -- **Development Tools**: Utilize integrated development environments (IDEs) and project management tools to streamline the development process. - -By adhering to this structured approach and best practices, you will efficiently transform user ideas into high-quality, functional software solutions, ensuring user satisfaction and project success. \ No newline at end of file diff --git a/examples/generate_people.py b/examples/generate_people.py new file mode 100644 index 00000000..a136ea3a --- /dev/null +++ b/examples/generate_people.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field + +import controlflow as cf + + +class UserProfile(BaseModel): + name: str = Field(description="The full name of the user") + age: int = Field(description="The age of the user, 20-60") + occupation: str = Field(description="The occupation of the user") + hobby: str + + +def generate_profiles(count: int) -> list[UserProfile]: + return cf.run( + f"Generate {count} user profiles", + result_type=list[UserProfile], + context={"count": count}, + ) + + +if __name__ == "__main__": + test_data = generate_profiles(count=5) + + from rich import print + + print(test_data) diff --git a/examples/headline_categorization.py b/examples/headline_categorization.py new file mode 100644 index 00000000..5693256c --- /dev/null +++ b/examples/headline_categorization.py @@ -0,0 +1,24 @@ +import controlflow as cf + +classifier = cf.Agent(model="openai/gpt-4o-mini") + + +def classify_news(headline: str) -> str: + return cf.run( + "Classify the news headline into the most appropriate category", + agents=[classifier], + result_type=["Politics", "Technology", "Sports", "Entertainment", "Science"], + context={"headline": headline}, + ) + + +if __name__ == "__main__": + headline = "New AI Model Breaks Records in Language Understanding" + category = classify_news(headline) + print(f"Headline: {headline}") + print(f"Category: {category}") + + headline = "Scientists Discover Potentially Habitable Exoplanet" + category = classify_news(headline) + print(f"\nHeadline: {headline}") + print(f"Category: {category}") diff --git a/examples/language_tutor.py b/examples/language_tutor.py new file mode 100644 index 00000000..d9b4905b --- /dev/null +++ b/examples/language_tutor.py @@ -0,0 +1,75 @@ +from pydantic import BaseModel + +import controlflow as cf + + +class Lesson(BaseModel): + topic: str + content: str + exercises: list[str] + + +def language_learning_session(language: str) -> None: + tutor = cf.Agent( + name="Tutor", + instructions=""" + You are a friendly and encouraging language tutor. Your goal is to create an + engaging and supportive learning environment. Always maintain a warm tone, + offer praise for efforts, and provide gentle corrections. Adapt your teaching + style to the user's needs and pace. Use casual language to keep the + conversation light and fun. When working through exercises: + - Present one exercise at a time. + - Provide hints if the user is struggling. + - Offer the correct answer if the user can't solve it after a few attempts. + - Use encouraging language throughout the process. + """, + ) + + @cf.flow(default_agent=tutor) + def learning_flow(): + user_name = cf.run( + f"Greet the user, learn their name, and introduce the {language} learning session", + interactive=True, + result_type=str, + ) + + print(f"\nWelcome, {user_name}! Let's start your {language} lesson.\n") + + while True: + lesson = cf.run( + "Create a fun and engaging language lesson", result_type=Lesson + ) + + print(f"\nToday's topic: {lesson.topic}") + print(f"Lesson content: {lesson.content}\n") + + for exercise in lesson.exercises: + print(f"Exercise: {exercise}") + cf.run( + "Work through the exercise with the user", + interactive=True, + context={"exercise": exercise}, + ) + + continue_learning = cf.run( + "Check if the user wants to continue learning", + result_type=bool, + interactive=True, + ) + + if not continue_learning: + break + + summary = cf.run( + "Summarize the learning session and provide encouragement", + context={"user_name": user_name}, + result_type=str, + ) + print(f"\nSession summary: {summary}") + + learning_flow() + + +if __name__ == "__main__": + language = input("Which language would you like to learn? ") + language_learning_session(language) diff --git a/examples/memory.py b/examples/memory.py new file mode 100644 index 00000000..53d810a6 --- /dev/null +++ b/examples/memory.py @@ -0,0 +1,37 @@ +import controlflow as cf + +# Create a memory module for user preferences +user_preferences = cf.Memory( + key="user_preferences", instructions="Store and retrieve user preferences." +) + +# Create an agent with access to the memory +agent = cf.Agent(memories=[user_preferences]) + + +# Create a flow to ask for the user's favorite color +@cf.flow +def remember_color(): + return cf.run( + "Ask the user for their favorite color and store it in memory", + agents=[agent], + interactive=True, + ) + + +# Create a flow to recall the user's favorite color +@cf.flow +def recall_color(): + return cf.run( + "What is the user's favorite color?", + agents=[agent], + ) + + +if __name__ == "__main__": + print("First flow:") + remember_color() + + print("\nSecond flow:") + result = recall_color() + print(result) diff --git a/examples/memory_between_flows.py b/examples/memory_between_flows.py deleted file mode 100644 index 75067273..00000000 --- a/examples/memory_between_flows.py +++ /dev/null @@ -1,20 +0,0 @@ -import controlflow as cf - -thread_id = "test-thread" - - -@cf.flow(thread=thread_id) -def flow_1(): - task = cf.Task("get the user's name", result_type=str, interactive=True) - return task - - -@cf.flow(thread=thread_id) -def flow_2(): - task = cf.Task("write the user's name backwards, if you don't know it, say so") - return task - - -if __name__ == "__main__": - flow_1() - flow_2() diff --git a/examples/named_entity_recognition.py b/examples/named_entity_recognition.py new file mode 100644 index 00000000..91d9c828 --- /dev/null +++ b/examples/named_entity_recognition.py @@ -0,0 +1,47 @@ +from typing import Dict, List + +import controlflow as cf + +extractor = cf.Agent( + name="Named Entity Recognizer", + model="openai/gpt-4o-mini", +) + + +def extract_entities(text: str) -> List[str]: + return cf.run( + "Extract all named entities from the text", + agents=[extractor], + result_type=List[str], + context={"text": text}, + ) + + +def extract_categorized_entities(text: str) -> Dict[str, List[str]]: + return cf.run( + "Extract named entities from the text and categorize them", + instructions=""" + Return a dictionary with the following keys: + - 'persons': List of person names + - 'organizations': List of organization names + - 'locations': List of location names + - 'dates': List of date references + - 'events': List of event names + Only include keys if entities of that type are found in the text. + """, + agents=[extractor], + result_type=Dict[str, List[str]], + context={"text": text}, + ) + + +if __name__ == "__main__": + text = "Apple Inc. is planning to open a new store in New York City next month." + entities = extract_entities(text) + print("Simple extraction:") + print(entities) + + text = "In 1969, Neil Armstrong became the first person to walk on the Moon during the Apollo 11 mission." + categorized_entities = extract_categorized_entities(text) + print("\nCategorized extraction:") + print(categorized_entities) diff --git a/examples/pineapple_pizza.py b/examples/pineapple_pizza.py new file mode 100644 index 00000000..fc019944 --- /dev/null +++ b/examples/pineapple_pizza.py @@ -0,0 +1,33 @@ +import controlflow as cf + +optimist = cf.Agent( + name="Half-full", + instructions="You are an eternal optimist.", +) +pessimist = cf.Agent( + name="Half-empty", + instructions="You are an eternal pessimist.", +) +moderator = cf.Agent(name="Moderator") + + +@cf.flow +def demo(topic: str): + cf.run( + "Have a debate about the topic.", + instructions="Each agent should take at least two turns.", + agents=[optimist, pessimist], + context={"topic": topic}, + ) + + winner: cf.Agent = cf.run( + "Whose argument do you find more compelling?", + agents=[moderator], + result_type=[optimist, pessimist], + ) + + print(f"{winner.name} wins the debate!") + + +if __name__ == "__main__": + demo("pineapple on pizza") diff --git a/examples/poem.py b/examples/poem.py deleted file mode 100644 index 60329d8f..00000000 --- a/examples/poem.py +++ /dev/null @@ -1,43 +0,0 @@ -from pydantic import BaseModel - -from controlflow import Agent, Task, flow, instructions, task - - -class Name(BaseModel): - first_name: str - last_name: str - - -@task(interactive=True) -def get_user_name() -> Name: - pass - - -@task(agents=[Agent(name="poetry-bot", instructions="loves limericks")]) -def write_poem_about_user(name: Name, interests: list[str]) -> str: - """write a poem based on the provided `name` and `interests`""" - pass - - -@flow() -def demo(): - # set instructions that will be used for multiple tasks - with instructions("talk like a pirate"): - # define an AI task as a function - name = get_user_name() - - # define an AI task imperatively - interests = Task( - "ask user for three interests", result_type=list[str], interactive=True - ) - interests.run() - - # set instructions for just the next task - with instructions("no more than 8 lines"): - poem = write_poem_about_user(name, interests.result) - - return poem - - -if __name__ == "__main__": - demo() diff --git a/examples/private_flows.py b/examples/private_flows.py new file mode 100644 index 00000000..4a768e02 --- /dev/null +++ b/examples/private_flows.py @@ -0,0 +1,30 @@ +import controlflow as cf + + +@cf.flow(args_as_context=False) +def process_user_data(user_name: str, sensitive_info: str): + # Main flow context + print(f"Processing data for user: {user_name}") + + # Create a private flow to handle sensitive information + with cf.Flow() as private_flow: + # This task runs in an isolated context + masked_info = cf.run( + "Mask the sensitive information", + context={"sensitive_info": sensitive_info}, + result_type=str, + ) + + # Task in the main flow can be provided the masked_info as context + summary = cf.run( + "Summarize the data processing result", + context={"user_name": user_name, "masked_info": masked_info}, + result_type=str, + ) + + return summary + + +if __name__ == "__main__": + result = process_user_data("Alice", "SSN: 123-45-6789") + print(result) diff --git a/examples/reasoning.py b/examples/reasoning.py new file mode 100644 index 00000000..a5c242fa --- /dev/null +++ b/examples/reasoning.py @@ -0,0 +1,118 @@ +""" +This example implements a reasoning loop that lets a relatively simple model +solve difficult problems. + +Here, gpt-4o-mini is used to solve a problem that typically requires o1's +reasoning ability. +""" + +import argparse + +from pydantic import BaseModel, Field + +import controlflow as cf +from controlflow.utilities.general import unwrap + + +class ReasoningStep(BaseModel): + explanation: str = Field( + description=""" + A brief (<5 words) description of what you intend to + achieve in this step, to display to the user. + """ + ) + reasoning: str = Field( + description="A single step of reasoning, not more than 1 or 2 sentences." + ) + found_validated_solution: bool + + +REASONING_INSTRUCTIONS = """ + You are working on solving a difficult problem (the `goal`). Based + on your previous thoughts and the overall goal, please perform **one + reasoning step** that advances you closer to a solution. Document + your thought process and any intermediate steps you take. + + After marking this task complete for a single step, you will be + given a new reasoning task to continue working on the problem. The + loop will continue until you have a valid solution. + + Complete the task as soon as you have a valid solution. + + **Guidelines** + + - You will not be able to brute force a solution exhaustively. You + must use your reasoning ability to make a plan that lets you make + progress. + - Each step should be focused on a specific aspect of the problem, + either advancing your understanding of the problem or validating a + solution. + - You should build on previous steps without repeating them. + - Since you will iterate your reasoning, you can explore multiple + approaches in different steps. + - Use logical and analytical thinking to reason through the problem. + - Ensure that your solution is valid and meets all requirements. + - If you find yourself spinning your wheels, take a step back and + re-evaluate your approach. +""" + + +@cf.flow +def solve_with_reasoning(goal: str, agent: cf.Agent) -> str: + while True: + response: ReasoningStep = cf.run( + objective=""" + Carefully read the `goal` and analyze the problem. + + Produce a single step of reasoning that advances you closer to a solution. + """, + instructions=REASONING_INSTRUCTIONS, + result_type=ReasoningStep, + agents=[agent], + context=dict(goal=goal), + model_kwargs=dict(tool_choice="required"), + ) + + if response.found_validated_solution: + if cf.run( + """ + Check your solution to be absolutely sure that it is correct and meets all requirements of the goal. Return True if it does. + """, + result_type=bool, + context=dict(goal=goal), + ): + break + + return cf.run(objective=goal, agents=[agent]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Solve a reasoning problem.") + parser.add_argument("--goal", type=str, help="Custom goal to solve", default=None) + args = parser.parse_args() + + agent = cf.Agent(name="Definitely not GPT-4o mini", model="openai/gpt-4o-mini") + + # Default goal via https://www.reddit.com/r/singularity/comments/1fggo1e/comment/ln3ymsu/ + default_goal = """ + Using only four instances of the digit 9 and any combination of the following + mathematical operations: the decimal point, parentheses, addition (+), + subtraction (-), multiplication (*), division (/), factorial (!), and square + root (sqrt), create an equation that equals 24. + + In order to validate your result, you should test that you have followed the rules: + + 1. You have used the correct number of variables + 2. You have only used 9s and potentially a leading 0 for a decimal + 3. You have used valid mathematical symbols + 4. Your equation truly equates to 24. + """ + + # Use the provided goal if available, otherwise use the default + goal = args.goal if args.goal is not None else default_goal + goal = unwrap(goal) + print(f"The goal is:\n\n{goal}") + + result = solve_with_reasoning(goal=goal, agent=agent) + + print(f"The solution is:\n\n{result}") diff --git a/examples/restaurant_recs.py b/examples/restaurant_recs.py deleted file mode 100644 index e8258a20..00000000 --- a/examples/restaurant_recs.py +++ /dev/null @@ -1,34 +0,0 @@ -from pydantic import BaseModel - -from controlflow import Task, flow - - -class Restaurant(BaseModel): - name: str - description: str - - -@flow -def restaurant_recs(n: int) -> list[Restaurant]: - """ - An agentic workflow that asks the user for their location and - cuisine preference, then recommends n restaurants based on their input. - """ - - # get the user's location - location = Task("Get a location", interactive=True) - - # get the user's preferred cuisine - cuisine = Task("Get a preferred cuisine", interactive=True) - - # generate the recommendations from the user's input - recs = Task( - f"Recommend {n} restaurants to the user", - context=dict(location=location, cuisine=cuisine), - result_type=list[Restaurant], - ) - return recs - - -if __name__ == "__main__": - restaurant_recs(5) diff --git a/examples/rock_paper_scissors.py b/examples/rock_paper_scissors.py new file mode 100644 index 00000000..72078642 --- /dev/null +++ b/examples/rock_paper_scissors.py @@ -0,0 +1,35 @@ +import controlflow as cf + + +@cf.flow +def rock_paper_scissors(): + """Play rock, paper, scissors against an AI.""" + play_again = True + + while play_again: + # Get the user's choice on a private thread + with cf.Flow(): + user_choice = cf.run( + "Get the user's choice", + result_type=["rock", "paper", "scissors"], + interactive=True, + ) + + # Get the AI's choice on a private thread + with cf.Flow(): + ai_choice = cf.run( + "Choose rock, paper, or scissors", + result_type=["rock", "paper", "scissors"], + ) + + # Report the score and ask if the user wants to play again + play_again = cf.run( + "Report the score to the user and see if they want to play again.", + interactive=True, + context={"user_choice": user_choice, "ai_choice": ai_choice}, + result_type=bool, + ) + + +if __name__ == "__main__": + rock_paper_scissors() diff --git a/examples/multi_agent_conversation.py b/examples/seinfeld.py similarity index 91% rename from examples/multi_agent_conversation.py rename to examples/seinfeld.py index 15d4669e..1bad21cd 100644 --- a/examples/multi_agent_conversation.py +++ b/examples/seinfeld.py @@ -1,3 +1,5 @@ +import sys + from controlflow import Agent, Task, flow jerry = Agent( @@ -66,12 +68,19 @@ def demo(topic: str): task = Task( "Discuss a topic", agents=[jerry, george, elaine, kramer, newman], + completion_agents=[jerry], result_type=None, context=dict(topic=topic), - instructions="every agent should speak at least once. only one agent per turn. Keep responses 1-2 paragraphs max.", + instructions="Every agent should speak at least once. only one agent per turn. Keep responses 1-2 paragraphs max.", ) task.run() if __name__ == "__main__": - demo(topic="sandwiches") + if len(sys.argv) > 1: + topic = sys.argv[1] + else: + topic = "sandwiches" + + print(f"Topic: {topic}") + demo(topic=topic) diff --git a/examples/sentiment_classifier.py b/examples/sentiment_classifier.py new file mode 100644 index 00000000..b65a6dcc --- /dev/null +++ b/examples/sentiment_classifier.py @@ -0,0 +1,29 @@ +import controlflow as cf +from controlflow.tasks.validators import between + +optimist = cf.Agent(model="openai/gpt-4o-mini") + + +def sentiment(text: str) -> float: + return cf.run( + "Classify the sentiment of the text as a value between 0 and 1", + agents=[optimist], + result_type=float, + result_validator=between(0, 1), + context={"text": text}, + ) + + +if __name__ == "__main__": + print(sentiment("I love ControlFlow!")) + + long_text = """ + Far out in the uncharted backwaters of the unfashionable end of + the western spiral arm of the Galaxy lies a small unregarded yellow sun. + Orbiting this at a distance of roughly ninety-two million miles is an utterly + insignificant little blue-green planet whose ape-descended life forms are so + amazingly primitive that they still think digital watches are a pretty neat + idea. This planet has – or rather had – a problem, which was this: most of + the people living on it were unhappy for pretty much of the time. + """ + print(sentiment(long_text)) diff --git a/examples/standardize_addresses.py b/examples/standardize_addresses.py new file mode 100644 index 00000000..b8626ef9 --- /dev/null +++ b/examples/standardize_addresses.py @@ -0,0 +1,38 @@ +from typing import List + +from pydantic import BaseModel + +import controlflow as cf + + +class StandardAddress(BaseModel): + city: str + state: str + country: str = "USA" + + +def standardize_addresses(place_names: List[str]) -> List[StandardAddress]: + return cf.run( + "Standardize the given place names into consistent postal addresses", + result_type=List[StandardAddress], + context={"place_names": place_names}, + ) + + +if __name__ == "__main__": + place_names = [ + "NYC", + "New York, NY", + "Big Apple", + "Los Angeles, California", + "LA", + "San Fran", + "The Windy City", + ] + + standardized_addresses = standardize_addresses(place_names) + + for original, standard in zip(place_names, standardized_addresses): + print(f"Original: {original}") + print(f"Standardized: {standard}") + print() diff --git a/examples/summarization.py b/examples/summarization.py new file mode 100644 index 00000000..9f9fa639 --- /dev/null +++ b/examples/summarization.py @@ -0,0 +1,40 @@ +from pydantic import BaseModel + +import controlflow as cf + + +class Summary(BaseModel): + summary: str + key_points: list[str] + + +def summarize_text(text: str, max_words: int = 100) -> Summary: + return cf.run( + f"Summarize the given text in no more than {max_words} words and list key points", + result_type=Summary, + context={"text": text}, + ) + + +if __name__ == "__main__": + long_text = """ + The Internet of Things (IoT) is transforming the way we interact with our + environment. It refers to the vast network of connected devices that collect + and share data in real-time. These devices range from simple sensors to + sophisticated wearables and smart home systems. The IoT has applications in + various fields, including healthcare, agriculture, and urban planning. In + healthcare, IoT devices can monitor patients remotely, improving care and + reducing hospital visits. In agriculture, sensors can track soil moisture and + crop health, enabling more efficient farming practices. Smart cities use IoT to + manage traffic, reduce energy consumption, and enhance public safety. However, + the IoT also raises concerns about data privacy and security, as these + interconnected devices can be vulnerable to cyber attacks. As the technology + continues to evolve, addressing these challenges will be crucial for the + widespread adoption and success of IoT. + """ + + result = summarize_text(long_text) + print(f"Summary:\n{result.summary}\n") + print("Key Points:") + for point in result.key_points: + print(f"- {point}") diff --git a/examples/task_dag.py b/examples/task_dag.py deleted file mode 100644 index 25b71994..00000000 --- a/examples/task_dag.py +++ /dev/null @@ -1,34 +0,0 @@ -import controlflow -from controlflow import Task, flow - -controlflow.settings.enable_experimental_tui = True - - -@flow -def book_ideas(): - genre = Task("pick a genre") - - ideas = Task( - "generate three short ideas for a book", - list[str], - context=dict(genre=genre), - ) - - abstract = Task( - "pick one idea and write a short abstract", - result_type=str, - context=dict(ideas=ideas, genre=genre), - ) - - title = Task( - "pick a title", - result_type=str, - context=dict(abstract=abstract), - ) - - return dict(genre=genre, ideas=ideas, abstract=abstract, title=title) - - -if __name__ == "__main__": - result = book_ideas() - print(result) diff --git a/examples/teacher_student.py b/examples/teacher_student.py deleted file mode 100644 index 549c195a..00000000 --- a/examples/teacher_student.py +++ /dev/null @@ -1,34 +0,0 @@ -from controlflow import Agent, Task, flow -from controlflow.instructions import instructions - -teacher = Agent(name="Teacher") -student = Agent(name="Student") - - -@flow -def demo(): - with Task("Teach a class by asking and answering 3 questions", agents=[teacher]): - for _ in range(3): - question = Task( - "Ask the student a question.", result_type=str, agents=[teacher] - ) - - with instructions("One sentence max"): - answer = Task( - "Answer the question.", - agents=[student], - context=dict(question=question), - ) - - grade = Task( - "Assess the answer.", - result_type=["pass", "fail"], - agents=[teacher], - context=dict(answer=answer), - ) - - # run each qa session, one at a time - grade.run() - - -t = demo() diff --git a/examples/translation.py b/examples/translation.py new file mode 100644 index 00000000..f4b1d312 --- /dev/null +++ b/examples/translation.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel + +import controlflow as cf + + +class TranslationResult(BaseModel): + translated: str + target_language: str + + +def translate_text(text: str, target_language: str) -> TranslationResult: + return cf.run( + f"Translate the given text to {target_language}", + result_type=TranslationResult, + context={"text": text, "target_language": target_language}, + ) + + +if __name__ == "__main__": + original_text = "Hello, how are you?" + target_language = "French" + + result = translate_text(original_text, target_language) + print(f"Original: {original_text}") + print(f"Translated ({result.target_language}): {result.translated}") diff --git a/examples/write_and_critique_paper.py b/examples/write_and_critique_paper.py deleted file mode 100644 index 522341be..00000000 --- a/examples/write_and_critique_paper.py +++ /dev/null @@ -1,30 +0,0 @@ -from controlflow import Agent, Task - -writer = Agent(name="writer") -editor = Agent(name="editor", instructions="you always find at least one problem") -critic = Agent(name="critic") - - -# ai tasks: -# - automatically supply context from kwargs -# - automatically wrap sub tasks in parent -# - automatically iterate over sub tasks if they are all completed but the parent isn't? - - -def write_paper(topic: str) -> str: - """ - Write a paragraph on the topic - """ - draft = Task( - "produce a 3-sentence draft on the topic", - str, - # agents=[writer], - context=dict(topic=topic), - ) - edits = Task("edit the draft", str, agents=[editor], depends_on=[draft]) - critique = Task("is it good enough?", bool, agents=[critic], depends_on=[edits]) - return critique - - -task = write_paper("AI and the future of work") -task.run() diff --git a/pyproject.toml b/pyproject.toml index bff67093..07aec50d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,14 +8,15 @@ authors = [ dependencies = [ "prefect>=3.0", "jinja2>=3.1.4", - "langchain_core>=0.2.9", - "langchain_openai>=0.1.8", - "langchain-anthropic>=0.1.19", + "langchain_core>=0.3", + "langchain_openai>=0.2", + "langchain-anthropic>=0.2", "markdownify>=0.12.1", + "openai<1.47", # 1.47.0 introduced a bug with attempting to reuse an async client that doesnt have an obvious solution "pydantic-settings>=2.2.1", "textual>=0.61.1", "tiktoken>=0.7.0", - "typer[all]>=0.10", + "typer>=0.10", ] readme = "README.md" requires-python = ">= 3.9" @@ -44,6 +45,11 @@ Code = "https://github.com/PrefectHQ/ControlFlow" [project.optional-dependencies] tests = [ + "chromadb", + "duckduckgo-search", + "langchain_community", + "langchain_google_genai", + "langchain_groq", "pytest-asyncio>=0.18.2,!=0.22.0,<0.23.0", "pytest-env>=0.8,<2.0", "pytest-rerunfailures>=10,<14", @@ -51,10 +57,6 @@ tests = [ "pytest>=7.0", "pytest-timeout", "pytest-xdist", - "langchain_community", - "langchain_google_genai", - "langchain_groq", - "duckduckgo-search", ] dev = [ "controlflow[tests]", @@ -82,7 +84,7 @@ managed = true # ruff configuration [tool.ruff] target-version = "py311" -lint.extend-select = ["I"] +lint.select = ["I"] # Changed from lint.extend-select to select lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # default, but here in case we want to change it [tool.ruff.format] diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index c8a86d11..27a0218c 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -11,12 +11,14 @@ from .tasks import Task from .flows import Flow -# functions and decorators +# functions, utilites, and decorators +from .memory import Memory from .instructions import instructions from .decorators import flow, task from .tools import tool from .run import run, run_async, run_tasks, run_tasks_async from .plan import plan +import controlflow.orchestration # --- Version --- diff --git a/src/controlflow/agents/__init__.py b/src/controlflow/agents/__init__.py index 1437945a..d2361b7a 100644 --- a/src/controlflow/agents/__init__.py +++ b/src/controlflow/agents/__init__.py @@ -1,2 +1 @@ -from . import memory from .agent import Agent diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 0b391950..5a38fd9e 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -14,6 +14,7 @@ from langchain_core.language_models import BaseChatModel from pydantic import Field, field_serializer, field_validator +from typing_extensions import Self import controlflow from controlflow.agents.names import AGENT_NAMES @@ -22,6 +23,7 @@ from controlflow.llm.messages import AIMessage, BaseMessage from controlflow.llm.models import get_model as get_model_from_string from controlflow.llm.rules import LLMRules +from controlflow.memory import Memory from controlflow.tools.tools import ( Tool, as_lc_tools, @@ -30,11 +32,9 @@ handle_tool_call_async, ) from controlflow.utilities.context import ctx -from controlflow.utilities.general import ControlFlowModel, hash_objects +from controlflow.utilities.general import ControlFlowModel, hash_objects, unwrap from controlflow.utilities.prefect import create_markdown_artifact, prefect_task -from .memory import Memory - if TYPE_CHECKING: from controlflow.orchestration.handler import Handler from controlflow.orchestration.turn_strategies import TurnStrategy @@ -71,20 +71,20 @@ class Agent(ControlFlowModel, abc.ABC): False, description="If True, the agent is given tools for interacting with a human user.", ) - memory: Optional[Memory] = Field( - default=None, - # default_factory=ThreadMemory, - description="The memory object used by the agent. If not specified, an in-memory memory object will be used. Pass None to disable memory.", - exclude=True, + memories: list[Memory] = Field( + default=[], + description="A list of memory modules for the agent to use.", ) - # note: `model` should be typed as Optional[BaseChatModel] but V2 models can't have - # V1 attributes without erroring, so we have to use Any. - model: Optional[Union[str, Any]] = Field( + model: Optional[Union[str, BaseChatModel]] = Field( None, description="The LangChain BaseChatModel used by the agent. If not provided, the default model will be used. A compatible string can be passed to automatically retrieve the model.", exclude=True, ) + llm_rules: Optional[LLMRules] = Field( + None, + description="The LLM rules for the agent. If not provided, the rules will be inferred from the model (if possible).", + ) _cm_stack: list[contextmanager] = [] @@ -128,12 +128,18 @@ def _generate_id(self): ) ) + @field_validator("instructions") + def _validate_instructions(cls, v): + if v: + v = unwrap(v) + return v + @field_validator("tools", mode="before") def _validate_tools(cls, tools: list[Tool]): return as_tools(tools or []) @field_validator("model", mode="before") - def _validate_model(cls, model: Optional[Union[str, Any]]): + def _validate_model(cls, model: Optional[Union[str, BaseChatModel]]): if isinstance(model, str): return get_model_from_string(model) return model @@ -141,8 +147,7 @@ def _validate_model(cls, model: Optional[Union[str, Any]]): @field_serializer("tools") def _serialize_tools(self, tools: list[Tool]): tools = controlflow.tools.as_tools(tools) - # tools are Pydantic 1 objects - return [t.dict(include={"name", "description"}) for t in tools] + return [t.model_dump(include={"name", "description"}) for t in tools] def serialize_for_prompt(self) -> dict: dct = self.model_dump( @@ -169,7 +174,10 @@ def get_llm_rules(self) -> LLMRules: """ Retrieve the LLM rules for this agent's model """ - return controlflow.llm.rules.rules_for_model(self.get_model()) + if self.llm_rules is None: + return controlflow.llm.rules.rules_for_model(self.get_model()) + else: + return self.llm_rules def get_tools(self) -> list["Tool"]: from controlflow.tools.input import cli_input @@ -177,8 +185,8 @@ def get_tools(self) -> list["Tool"]: tools = self.tools.copy() if self.interactive: tools.append(cli_input) - if self.memory is not None: - tools.extend(self.memory.get_tools()) + for memory in self.memories: + tools.extend(memory.get_tools()) return as_tools(tools) @@ -189,11 +197,11 @@ def get_prompt(self) -> str: return template.render() @contextmanager - def create_context(self): + def create_context(self) -> Generator[Self, None, None]: with ctx(agent=self): yield self - def __enter__(self): + def __enter__(self) -> Self: self._cm_stack.append(self.create_context()) return self._cm_stack[-1].__enter__() @@ -269,6 +277,7 @@ def _run_model( messages: list[BaseMessage], tools: list["Tool"], stream: bool = True, + model_kwargs: Optional[dict] = None, ) -> Generator[Event, None, None]: from controlflow.events.events import ( AgentMessage, @@ -286,7 +295,7 @@ def _run_model( if stream: response = None - for delta in model.stream(messages): + for delta in model.stream(messages, **(model_kwargs or {})): if response is None: response = delta else: @@ -298,14 +307,13 @@ def _run_model( response: AIMessage = model.invoke(messages) yield AgentMessage(agent=self, message=response) - create_markdown_artifact( markdown=f""" {response.content or '(No content)'} #### Payload ```json -{response.json(indent=2)} +{response.model_dump_json(indent=2)} ``` """, description=f"LLM Response for Agent {self.name}", @@ -326,6 +334,7 @@ async def _run_model_async( messages: list[BaseMessage], tools: list["Tool"], stream: bool = True, + model_kwargs: Optional[dict] = None, ) -> AsyncGenerator[Event, None]: from controlflow.events.events import ( AgentMessage, @@ -343,7 +352,7 @@ async def _run_model_async( if stream: response = None - async for delta in model.astream(messages): + async for delta in model.astream(messages, **(model_kwargs or {})): if response is None: response = delta else: @@ -362,7 +371,7 @@ async def _run_model_async( #### Payload ```json -{response.json(indent=2)} +{response.model_dump_json(indent=2)} ``` """, description=f"LLM Response for Agent {self.name}", diff --git a/src/controlflow/agents/memory.py b/src/controlflow/agents/memory.py deleted file mode 100644 index 0a4841af..00000000 --- a/src/controlflow/agents/memory.py +++ /dev/null @@ -1,99 +0,0 @@ -import abc -import uuid -from typing import TYPE_CHECKING, ClassVar, Optional, cast - -from pydantic import Field - -from controlflow.utilities.context import ctx -from controlflow.utilities.general import ControlFlowModel - -if TYPE_CHECKING: - from controlflow.tools import Tool - - -class Memory(ControlFlowModel, abc.ABC): - id: str = Field(default_factory=lambda: uuid.uuid4().hex) - - def load(self) -> dict[int, str]: - """ - Load all memories as a dictionary of index to value. - """ - raise NotImplementedError() - - def update(self, value: str, index: int = None): - """ - Store a value, optionally overwriting an existing value at the given index. - """ - raise NotImplementedError() - - def delete(self, index: int): - raise NotImplementedError() - - def get_tools(self) -> list["Tool"]: - from controlflow.tools import Tool - - update_tool = Tool.from_function( - self.update, - name="update_memory", - description="Privately remember an idea or fact, optionally updating the existing memory at `index`", - ) - delete_tool = Tool.from_function( - self.delete, - name="delete_memory", - description="Forget the private memory at `index`", - ) - - tools = [update_tool, delete_tool] - - return tools - - -class AgentMemory(Memory): - """ - In-memory store for an agent. Memories are scoped to the agent. - - Note memories may persist across flows. - """ - - _memory: list[str] = [] - - def update(self, value: str, index: int = None): - if index is not None: - self._memory[index] = value - else: - self._memory.append(value) - - def load(self, thread_id: str) -> dict[int, str]: - return dict(enumerate(self._memory)) - - def delete(self, index: int): - del self._memory[index] - - -class ThreadMemory(Memory): - """ - In-memory store for an agent. Memories are scoped to each thread. - """ - - _memory: ClassVar[dict[str, list[str]]] = {} - - def _get_thread_id(self) -> Optional[str]: - from controlflow.flows import Flow - - if flow := ctx.get("flow", None): # type: Flow - flow = cast(Flow, flow) - return flow.thread_id - - def update(self, value: str, index: int = None): - thread_id = self._get_thread_id() - if index is not None: - self._memory[thread_id][index] = value - else: - self._memory[thread_id].append(value) - - def load(self, thread_id: str) -> dict[int, str]: - return dict(enumerate(self._memory.get(thread_id, []))) - - def delete(self, index: int): - thread_id = self._get_thread_id() - del self._memory[thread_id][index] diff --git a/src/controlflow/decorators.py b/src/controlflow/decorators.py index 4305a42c..5f175948 100644 --- a/src/controlflow/decorators.py +++ b/src/controlflow/decorators.py @@ -1,3 +1,4 @@ +import asyncio import functools import inspect from typing import Any, Callable, Optional, Union @@ -67,19 +68,7 @@ def flow( sig = inspect.signature(fn) - # the flow decorator creates a proper prefect flow - @prefect_flow( - timeout_seconds=timeout_seconds, - retries=retries, - retry_delay_seconds=retry_delay_seconds, - **(prefect_kwargs or {}), - ) - @functools.wraps(fn) - def wrapper( - *wrapper_args, - flow_kwargs: dict = None, - **wrapper_kwargs, - ): + def _inner_wrapper(*wrapper_args, flow_kwargs: dict = None, **wrapper_kwargs): # first process callargs bound = sig.bind(*wrapper_args, **wrapper_kwargs) bound.apply_defaults() @@ -108,6 +97,23 @@ def wrapper( ): return fn(*wrapper_args, **wrapper_kwargs) + if asyncio.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def wrapper(*wrapper_args, **wrapper_kwargs): + return await _inner_wrapper(*wrapper_args, **wrapper_kwargs) + else: + + @functools.wraps(fn) + def wrapper(*wrapper_args, **wrapper_kwargs): + return _inner_wrapper(*wrapper_args, **wrapper_kwargs) + + wrapper = prefect_flow( + timeout_seconds=timeout_seconds, + retries=retries, + retry_delay_seconds=retry_delay_seconds, + **(prefect_kwargs or {}), + )(wrapper) return wrapper @@ -195,18 +201,24 @@ def _get_task(*args, **kwargs) -> Task: **task_kwargs, ) - @functools.wraps(fn) - @prefect_task( + if asyncio.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + task = _get_task(*args, **kwargs) + return await task.run_async() + else: + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + task = _get_task(*args, **kwargs) + return task.run() + + wrapper = prefect_task( timeout_seconds=timeout_seconds, retries=retries, retry_delay_seconds=retry_delay_seconds, - ) - def wrapper( - *args, - **kwargs, - ): - task = _get_task(*args, **kwargs) - return task.run() + )(wrapper) # store the `as_task` method for loading the task object wrapper.as_task = _get_task diff --git a/src/controlflow/defaults.py b/src/controlflow/defaults.py index 53204294..39e52ad9 100644 --- a/src/controlflow/defaults.py +++ b/src/controlflow/defaults.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Union from pydantic import field_validator @@ -6,10 +6,11 @@ import controlflow.utilities import controlflow.utilities.logging from controlflow.llm.models import BaseChatModel +from controlflow.memory.memory import MemoryProvider, get_memory_provider from controlflow.utilities.general import ControlFlowModel from .agents import Agent -from .events.history import History, InMemoryHistory +from .events.history import FileHistory, History, InMemoryHistory from .llm.models import _get_initial_default_model, get_model __all__ = ["defaults"] @@ -19,6 +20,10 @@ _default_model = _get_initial_default_model() _default_history = InMemoryHistory() _default_agent = Agent(name="Marvin") +try: + _default_memory_provider = get_memory_provider(controlflow.settings.memory_provider) +except Exception: + _default_memory_provider = controlflow.settings.memory_provider class Defaults(ControlFlowModel): @@ -34,18 +39,20 @@ class Defaults(ControlFlowModel): model: Optional[Any] history: History agent: Agent + memory_provider: Optional[Union[MemoryProvider, str]] # add more defaults here def __repr__(self) -> str: fields = ", ".join(self.model_fields.keys()) return f"" - @field_validator("model") + @field_validator("model", mode="before") def _model(cls, v): if isinstance(v, str): v = get_model(v) - elif v is not None and not isinstance(v, BaseChatModel): - raise ValueError("Input must be an instance of BaseChatModel") + # the model validator in langchain forcibly expects a dictionary + elif v is not None and not isinstance(v, (dict, BaseChatModel)): + raise ValueError("Input must be an instance of dict or BaseChatModel") return v @@ -53,4 +60,5 @@ def _model(cls, v): model=_default_model, history=_default_history, agent=_default_agent, + memory_provider=_default_memory_provider, ) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index 74eaa2bf..a28383e6 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -57,7 +57,7 @@ class AgentMessage(Event): @field_validator("message", mode="before") def _message(cls, v): if isinstance(v, BaseMessage): - v = v.dict() + v = v.model_dump() v["type"] = "ai" return v @@ -93,7 +93,7 @@ class AgentMessageDelta(UnpersistedEvent): @field_validator("delta", "snapshot", mode="before") def _message(cls, v): if isinstance(v, BaseMessage): - v = v.dict() + v = v.model_dump() v["type"] = "AIMessageChunk" return v diff --git a/src/controlflow/events/history.py b/src/controlflow/events/history.py index 8c26affe..e62cc660 100644 --- a/src/controlflow/events/history.py +++ b/src/controlflow/events/history.py @@ -147,19 +147,12 @@ def get_events( class FileHistory(History): base_path: Path = Field( - default_factory=lambda: controlflow.settings.home_path / "filestore_events" + default_factory=lambda: controlflow.settings.home_path / "history/FileHistory" ) def path(self, thread_id: str) -> Path: return self.base_path / f"{thread_id}.json" - @field_validator("base_path", mode="before") - def _validate_path(cls, v): - v = Path(v).expanduser() - if not v.exists(): - v.mkdir(parents=True, exist_ok=True) - return v - def get_events( self, thread_id: str, @@ -173,7 +166,6 @@ def get_events( Args: thread_id (str): The ID of the thread to retrieve events from. - tags (Optional[list[str]]): The tags associated with the events (default: None). types (Optional[list[str]]): The list of event types to filter by (default: None). before_id (Optional[str]): The ID of the event before which to stop retrieving events (default: None). after_id (Optional[str]): The ID of the event after which to start retrieving events (default: None). @@ -182,10 +174,12 @@ def get_events( Returns: list[Event]: A list of events that match the specified criteria. """ - if not self.path(thread_id).exists(): + file_path = self.path(thread_id) + + if not file_path.exists(): return [] - with open(self.path(thread_id), "r") as f: + with file_path.open("r") as f: raw_data = f.read() validator = get_event_validator() @@ -200,11 +194,22 @@ def get_events( ) def add_events(self, thread_id: str, events: list[Event]): - if self.path(thread_id).exists(): - with open(self.path(thread_id), "r") as f: + # TODO: this is pretty inefficient because we read / write the entire file + # every time instead of doing it incrementally. Need to switch to JSONL + # if we want to improve performance. + file_path = self.path(thread_id) + + if not file_path.exists(): + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.touch() + + with file_path.open("r") as f: + try: all_events = json.load(f) - else: - all_events = [] + except json.JSONDecodeError: + all_events = [] + all_events.extend([event.model_dump(mode="json") for event in events]) - with open(self.path(thread_id), "w") as f: + + with file_path.open("w") as f: json.dump(all_events, f) diff --git a/src/controlflow/events/message_compiler.py b/src/controlflow/events/message_compiler.py index f9fa054a..6ed4d54d 100644 --- a/src/controlflow/events/message_compiler.py +++ b/src/controlflow/events/message_compiler.py @@ -146,7 +146,9 @@ def format_message_name( def count_tokens(message: BaseMessage) -> int: # always use gpt-3.5 token counter with the entire message object; we only need to be approximate here - return len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(message.json())) + return len( + tiktoken.encoding_for_model("gpt-3.5-turbo").encode(message.model_dump_json()) + ) def trim_messages( diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index 6c0cc07c..f76c0bc1 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -1,16 +1,17 @@ import uuid from contextlib import contextmanager, nullcontext -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Union from prefect.context import FlowRunContext -from pydantic import Field +from pydantic import Field, field_validator +from typing_extensions import Self import controlflow from controlflow.agents import Agent from controlflow.events.base import Event from controlflow.events.history import History from controlflow.utilities.context import ctx -from controlflow.utilities.general import ControlFlowModel +from controlflow.utilities.general import ControlFlowModel, unwrap from controlflow.utilities.logging import get_logger from controlflow.utilities.prefect import prefect_flow_context @@ -54,7 +55,7 @@ class Flow(ControlFlowModel): context: dict[str, Any] = {} _cm_stack: list[contextmanager] = [] - def __enter__(self): + def __enter__(self) -> Self: # use stack so we can enter the context multiple times cm = self.create_context() self._cm_stack.append(cm) @@ -69,6 +70,12 @@ def __init__(self, **kwargs): kwargs["parent"] = get_flow() super().__init__(**kwargs) + @field_validator("description") + def _validate_description(cls, v): + if v: + v = unwrap(v) + return v + def get_prompt(self) -> str: """ Generate a prompt to share information about the flow with an agent. @@ -111,7 +118,7 @@ def add_events(self, events: list[Event]): self.history.add_events(thread_id=self.thread_id, events=events) @contextmanager - def create_context(self, **prefect_kwargs): + def create_context(self, **prefect_kwargs) -> Generator[Self, None, None]: # create a new Prefect flow if we're not already in a flow run if FlowRunContext.get() is None: prefect_context = prefect_flow_context(**prefect_kwargs) diff --git a/src/controlflow/flows/graph.py b/src/controlflow/flows/graph.py index 2c2af8df..d200958a 100644 --- a/src/controlflow/flows/graph.py +++ b/src/controlflow/flows/graph.py @@ -73,7 +73,7 @@ def add_task(self, task: Task): ) # add the task's subtasks - for subtask in task._subtasks: + for subtask in task.subtasks: self.add_edge( Edge( upstream=subtask, diff --git a/src/controlflow/llm/models.py b/src/controlflow/llm/models.py index ea5428f2..6938e197 100644 --- a/src/controlflow/llm/models.py +++ b/src/controlflow/llm/models.py @@ -5,6 +5,7 @@ from pydantic import ValidationError import controlflow +from controlflow.utilities.general import unwrap from controlflow.utilities.logging import get_logger logger = get_logger(__name__) @@ -75,21 +76,21 @@ def _get_initial_default_model() -> BaseChatModel: if isinstance(exc, ValidationError) and "Did not find openai_api_key" in str( exc ): - msg = inspect.cleandoc(""" + msg = unwrap(""" The default LLM model could not be created because the OpenAI API key was not found. ControlFlow will continue to work, but you must manually provide an LLM model for each agent. Please set the OPENAI_API_KEY environment variable or choose a different default LLM model. For more information, please see - https://controlflow.ai/guides/llms. + https://controlflow.ai/guides/configure-llms. """).replace("\n", " ") else: msg = ( - inspect.cleandoc(""" + unwrap(""" The default LLM model could not be created. ControlFlow will continue to work, but you must manually provide an LLM model for each agent. For more information, please see - https://controlflow.ai/guides/llms. The error was: + https://controlflow.ai/guides/configure-llms. The error was: """).replace("\n", " ") + f"\n{exc}" ) diff --git a/src/controlflow/llm/rules.py b/src/controlflow/llm/rules.py index 259a02d6..21f7596a 100644 --- a/src/controlflow/llm/rules.py +++ b/src/controlflow/llm/rules.py @@ -1,10 +1,11 @@ +import textwrap from typing import Optional from langchain_anthropic import ChatAnthropic from langchain_openai import AzureChatOpenAI, ChatOpenAI from controlflow.llm.models import BaseChatModel -from controlflow.utilities.general import ControlFlowModel +from controlflow.utilities.general import ControlFlowModel, unwrap class LLMRules(ControlFlowModel): @@ -16,6 +17,8 @@ class LLMRules(ControlFlowModel): necessary. """ + model: Optional[BaseChatModel] + # require at least one non-system message require_at_least_one_message: bool = False @@ -41,10 +44,19 @@ class LLMRules(ControlFlowModel): # the name associated with a message must conform to a specific format require_message_name_format: Optional[str] = None + def model_instructions(self) -> Optional[list[str]]: + pass + class OpenAIRules(LLMRules): require_message_name_format: str = r"[^a-zA-Z0-9_-]" + model: ChatOpenAI + + def model_instructions(self) -> list[str]: + instructions = [] + return instructions + class AnthropicRules(LLMRules): require_at_least_one_message: bool = True @@ -56,8 +68,17 @@ class AnthropicRules(LLMRules): def rules_for_model(model: BaseChatModel) -> LLMRules: if isinstance(model, (ChatOpenAI, AzureChatOpenAI)): - return OpenAIRules() - elif isinstance(model, ChatAnthropic): - return AnthropicRules() - else: - return LLMRules() + return OpenAIRules(model=model) + if isinstance(model, ChatAnthropic): + return AnthropicRules(model=model) + + try: + from langchain_google_vertexai.model_garden import ChatAnthropicVertex + + if isinstance(model, ChatAnthropicVertex): + return AnthropicRules(model=model) + except ImportError: + pass + + # catchall + return LLMRules(model=model) diff --git a/src/controlflow/memory/__init__.py b/src/controlflow/memory/__init__.py new file mode 100644 index 00000000..db0261eb --- /dev/null +++ b/src/controlflow/memory/__init__.py @@ -0,0 +1 @@ +from .memory import Memory diff --git a/src/controlflow/memory/memory.py b/src/controlflow/memory/memory.py new file mode 100644 index 00000000..c675cbae --- /dev/null +++ b/src/controlflow/memory/memory.py @@ -0,0 +1,164 @@ +import abc +import re +from typing import Dict, List, Optional, Union + +from pydantic import Field, field_validator, model_validator + +import controlflow +from controlflow.tools.tools import Tool +from controlflow.utilities.general import ControlFlowModel, unwrap + + +def sanitize_memory_key(key: str) -> str: + # Remove any characters that are not alphanumeric or underscore + return re.sub(r"[^a-zA-Z0-9_]", "", key) + + +class MemoryProvider(ControlFlowModel, abc.ABC): + def configure(self, memory_key: str) -> None: + """Configure the provider for a specific memory.""" + pass + + @abc.abstractmethod + def add(self, memory_key: str, content: str) -> str: + """Create a new memory and return its ID.""" + pass + + @abc.abstractmethod + def delete(self, memory_key: str, memory_id: str) -> None: + """Delete a memory by its ID.""" + pass + + @abc.abstractmethod + def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]: + """Search for n memories using a string query.""" + pass + + +class Memory(ControlFlowModel): + """ + A memory module is a partitioned collection of memories that are stored in a + vector database, configured by a MemoryProvider. + """ + + key: str + instructions: str = Field( + description="Explain what this memory is for and how it should be used." + ) + provider: MemoryProvider = Field( + default_factory=lambda: controlflow.defaults.memory_provider, + validate_default=True, + ) + + def __hash__(self) -> int: + return id(self) + + @field_validator("provider", mode="before") + @classmethod + def validate_provider( + cls, v: Optional[Union[MemoryProvider, str]] + ) -> MemoryProvider: + if isinstance(v, str): + return get_memory_provider(v) + if v is None: + raise ValueError( + unwrap( + """ + Memory modules require a MemoryProvider to configure the + underlying vector database. No provider was passed as an + argument, and no default value has been configured. + + For more information on configuring a memory provider, see + the [Memory + documentation](https://controlflow.ai/patterns/memory), and + please review the [default provider + guide](https://controlflow.ai/guides/default-memory) for + information on configuring a default provider. + + Please note that if you are using ControlFlow for the first + time, this error is expected because ControlFlow does not include + vector dependencies by default. + """ + ) + ) + return v + + @field_validator("key") + @classmethod + def validate_key(cls, v: str) -> str: + sanitized = sanitize_memory_key(v) + if sanitized != v: + raise ValueError( + "Memory key must contain only alphanumeric characters and underscores" + ) + return sanitized + + @model_validator(mode="after") + def _configure_provider(self): + self.provider.configure(self.key) + return self + + def add(self, content: str) -> str: + return self.provider.add(self.key, content) + + def delete(self, memory_id: str) -> None: + self.provider.delete(self.key, memory_id) + + def search(self, query: str, n: int = 20) -> Dict[str, str]: + return self.provider.search(self.key, query, n) + + def get_tools(self) -> List[Tool]: + return [ + Tool.from_function( + self.add, + name=f"store_memory_{self.key}", + description=f'Create a new memory in Memory: "{self.key}".', + ), + Tool.from_function( + self.delete, + name=f"delete_memory_{self.key}", + description=f'Delete a memory by its ID from Memory: "{self.key}".', + ), + Tool.from_function( + self.search, + name=f"search_memories_{self.key}", + description=f'Search for memories relevant to a string query in Memory: "{self.key}". Returns a dictionary of memory IDs and their contents.', + ), + ] + + +def get_memory_provider(provider: str) -> MemoryProvider: + # --- CHROMA --- + + if provider.startswith("chroma"): + try: + import chromadb + except ImportError: + raise ImportError( + "To use Chroma as a memory provider, please install the `chromadb` package." + ) + + import controlflow.memory.providers.chroma as chroma_providers + + if provider == "chroma-ephemeral": + return chroma_providers.ChromaEphemeralMemory() + elif provider == "chroma-db": + return chroma_providers.ChromaPersistentMemory() + elif provider == "chroma-cloud": + return chroma_providers.ChromaCloudMemory() + + # --- LanceDB --- + + elif provider.startswith("lancedb"): + try: + import lancedb + except ImportError: + raise ImportError( + "To use LanceDB as a memory provider, please install the `lancedb` package." + ) + + import controlflow.memory.providers.lance as lance_providers + + return lance_providers.LanceMemory() + + raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.') diff --git a/src/controlflow/memory/providers/__init__.py b/src/controlflow/memory/providers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/controlflow/memory/providers/chroma.py b/src/controlflow/memory/providers/chroma.py new file mode 100644 index 00000000..18291c3a --- /dev/null +++ b/src/controlflow/memory/providers/chroma.py @@ -0,0 +1,77 @@ +import uuid +from typing import Dict, Optional + +import chromadb +from pydantic import Field, PrivateAttr + +import controlflow +from controlflow.memory.memory import MemoryProvider + + +class ChromaMemory(MemoryProvider): + model_config = dict(arbitrary_types_allowed=True) + client: chromadb.ClientAPI = Field( + default_factory=lambda: chromadb.PersistentClient( + path=str(controlflow.settings.home_path / "memory/chroma") + ) + ) + collection_name: str = Field( + "memory-{key}", + description=""" + Optional; the name of the collection to use. This should be a + string optionally formatted with the variable `key`, which + will be provided by the memory module. The default is `"memory-{{key}}"`. + """, + ) + + def get_collection(self, memory_key: str) -> chromadb.Collection: + return self.client.get_or_create_collection( + self.collection_name.format(key=memory_key) + ) + + def add(self, memory_key: str, content: str) -> str: + collection = self.get_collection(memory_key) + memory_id = str(uuid.uuid4()) + collection.add( + documents=[content], metadatas=[{"id": memory_id}], ids=[memory_id] + ) + return memory_id + + def delete(self, memory_key: str, memory_id: str) -> None: + collection = self.get_collection(memory_key) + collection.delete(ids=[memory_id]) + + def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]: + results = self.get_collection(memory_key).query( + query_texts=[query], n_results=n + ) + return dict(zip(results["ids"][0], results["documents"][0])) + + +def ChromaEphemeralMemory(**kwargs) -> ChromaMemory: + return ChromaMemory(client=chromadb.EphemeralClient(**kwargs)) + + +def ChromaPersistentMemory(path: str = None, **kwargs) -> ChromaMemory: + return ChromaMemory( + client=chromadb.PersistentClient( + path=path or str(controlflow.settings.home_path / "memory" / "chroma"), + **kwargs, + ) + ) + + +def ChromaCloudMemory( + tenant: Optional[str] = None, + database: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs, +) -> ChromaMemory: + return ChromaMemory( + client=chromadb.CloudClient( + api_key=api_key or controlflow.settings.chroma_cloud_api_key, + tenant=tenant or controlflow.settings.chroma_cloud_tenant, + database=database or controlflow.settings.chroma_cloud_database, + **kwargs, + ) + ) diff --git a/src/controlflow/memory/providers/lance.py b/src/controlflow/memory/providers/lance.py new file mode 100644 index 00000000..3ea167c9 --- /dev/null +++ b/src/controlflow/memory/providers/lance.py @@ -0,0 +1,74 @@ +import functools +import uuid +from pathlib import Path +from typing import Callable, Dict, Optional + +import lancedb +from lancedb.embeddings import get_registry +from lancedb.pydantic import LanceModel, Vector +from pydantic import Field, PrivateAttr + +import controlflow +from controlflow.memory.memory import MemoryProvider + + +class LanceMemory(MemoryProvider): + uri: Path = Field( + default=controlflow.settings.home_path / "memory" / "lancedb", + description="The URI of the Lance database to use.", + ) + table_name: str = Field( + "memory-{key}", + description=""" + Optional; the name of the table to use. This should be a + string optionally formatted with the variable `key`, which + will be provCallablethe memory module. The default is `"memory-{{key}}"`. + """, + ) + embedding_fn: Callable = Field( + default_factory=lambda: get_registry() + .get("openai") + .create(name="text-embedding-ada-002"), + description="The LanceDB embedding function to use. Defaults to `get_registry().get('openai').create(name='text-embedding-ada-002')`.", + ) + _cached_model: Optional[LanceModel] = None + + def get_model(self) -> LanceModel: + if self._cached_model is None: + fn = self.embedding_fn + + class Memory(LanceModel): + id: str = Field(..., description="The ID of the memory.") + text: str = fn.SourceField() + vector: Vector(fn.ndims()) = fn.VectorField() # noqa + + self._cached_model = Memory + + return self._cached_model + + def get_db(self) -> lancedb.DBConnection: + return lancedb.connect(self.uri) + + def get_table(self, memory_key: str) -> lancedb.table.Table: + table_name = self.table_name.format(key=memory_key) + db = self.get_db() + model = self.get_model() + try: + return db.open_table(table_name) + except FileNotFoundError: + return db.create_table(table_name, schema=model) + + def add(self, memory_key: str, content: str) -> str: + memory_id = str(uuid.uuid4()) + table = self.get_table(memory_key) + table.add([{"id": memory_id, "text": content}]) + return memory_id + + def delete(self, memory_key: str, memory_id: str) -> None: + table = self.get_table(memory_key) + table.delete(f'id = "{memory_id}"') + + def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]: + table = self.get_table(memory_key) + results = table.search(query).limit(n).to_pydantic(self.get_model()) + return {r.id: r.text for r in results} diff --git a/src/controlflow/orchestration/__init__.py b/src/controlflow/orchestration/__init__.py index 8f3ed651..e4870f81 100644 --- a/src/controlflow/orchestration/__init__.py +++ b/src/controlflow/orchestration/__init__.py @@ -1,2 +1,3 @@ +from . import conditions from .orchestrator import Orchestrator from .handler import Handler diff --git a/src/controlflow/orchestration/conditions.py b/src/controlflow/orchestration/conditions.py new file mode 100644 index 00000000..aee1c852 --- /dev/null +++ b/src/controlflow/orchestration/conditions.py @@ -0,0 +1,166 @@ +import logging +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from pydantic import BaseModel, field_validator + +from controlflow.tasks.task import Task +from controlflow.utilities.general import ControlFlowModel +from controlflow.utilities.logging import get_logger + +if TYPE_CHECKING: + from controlflow.orchestration.orchestrator import Orchestrator + +logger = get_logger(__name__) + + +class RunContext(ControlFlowModel): + """ + Context for a run. + """ + + model_config = dict(arbitrary_types_allowed=True) + + orchestrator: "Orchestrator" + llm_calls: int = 0 + agent_turns: int = 0 + run_end_condition: "RunEndCondition" + + @field_validator("run_end_condition", mode="before") + def validate_condition(cls, v: Any) -> "RunEndCondition": + if not isinstance(v, RunEndCondition): + v = FnCondition(v) + return v + + def should_end(self) -> bool: + return self.run_end_condition.should_end(self) + + +class RunEndCondition: + def should_end(self, context: RunContext) -> bool: + """ + Returns True if the run should end, False otherwise. + """ + return False + + def __or__( + self, other: Union["RunEndCondition", Callable[[RunContext], bool]] + ) -> "RunEndCondition": + if isinstance(other, RunEndCondition): + return OR_(self, other) + elif callable(other): + return OR_(self, FnCondition(other)) + else: + raise NotImplementedError( + f"Cannot combine RunEndCondition with {type(other)}" + ) + + def __and__( + self, other: Union["RunEndCondition", Callable[[RunContext], bool]] + ) -> "RunEndCondition": + if isinstance(other, RunEndCondition): + return AND_(self, other) + elif callable(other): + return AND_(self, FnCondition(other)) + else: + raise NotImplementedError( + f"Cannot combine RunEndCondition with {type(other)}" + ) + + +class FnCondition(RunEndCondition): + def __init__(self, fn: Callable[[RunContext], bool]): + self.fn = fn + + def should_end(self, context: RunContext) -> bool: + result = self.fn(context) + if result: + logger.debug("Custom function condition met; ending run.") + return result + + +class OR_(RunEndCondition): + def __init__(self, *conditions: RunEndCondition): + self.conditions = conditions + + def should_end(self, context: RunContext) -> bool: + result = any(condition.should_end(context) for condition in self.conditions) + if result: + logger.debug("At least one condition in OR clause met.") + return result + + +class AND_(RunEndCondition): + def __init__(self, *conditions: RunEndCondition): + self.conditions = conditions + + def should_end(self, context: RunContext) -> bool: + result = all(condition.should_end(context) for condition in self.conditions) + if result: + logger.debug("All conditions in AND clause met.") + return result + + +class AllComplete(RunEndCondition): + def __init__(self, tasks: Optional[list[Task]] = None): + self.tasks = tasks + + def should_end(self, context: RunContext) -> bool: + tasks = self.tasks if self.tasks is not None else context.orchestrator.tasks + result = all(t.is_complete() for t in tasks) + if result: + logger.debug("All tasks are complete; ending run.") + return result + + +class AnyComplete(RunEndCondition): + def __init__(self, tasks: Optional[list[Task]] = None, min_complete: int = 1): + self.tasks = tasks + if min_complete < 1: + raise ValueError("min_complete must be at least 1") + self.min_complete = min_complete + + def should_end(self, context: RunContext) -> bool: + tasks = self.tasks if self.tasks is not None else context.orchestrator.tasks + result = sum(t.is_complete() for t in tasks) >= self.min_complete + if result: + logger.debug("At least one task is complete; ending run.") + return result + + +class AnyFailed(RunEndCondition): + def __init__(self, tasks: Optional[list[Task]] = None, min_failed: int = 1): + self.tasks = tasks + if min_failed < 1: + raise ValueError("min_failed must be at least 1") + self.min_failed = min_failed + + def should_end(self, context: RunContext) -> bool: + tasks = self.tasks if self.tasks is not None else context.orchestrator.tasks + result = sum(t.is_failed() for t in tasks) >= self.min_failed + if result: + logger.debug("At least one task has failed; ending run.") + return result + + +class MaxAgentTurns(RunEndCondition): + def __init__(self, n: int): + self.n = n + + def should_end(self, context: RunContext) -> bool: + result = context.agent_turns >= self.n + if result: + logger.debug( + f"Maximum number of agent turns ({self.n}) reached; ending run." + ) + return result + + +class MaxLLMCalls(RunEndCondition): + def __init__(self, n: int): + self.n = n + + def should_end(self, context: RunContext) -> bool: + result = context.llm_calls >= self.n + if result: + logger.debug(f"Maximum number of LLM calls ({self.n}) reached; ending run.") + return result diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index b11c970b..c6fff6fb 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -1,7 +1,7 @@ import logging -from typing import Optional, TypeVar +from typing import Callable, Optional, TypeVar, Union -from pydantic import Field, field_validator +from pydantic import BaseModel, Field, field_validator import controlflow from controlflow.agents.agent import Agent @@ -11,6 +11,15 @@ from controlflow.flows import Flow from controlflow.instructions import get_instructions from controlflow.llm.messages import BaseMessage +from controlflow.memory import Memory +from controlflow.orchestration.conditions import ( + AllComplete, + FnCondition, + MaxAgentTurns, + MaxLLMCalls, + RunContext, + RunEndCondition, +) from controlflow.orchestration.handler import Handler from controlflow.orchestration.turn_strategies import Popcorn, TurnStrategy from controlflow.tasks.task import Task @@ -114,8 +123,7 @@ def get_tools(self) -> list[Tool]: # add completion tools if task.completion_agents is None or self.agent in task.completion_agents: - tools.append(task.create_success_tool()) - tools.append(task.create_fail_tool()) + tools.extend(task.get_completion_tools()) # add turn strategy tools only if there are multiple available agents available_agents = self.get_available_agents() @@ -125,14 +133,45 @@ def get_tools(self) -> list[Tool]: tools = as_tools(tools) return tools + def get_memories(self) -> list[Memory]: + memories = set() + + memories.update(self.agent.memories) + + for task in self.get_tasks("assigned"): + memories.update(task.memories) + + return memories + @prefect_task(task_run_name="Orchestrator.run()") def run( - self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None - ): + self, + max_llm_calls: Optional[int] = None, + max_agent_turns: Optional[int] = None, + model_kwargs: Optional[dict] = None, + run_until: Optional[ + Union[RunEndCondition, Callable[[RunContext], bool]] + ] = None, + ) -> RunContext: import controlflow.events.orchestrator_events - call_count = 0 - turn_count = 0 + # Create the base termination condition + if run_until is None: + run_until = AllComplete() + elif not isinstance(run_until, RunEndCondition): + run_until = FnCondition(run_until) + + # Add max_llm_calls condition + if max_llm_calls is None: + max_llm_calls = controlflow.settings.orchestrator_max_llm_calls + run_until = run_until | MaxLLMCalls(max_llm_calls) + + # Add max_agent_turns condition + if max_agent_turns is None: + max_agent_turns = controlflow.settings.orchestrator_max_agent_turns + run_until = run_until | MaxAgentTurns(max_agent_turns) + + run_context = RunContext(orchestrator=self, run_end_condition=run_until) # Initialize the agent if not already set if not self.agent: @@ -140,24 +179,14 @@ def run( None, self.get_available_agents() ) - if max_agent_turns is None: - max_agent_turns = controlflow.settings.orchestrator_max_agent_turns - if max_llm_calls is None: - max_llm_calls = controlflow.settings.orchestrator_max_llm_calls - # Signal the start of orchestration self.handle_event( controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self) ) try: - while any(t.is_incomplete() for t in self.tasks): - # Check if we've reached the turn or call limit - if max_agent_turns is not None and turn_count >= max_agent_turns: - logger.debug(f"Max agent turns reached: {max_agent_turns}") - break - - if max_llm_calls is not None and call_count >= max_llm_calls: + while True: + if run_context.should_end(): break self.handle_event( @@ -165,8 +194,10 @@ def run( orchestrator=self, agent=self.agent ) ) - turn_count += 1 - call_count += self.run_agent_turn(max_llm_calls - call_count) + self.run_agent_turn( + run_context=run_context, + model_kwargs=model_kwargs, + ) self.handle_event( controlflow.events.orchestrator_events.AgentTurnEnd( orchestrator=self, agent=self.agent @@ -194,23 +225,37 @@ def run( orchestrator=self ) ) + return run_context @prefect_task async def run_async( - self, max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None - ): - """ - Run the orchestration process asynchronously until completion or limits are reached. - - Args: - max_llm_calls (int, optional): Maximum number of LLM calls to make. - max_agent_turns (int, optional): Maximum number of agent turns to run - (each turn can consist of multiple LLM calls) - """ + self, + max_llm_calls: Optional[int] = None, + max_agent_turns: Optional[int] = None, + model_kwargs: Optional[dict] = None, + run_until: Optional[ + Union[RunEndCondition, Callable[[RunContext], bool]] + ] = None, + ) -> RunContext: import controlflow.events.orchestrator_events - call_count = 0 - turn_count = 0 + # Create the base termination condition + if run_until is None: + run_until = AllComplete() + elif not isinstance(run_until, RunEndCondition): + run_until = FnCondition(run_until) + + # Add max_llm_calls condition + if max_llm_calls is None: + max_llm_calls = controlflow.settings.orchestrator_max_llm_calls + run_until = run_until | MaxLLMCalls(max_llm_calls) + + # Add max_agent_turns condition + if max_agent_turns is None: + max_agent_turns = controlflow.settings.orchestrator_max_agent_turns + run_until = run_until | MaxAgentTurns(max_agent_turns) + + run_context = RunContext(orchestrator=self, run_end_condition=run_until) # Initialize the agent if not already set if not self.agent: @@ -218,24 +263,15 @@ async def run_async( None, self.get_available_agents() ) - if max_agent_turns is None: - max_agent_turns = controlflow.settings.orchestrator_max_agent_turns - if max_llm_calls is None: - max_llm_calls = controlflow.settings.orchestrator_max_llm_calls - # Signal the start of orchestration self.handle_event( controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self) ) try: - while any(t.is_incomplete() for t in self.tasks): - # Check if we've reached the turn or call limit - if max_agent_turns is not None and turn_count >= max_agent_turns: - logger.debug(f"Max agent turns reached: {max_agent_turns}") - break - - if max_llm_calls is not None and call_count >= max_llm_calls: + while True: + # Check termination condition + if run_context.should_end(): break self.handle_event( @@ -243,9 +279,9 @@ async def run_async( orchestrator=self, agent=self.agent ) ) - turn_count += 1 - call_count += await self.run_agent_turn_async( - max_llm_calls - call_count + await self.run_agent_turn_async( + run_context=run_context, + model_kwargs=model_kwargs, ) self.handle_event( controlflow.events.orchestrator_events.AgentTurnEnd( @@ -274,19 +310,17 @@ async def run_async( orchestrator=self ) ) + return run_context @prefect_task(task_run_name="Agent turn: {self.agent.name}") - def run_agent_turn(self, max_llm_calls: Optional[int]) -> int: + def run_agent_turn( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> int: """ Run a single agent turn, which may consist of multiple LLM calls. - - Args: - max_llm_calls (Optional[int]): The number of LLM calls allowed. - - Returns: - int: The number of LLM calls made during this turn. """ - call_count = 0 assigned_tasks = self.get_tasks("assigned") self.turn_strategy.begin_turn() @@ -297,39 +331,47 @@ def run_agent_turn(self, max_llm_calls: Optional[int]) -> int: task.mark_running() self.handle_event( OrchestratorMessage( - content=f"Starting task {task.name} (ID {task.id}) " + content=f"Starting task {task.name + ' ' if task.name else ''}(ID {task.id}) " f"with objective: {task.objective}" ) ) while not self.turn_strategy.should_end_turn(): + # fail any tasks that have reached their max llm calls for task in assigned_tasks: if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: task.mark_failed(reason="Max LLM calls reached for this task.") - else: - task._llm_calls += 1 # Check if there are any ready tasks left if not any(t.is_ready() for t in assigned_tasks): logger.debug("No `ready` tasks to run") break - call_count += 1 + if run_context.should_end(): + break + messages = self.compile_messages() tools = self.get_tools() - for event in self.agent._run_model(messages=messages, tools=tools): + for event in self.agent._run_model( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): self.handle_event(event) - # Check if we've reached the call limit within a turn - if max_llm_calls is not None and call_count >= max_llm_calls: - logger.debug(f"Max LLM calls reached: {max_llm_calls}") - break + run_context.llm_calls += 1 + for task in assigned_tasks: + task._llm_calls += 1 - return call_count + run_context.agent_turns += 1 @prefect_task - async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int: + async def run_agent_turn_async( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> int: """ Run a single agent turn asynchronously, which may consist of multiple LLM calls. @@ -339,7 +381,6 @@ async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int: Returns: int: The number of LLM calls made during this turn. """ - call_count = 0 assigned_tasks = self.get_tasks("assigned") self.turn_strategy.begin_turn() @@ -356,32 +397,34 @@ async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int: ) while not self.turn_strategy.should_end_turn(): + # fail any tasks that have reached their max llm calls for task in assigned_tasks: if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: task.mark_failed(reason="Max LLM calls reached for this task.") - else: - task._llm_calls += 1 # Check if there are any ready tasks left if not any(t.is_ready() for t in assigned_tasks): logger.debug("No `ready` tasks to run") break - call_count += 1 + if run_context.should_end(): + break + messages = self.compile_messages() tools = self.get_tools() async for event in self.agent._run_model_async( - messages=messages, tools=tools + messages=messages, + tools=tools, + model_kwargs=model_kwargs, ): self.handle_event(event) - # Check if we've reached the call limit within a turn - if max_llm_calls is not None and call_count >= max_llm_calls: - logger.debug(f"Max LLM calls reached: {max_llm_calls}") - break + run_context.llm_calls += 1 + for task in assigned_tasks: + task._llm_calls += 1 - return call_count + run_context.agent_turns += 1 def compile_prompt(self) -> str: """ @@ -392,19 +435,26 @@ def compile_prompt(self) -> str: """ from controlflow.orchestration.prompt_templates import ( InstructionsTemplate, + LLMInstructionsTemplate, + MemoryTemplate, TasksTemplate, ToolTemplate, ) - tools = self.get_tools() + llm_rules = self.agent.get_llm_rules() prompts = [ self.agent.get_prompt(), self.flow.get_prompt(), TasksTemplate(tasks=self.get_tasks("ready")).render(), - ToolTemplate(tools=tools).render(), + ToolTemplate(tools=self.get_tools()).render(), + MemoryTemplate(memories=self.get_memories()).render(), InstructionsTemplate(instructions=get_instructions()).render(), + LLMInstructionsTemplate( + instructions=llm_rules.model_instructions() + ).render(), ] + prompt = "\n\n".join([p for p in prompts if p]) return prompt @@ -457,6 +507,10 @@ def collect_tasks(task: Task): for dependency in task.depends_on: collect_tasks(dependency) + # Collect parent + if task.parent and not task.parent.wait_for_subtasks: + collect_tasks(task.parent) + # Check if the task is ready if task.is_ready(): ready_tasks.append(task) @@ -501,3 +555,6 @@ def get_task_hierarchy(self) -> dict: hierarchy[task.id] = task_dict_map[task.id] return hierarchy + + +RunContext.model_rebuild() diff --git a/src/controlflow/orchestration/print_handler.py b/src/controlflow/orchestration/print_handler.py index f76523aa..e121474d 100644 --- a/src/controlflow/orchestration/print_handler.py +++ b/src/controlflow/orchestration/print_handler.py @@ -59,7 +59,9 @@ def update_live(self, latest: BaseMessage = None): cf_console.print(format_event(latest)) def on_orchestrator_start(self, event: OrchestratorStart): - self.live: Live = Live(auto_refresh=False, console=cf_console) + self.live: Live = Live( + auto_refresh=False, console=cf_console, vertical_overflow="visible" + ) self.events.clear() try: self.live.start() diff --git a/src/controlflow/orchestration/prompt_templates.py b/src/controlflow/orchestration/prompt_templates.py index 47b73b5e..58de81eb 100644 --- a/src/controlflow/orchestration/prompt_templates.py +++ b/src/controlflow/orchestration/prompt_templates.py @@ -4,6 +4,7 @@ from controlflow.agents.agent import Agent from controlflow.flows import Flow +from controlflow.memory.memory import Memory from controlflow.tasks.task import Task from controlflow.tools.tools import Tool from controlflow.utilities.general import ControlFlowModel @@ -78,6 +79,14 @@ def should_render(self) -> bool: return bool(self.instructions) +class LLMInstructionsTemplate(Template): + template_path: str = "llm_instructions.jinja" + instructions: Optional[list[str]] = None + + def should_render(self) -> bool: + return bool(self.instructions) + + class ToolTemplate(Template): template_path: str = "tools.jinja" tools: list[Tool] @@ -86,6 +95,14 @@ def should_render(self) -> bool: return any(t.instructions for t in self.tools) +class MemoryTemplate(Template): + template_path: str = "memories.jinja" + memories: list[Memory] + + def should_render(self) -> bool: + return bool(self.memories) + + def build_task_hierarchy(provided_tasks: List[Task]) -> List[Dict[str, Any]]: """ Builds a hierarchical structure of tasks, including all descendants of provided tasks diff --git a/src/controlflow/orchestration/prompt_templates/flow.jinja b/src/controlflow/orchestration/prompt_templates/flow.jinja index 1d23fb87..b0902ecc 100644 --- a/src/controlflow/orchestration/prompt_templates/flow.jinja +++ b/src/controlflow/orchestration/prompt_templates/flow.jinja @@ -1,6 +1,6 @@ # Flow -Here is context about the flow you are participating in. +Here is context about the flow/thread you are participating in. - Name: {{ flow.name }} {% if flow.description %} diff --git a/src/controlflow/orchestration/prompt_templates/instructions.jinja b/src/controlflow/orchestration/prompt_templates/instructions.jinja index b6d068ec..8151ec15 100644 --- a/src/controlflow/orchestration/prompt_templates/instructions.jinja +++ b/src/controlflow/orchestration/prompt_templates/instructions.jinja @@ -1,6 +1,6 @@ # Instructions -You must follow these instructions. Note that instructions can be changed at any time. +You must follow these instructions at all times. Note that instructions can be changed at any time. {% for instruction in instructions %} - {{ instruction }} diff --git a/src/controlflow/orchestration/prompt_templates/llm_instructions.jinja b/src/controlflow/orchestration/prompt_templates/llm_instructions.jinja new file mode 100644 index 00000000..e1c0f346 --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/llm_instructions.jinja @@ -0,0 +1,9 @@ +# LLM Instructions + +These instructions are specific to your LLM model. They must be followed to ensure compliance with the orchestrator and +other agents. + +{% for instruction in instructions %} +- {{ instruction }} + +{% endfor %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/memories.jinja b/src/controlflow/orchestration/prompt_templates/memories.jinja new file mode 100644 index 00000000..4886f395 --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/memories.jinja @@ -0,0 +1,12 @@ +# Memories + +You have the following memory modules installed. Consult your memory whenever +you think it would be helpful, and make sure to add new memories when you learn +something new. You should not refer directly to memories in your responses. + +{% for memory in memories %} +## Memory: {{ memory.key }} + +Instructions: {{ memory.instructions }} + +{% endfor %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/task.jinja b/src/controlflow/orchestration/prompt_templates/task.jinja index 8abc34b8..23fb868f 100644 --- a/src/controlflow/orchestration/prompt_templates/task.jinja +++ b/src/controlflow/orchestration/prompt_templates/task.jinja @@ -3,4 +3,10 @@ - objective: {{ task.objective }} {% if task.instructions %}- instructions: {{ task.instructions }}{% endif %} {% if task.result_type %}- result type: {{ task.result_type }}{% endif %} -{% if task.context %}- context: {{ task.context }}{% endif %} \ No newline at end of file +{% if task.context %}- context: {{ task.context }}{% endif %} +{% if task.parent %}- parent task ID: {{ task.parent.id }}{%endif %} +{% if task._subtasks%}- this task has the following subtask IDs: {{ task._subtasks | map(attribute='id') | join(', ') }} +{% if not task.wait_for_subtasks %}- complete this task as soon as you meet its objective, even if you haven't completed +its subtasks{% endif%}{% endif %} +{% if task.depends_on %}- this task depends on these upstream task IDs (includes subtasks): {{ task.depends_on | +map(attribute='id') | join(', ') }}{% endif %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/tasks.jinja b/src/controlflow/orchestration/prompt_templates/tasks.jinja index 338fbde2..36868957 100644 --- a/src/controlflow/orchestration/prompt_templates/tasks.jinja +++ b/src/controlflow/orchestration/prompt_templates/tasks.jinja @@ -1,10 +1,11 @@ -{% macro render_task_hierarchy(task_info, indent='') -%} +{% macro render_task_hierarchy(task_info, indent='') %} {{ indent }}- {{ task_info.task.id }} ({{ task_info.task.status.value }}){% if task_info['is_active'] %} -(active){% endif %} +(active){%endif%} {%- if task_info.children %} + {% for child in task_info.children %} -{{ render_task_hierarchy(child, indent + ' ') }} -{%- endfor %} +{{ render_task_hierarchy(child, indent + '-') }} +{% endfor %} {%- endif %} {%- endmacro -%} @@ -21,15 +22,29 @@ The following tasks are active: {{ task.get_prompt() }} -{% endfor %} -Only agents assigned to a task are able to mark the task as complete. You must use a tool to end your turn to let other -agents participate. If you are asked to talk to other agents, post messages. Do not impersonate another agent! Do not -impersonate the orchestrator! -Only mark a task failed if there is a technical error or issue preventing completion. +{% endfor %} -## Task hierarchy +Only agents assigned to a task are able to mark the task as complete. You must +use a tool to end your turn to let other agents participate. If you are asked to +talk to other agents, post messages. Do not impersonate another agent! Do not +impersonate the orchestrator! If you have been assigned a task, then you (and +other agents) must have the resources, knowledge, or tools required to complete +it. + +A task can only be marked complete one time. Do not attempt to mark a task +successful more than once. Even if the `result_type` does not appear to match +the objective, you must supply a single compatible result. Only mark a task +failed if there is a technical error or issue preventing completion. + +When a parent task must wait for subtasks, it means that all of its subtasks are +treated as upstream dependencies and must be completed before the parent task +can be marked as complete. However, if the parent task has +`wait_for_subtasks=False`, then it can and should be marked as complete as soon +as you can, regardless of the status of its subtasks. + +## Subtask hierarchy {% for task in task_hierarchy %} {{ render_task_hierarchy(task) }} diff --git a/src/controlflow/orchestration/turn_strategies.py b/src/controlflow/orchestration/turn_strategies.py index 803ffc9f..5494f8b2 100644 --- a/src/controlflow/orchestration/turn_strategies.py +++ b/src/controlflow/orchestration/turn_strategies.py @@ -38,7 +38,7 @@ def should_end_turn(self) -> bool: return self.end_turn -def create_end_turn_tool(strategy: TurnStrategy) -> Tool: +def get_end_turn_tool(strategy: TurnStrategy) -> Tool: @tool def end_turn() -> str: """ @@ -51,7 +51,7 @@ def end_turn() -> str: return end_turn -def create_delegate_tool( +def get_delegate_tool( strategy: TurnStrategy, available_agents: dict[Agent, list[Task]] ) -> Tool: @tool @@ -77,7 +77,7 @@ class SingleAgent(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -93,7 +93,7 @@ class Popcorn(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_delegate_tool(self, available_agents)] + return [get_delegate_tool(self, available_agents)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -107,7 +107,7 @@ class Random(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -119,7 +119,7 @@ class RoundRobin(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -136,7 +136,7 @@ class MostBusy(TurnStrategy): def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] @@ -152,9 +152,9 @@ def get_tools( self, current_agent: Agent, available_agents: dict[Agent, list[Task]] ) -> list[Tool]: if current_agent == self.moderator: - return [create_delegate_tool(self, available_agents)] + return [get_delegate_tool(self, available_agents)] else: - return [create_end_turn_tool(self)] + return [get_end_turn_tool(self)] def get_next_agent( self, current_agent: Optional[Agent], available_agents: Dict[Agent, List[Task]] diff --git a/src/controlflow/plan.py b/src/controlflow/plan.py index 9875e9f1..e861c3de 100644 --- a/src/controlflow/plan.py +++ b/src/controlflow/plan.py @@ -61,7 +61,7 @@ def plan( agent_dict = dict(enumerate(agents)) tool_dict = dict( - enumerate([t.dict(include={"name", "description"}) for t in tools]) + enumerate([t.model_dump(include={"name", "description"}) for t in tools]) ) def validate_plan(plan: list[PlanTask]): diff --git a/src/controlflow/run.py b/src/controlflow/run.py index 1f01a350..59c2fe58 100644 --- a/src/controlflow/run.py +++ b/src/controlflow/run.py @@ -1,9 +1,11 @@ -from typing import Any +from typing import Any, Callable, Optional, Union from prefect.context import TaskRunContext +import controlflow from controlflow.agents.agent import Agent from controlflow.flows import Flow, get_flow +from controlflow.orchestration.conditions import RunContext, RunEndCondition from controlflow.orchestration.handler import Handler from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy from controlflow.tasks.task import Task @@ -20,13 +22,16 @@ def get_task_run_name() -> str: @prefect_task(task_run_name=get_task_run_name) def run_tasks( tasks: list[Task], + instructions: str = None, flow: Flow = None, agent: Agent = None, turn_strategy: TurnStrategy = None, - raise_on_error: bool = True, + raise_on_failure: bool = True, max_llm_calls: int = None, max_agent_turns: int = None, handlers: list[Handler] = None, + model_kwargs: Optional[dict] = None, + run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, ) -> list[Any]: """ Run a list of tasks. @@ -42,12 +47,16 @@ def run_tasks( turn_strategy=turn_strategy, handlers=handlers, ) - orchestrator.run( - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - ) - if raise_on_error and any(t.is_failed() for t in tasks): + with controlflow.instructions(instructions): + orchestrator.run( + max_llm_calls=max_llm_calls, + max_agent_turns=max_agent_turns, + model_kwargs=model_kwargs, + run_until=run_until, + ) + + if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] if errors: raise ValueError( @@ -61,16 +70,19 @@ def run_tasks( @prefect_task(task_run_name=get_task_run_name) async def run_tasks_async( tasks: list[Task], + instructions: str = None, flow: Flow = None, agent: Agent = None, turn_strategy: TurnStrategy = None, - raise_on_error: bool = True, + raise_on_failure: bool = True, max_llm_calls: int = None, max_agent_turns: int = None, handlers: list[Handler] = None, + model_kwargs: Optional[dict] = None, + run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, ): """ - Run a list of tasks. + Run a list of tasks asynchronously. """ flow = flow or get_flow() or Flow() orchestrator = Orchestrator( @@ -80,12 +92,16 @@ async def run_tasks_async( turn_strategy=turn_strategy, handlers=handlers, ) - await orchestrator.run_async( - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - ) - if raise_on_error and any(t.is_failed() for t in tasks): + with controlflow.instructions(instructions): + await orchestrator.run_async( + max_llm_calls=max_llm_calls, + max_agent_turns=max_agent_turns, + model_kwargs=model_kwargs, + run_until=run_until, + ) + + if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] if errors: raise ValueError( @@ -102,18 +118,22 @@ def run( turn_strategy: TurnStrategy = None, max_llm_calls: int = None, max_agent_turns: int = None, - raise_on_error: bool = True, + raise_on_failure: bool = True, handlers: list[Handler] = None, + model_kwargs: Optional[dict] = None, + run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) results = run_tasks( tasks=[task], - raise_on_error=raise_on_error, + raise_on_failure=raise_on_failure, turn_strategy=turn_strategy, max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, handlers=handlers, + model_kwargs=model_kwargs, + run_until=run_until, ) return results[0] @@ -126,8 +146,10 @@ async def run_async( turn_strategy: TurnStrategy = None, max_llm_calls: int = None, max_agent_turns: int = None, - raise_on_error: bool = True, + raise_on_failure: bool = True, handlers: list[Handler] = None, + model_kwargs: Optional[dict] = None, + run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -138,7 +160,9 @@ async def run_async( turn_strategy=turn_strategy, max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, - raise_on_error=raise_on_error, + raise_on_failure=raise_on_failure, handlers=handlers, + model_kwargs=model_kwargs, + run_until=run_until, ) return results[0] diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index 266d9ae8..79a4b945 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -81,6 +81,28 @@ class Settings(ControlFlowSettings): 100_000, description="The maximum number of tokens to send to an LLM." ) + # ------------ Memory settings ------------ + + memory_provider: Optional[str] = Field( + default="chroma-db", + description="The default memory provider for agents.", + ) + + # ------------ Memory settings: ChromaDB ------------ + + chroma_cloud_tenant: Optional[str] = Field( + None, + description="The tenant for Chroma Cloud.", + ) + chroma_cloud_database: Optional[str] = Field( + None, + description="The database for Chroma Cloud.", + ) + chroma_cloud_api_key: Optional[str] = Field( + None, + description="The API key for Chroma Cloud.", + ) + # ------------ Debug settings ------------ debug_messages: bool = Field( diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index ff0f6221..a0560923 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -1,4 +1,5 @@ import datetime +import textwrap import warnings from contextlib import ExitStack, contextmanager from enum import Enum @@ -6,7 +7,9 @@ TYPE_CHECKING, Any, Callable, + Generator, GenericAlias, + Literal, Optional, TypeVar, Union, @@ -15,6 +18,7 @@ _LiteralGenericAlias, _SpecialGenericAlias, ) +from uuid import uuid4 from prefect.context import TaskRunContext from pydantic import ( @@ -25,10 +29,12 @@ field_serializer, field_validator, ) +from typing_extensions import Self import controlflow from controlflow.agents import Agent from controlflow.instructions import get_instructions +from controlflow.memory.memory import Memory from controlflow.tools import Tool, tool from controlflow.tools.input import cli_input from controlflow.tools.tools import as_tools @@ -37,6 +43,7 @@ NOTSET, ControlFlowModel, hash_objects, + unwrap, ) from controlflow.utilities.logging import get_logger from controlflow.utilities.prefect import prefect_task as prefect_task @@ -50,6 +57,9 @@ logger = get_logger(__name__) +COMPLETION_TOOLS = Literal["SUCCEED", "FAIL"] + + def get_task_run_name(): context = TaskRunContext.get() task = context.parameters["self"] @@ -97,8 +107,7 @@ class Task(ControlFlowModel): ) context: dict = Field( default_factory=dict, - description="Additional context for the task. If tasks are provided as " - "context, they are automatically added as `depends_on`", + description="Additional context for the task.", ) parent: Optional["Task"] = Field( NOTSET, @@ -141,11 +150,23 @@ class Task(ControlFlowModel): default_factory=list, description="Tools available to every agent working on this task.", ) + completion_tools: Optional[list[COMPLETION_TOOLS]] = Field( + default=None, + description=""" + Completion tools that will be generated for this task. If None, all + tools will be generated; if a list of strings, only the corresponding + tools will be generated automatically. + """, + ) completion_agents: Optional[list[Agent]] = Field( default=None, description="Agents that are allowed to mark this task as complete. If None, all agents are allowed.", ) interactive: bool = False + memories: list[Memory] = Field( + default=[], + description="A list of memory modules for the task to use.", + ) max_llm_calls: Optional[int] = Field( default_factory=lambda: controlflow.settings.task_max_llm_calls, description="Maximum number of LLM calls to make before the task should be marked as failed. " @@ -153,6 +174,10 @@ class Task(ControlFlowModel): "which this task is considered `assigned`.", ) created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) + wait_for_subtasks: bool = Field( + default=True, + description="If True, the task will not be considered ready until all subtasks are complete.", + ) _subtasks: set["Task"] = set() _downstreams: set["Task"] = set() _cm_stack: list[contextmanager] = [] @@ -207,16 +232,18 @@ def __init__( self.id = self._generate_id() def _generate_id(self): - return hash_objects( - ( - type(self).__name__, - self.objective, - self.instructions, - str(self.result_type), - self.prompt, - str(self.context), - ) - ) + return str(uuid4())[:8] + # generate a short, semi-stable ID for a task + # return hash_objects( + # ( + # type(self).__name__, + # self.objective, + # self.instructions, + # str(self.result_type), + # self.prompt, + # str(self.context), + # ) + # ) def __hash__(self) -> int: return id(self) @@ -231,9 +258,16 @@ def __eq__(self, other): if type(self) is type(other): d1 = dict(self) d2 = dict(other) + + for attr in ["id", "created_at"]: + d1.pop(attr) + d2.pop(attr) + # conver sets to lists for comparison d1["depends_on"] = list(d1["depends_on"]) d2["depends_on"] = list(d2["depends_on"]) + d1["subtasks"] = list(self.subtasks) + d2["subtasks"] = list(other.subtasks) return d1 == d2 return False @@ -241,6 +275,18 @@ def __repr__(self) -> str: serialized = self.model_dump(include={"id", "objective"}) return f"{self.__class__.__name__}({', '.join(f'{key}={repr(value)}' for key, value in serialized.items())})" + @field_validator("objective") + def _validate_objective(cls, v): + if v: + v = unwrap(v) + return v + + @field_validator("instructions") + def _validate_instructions(cls, v): + if v: + v = unwrap(v) + return v + @field_validator("agents") def _validate_agents(cls, v): if isinstance(v, list) and not v: @@ -335,7 +381,6 @@ def add_subtask(self, task: "Task"): elif task.parent is not self: raise ValueError(f"{self.friendly_name()} already has a parent.") self._subtasks.add(task) - self.depends_on.add(task) def add_dependency(self, task: "Task"): """ @@ -353,6 +398,8 @@ def run( max_llm_calls: int = None, max_agent_turns: int = None, handlers: list["Handler"] = None, + raise_on_failure: bool = True, + model_kwargs: Optional[dict] = None, ) -> T: """ Run the task @@ -365,13 +412,14 @@ def run( turn_strategy=turn_strategy, max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, - raise_on_error=False, + raise_on_failure=False, handlers=handlers, + model_kwargs=model_kwargs, ) if self.is_successful(): return self.result - elif self.is_failed(): + elif raise_on_failure and self.is_failed(): raise ValueError(f"{self.friendly_name()} failed: {self.result}") @prefect_task(task_run_name=get_task_run_name) @@ -383,6 +431,7 @@ async def run_async( max_llm_calls: int = None, max_agent_turns: int = None, handlers: list["Handler"] = None, + raise_on_failure: bool = True, ) -> T: """ Run the task @@ -395,22 +444,22 @@ async def run_async( turn_strategy=turn_strategy, max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, - raise_on_error=False, + raise_on_failure=False, handlers=handlers, ) if self.is_successful(): return self.result - elif self.is_failed(): + elif raise_on_failure and self.is_failed(): raise ValueError(f"{self.friendly_name()} failed: {self.result}") @contextmanager - def create_context(self): + def create_context(self) -> Generator[Self, None, None]: stack = ctx.get("tasks") or [] with ctx(tasks=stack + [self]): yield self - def __enter__(self): + def __enter__(self) -> Self: # use stack so we can enter the context multiple times self._cm_stack.append(ExitStack()) return self._cm_stack[-1].enter_context(self.create_context()) @@ -444,7 +493,11 @@ def is_ready(self) -> bool: Returns True if all dependencies are complete and this task is incomplete, meaning it is ready to be worked on. """ - return self.is_incomplete() and all(t.is_complete() for t in self.depends_on) + depends_on = self.depends_on + if self.wait_for_subtasks: + depends_on = depends_on.union(self._subtasks) + + return self.is_incomplete() and all(t.is_complete() for t in depends_on) def get_agents(self) -> list[Agent]: if self.agents is not None: @@ -463,17 +516,32 @@ def get_agents(self) -> list[Agent]: else: return [controlflow.defaults.agent] - def get_tools(self) -> list[Union[Tool, Callable]]: + def get_tools(self) -> list[Tool]: + """ + Return a list of all tools available for the task. + + Note this does not include completion tools, which are handled separately. + """ tools = self.tools.copy() if self.interactive: tools.append(cli_input) - return tools + for memory in self.memories: + tools.extend(memory.get_tools()) + return as_tools(tools) def get_completion_tools(self) -> list[Tool]: - tools = [ - self.create_success_tool(), - self.create_fail_tool(), - ] + """ + Return a list of all completion tools available for the task. + """ + tools = [] + completion_tools = self.completion_tools + if completion_tools is None: + completion_tools = ["SUCCEED", "FAIL"] + + if "SUCCEED" in completion_tools: + tools.append(self.get_success_tool()) + if "FAIL" in completion_tools: + tools.append(self.get_fail_tool()) return tools def get_prompt(self) -> str: @@ -506,12 +574,14 @@ def mark_failed(self, reason: Optional[str] = None): def mark_skipped(self): self.set_status(TaskStatus.SKIPPED) - def create_success_tool(self) -> Tool: + def get_success_tool(self) -> Tool: """ Create an agent-compatible tool for marking this task as successful. """ options = {} - instructions = None + instructions = unwrap(""" + Use this tool to mark the task as successful and provide a result. + """) result_schema = None # if the result_type is a tuple of options, then we want the LLM to provide @@ -532,10 +602,12 @@ def create_success_tool(self) -> Tool: options_str = "\n\n".join( f"Option {i}: {option}" for i, option in serialized_options.items() ) - instructions = f""" + instructions += "\n\n" + unwrap(""" Provide a single integer as the result, corresponding to the index - of your chosen option. Your options are: {options_str} - """ + of your chosen option. Your options are: + + {options_str} + """).format(options_str=options_str) # otherwise try to load the schema for the result type elif self.result_type is not None: @@ -571,7 +643,7 @@ def succeed(result: result_schema) -> str: # type: ignore return succeed - def create_fail_tool(self) -> Tool: + def get_fail_tool(self) -> Tool: """ Create an agent-compatible tool for failing this task. """ diff --git a/src/controlflow/tools/tools.py b/src/controlflow/tools/tools.py index 683f09e4..b6c9d19e 100644 --- a/src/controlflow/tools/tools.py +++ b/src/controlflow/tools/tools.py @@ -6,13 +6,12 @@ import langchain_core.tools import pydantic -import pydantic.v1 from langchain_core.messages import InvalidToolCall, ToolCall from prefect.utilities.asyncutils import run_coro_as_sync from pydantic import Field, PydanticSchemaGenerationError, TypeAdapter import controlflow -from controlflow.utilities.general import ControlFlowModel +from controlflow.utilities.general import ControlFlowModel, unwrap from controlflow.utilities.prefect import create_markdown_artifact, prefect_task TOOL_CALL_FUNCTION_RESULT_TEMPLATE = """ @@ -176,13 +175,13 @@ def from_function( if len(description) > 1024: raise ValueError( - inspect.cleandoc(f""" - {name}: The tool's description exceeds 1024 - characters. Please provide a shorter description, fewer - annotations, or pass - `include_param_descriptions=False` or - `include_return_description=False` to `from_function`. - """).replace("\n", " ") + unwrap(f""" + {name}: The tool's description exceeds 1024 + characters. Please provide a shorter description, fewer + annotations, or pass + `include_param_descriptions=False` or + `include_return_description=False` to `from_function`. + """) ) return cls( diff --git a/src/controlflow/utilities/general.py b/src/controlflow/utilities/general.py index 4e43ffac..ce8e5e07 100644 --- a/src/controlflow/utilities/general.py +++ b/src/controlflow/utilities/general.py @@ -1,5 +1,7 @@ import hashlib import json +import re +import textwrap from typing import Optional, Union import prefect @@ -32,6 +34,22 @@ def hash_objects(input_data: tuple, len: int = 8) -> str: return hasher.hexdigest()[:len] +def unwrap(text: str) -> str: + """ + Given a multi-line string, dedent, remove newlines within paragraphs, but keep paragraph breaks. + """ + # Dedent the text + dedented_text = textwrap.dedent(text) + + # Remove newlines within paragraphs, but keep paragraph breaks + cleaned_text = re.sub(r"(? 0 assert len(handler.agent_messages) == 1 + + +class TestLLMRules: + def test_llm_rules_from_model_openai(self): + agent = Agent(model=ChatOpenAI(model="gpt-4o-mini")) + rules = agent.get_llm_rules() + assert isinstance(rules, OpenAIRules) + + def test_llm_rules_from_model_anthropic(self): + agent = Agent(model=ChatAnthropic(model="claude-3-haiku-20240307")) + rules = agent.get_llm_rules() + assert isinstance(rules, AnthropicRules) + + def test_custom_llm_rules(self): + rules = LLMRules(model=None) + agent = Agent(llm_rules=rules, model=ChatOpenAI(model="gpt-4o-mini")) + assert agent.get_llm_rules() is rules diff --git a/tests/conftest.py b/tests/conftest.py index 1dc1acbe..dd63c005 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,29 +1,9 @@ import pytest from prefect.testing.utilities import prefect_test_harness -from controlflow.settings import temporary_settings - from .fixtures import * -@pytest.fixture(autouse=True, scope="session") -def temp_controlflow_settings(): - with temporary_settings( - pretty_print_agent_events=False, - log_all_messages=True, - log_level="DEBUG", - orchestrator_max_agent_turns=10, - orchestrator_max_llm_calls=10, - ): - yield - - -@pytest.fixture(autouse=True) -def reset_settings_after_each_test(): - with temporary_settings(): - yield - - @pytest.fixture(autouse=True, scope="session") def prefect_test_fixture(): """ diff --git a/tests/events/test_history.py b/tests/events/test_history.py index 2fc02b7d..a256652c 100644 --- a/tests/events/test_history.py +++ b/tests/events/test_history.py @@ -4,22 +4,22 @@ class TestFileHistory: - def test_write_to_thread_id_file(self, tmpdir): - h = FileHistory(base_path=tmpdir) + def test_write_to_thread_id_file(self, tmp_path): + h = FileHistory(base_path=tmp_path) event = UserMessage(content="test") thread_id = "abc" - # assert a file called 'abc.json' does not exist in tmpdir - assert not (tmpdir / f"{thread_id}.json").exists() + # assert a file called 'abc.json' does not exist in tmp_path + assert not (tmp_path / f"{thread_id}.json").exists() h.add_events(thread_id, [event]) - # assert a file called 'abc.json' exists in tmpdir - assert (tmpdir / f"{thread_id}.json").exists() + # assert a file called 'abc.json' exists in tmp_path + assert (tmp_path / f"{thread_id}.json").exists() - def test_read_from_thread_id_file(self, tmpdir): - h1 = FileHistory(base_path=tmpdir) - h2 = FileHistory(base_path=tmpdir) + def test_read_from_thread_id_file(self, tmp_path): + h1 = FileHistory(base_path=tmp_path) + h2 = FileHistory(base_path=tmp_path) event = UserMessage(content="test") thread_id = "abc" @@ -27,9 +27,9 @@ def test_read_from_thread_id_file(self, tmpdir): # read with different history object assert h2.get_events(thread_id) == [event] - def test_file_histories_respect_base_path(self, tmpdir): - h1 = FileHistory(base_path=tmpdir) - h2 = FileHistory(base_path=tmpdir / "subdir") + def test_file_histories_respect_base_path(self, tmp_path): + h1 = FileHistory(base_path=tmp_path) + h2 = FileHistory(base_path=tmp_path / "subdir") event = UserMessage(content="test") thread_id = "abc" @@ -38,27 +38,27 @@ def test_file_histories_respect_base_path(self, tmpdir): assert h2.get_events(thread_id) == [] assert h1.get_events(thread_id) == [event] - def test_file_history_creates_dir(self, tmpdir): - h = FileHistory(base_path=tmpdir / "subdir") + def test_file_history_creates_dir(self, tmp_path): + h = FileHistory(base_path=tmp_path / "subdir") event = UserMessage(content="test") thread_id = "abc" h.add_events(thread_id, [event]) - assert (tmpdir / "subdir" / f"{thread_id}.json").exists() + assert (tmp_path / "subdir" / f"{thread_id}.json").exists() class TestFileHistoryFlow: - def test_flow_uses_file_history(self, tmpdir): - f1 = Flow(thread_id="abc", history=FileHistory(base_path=tmpdir)) - f2 = Flow(thread_id="abc", history=FileHistory(base_path=tmpdir)) + def test_flow_uses_file_history(self, tmp_path): + f1 = Flow(thread_id="abc", history=FileHistory(base_path=tmp_path)) + f2 = Flow(thread_id="abc", history=FileHistory(base_path=tmp_path)) event = UserMessage(content="test") f1.add_events([event]) assert f2.get_events() == [event] - def test_flow_sets_thread_id_for_file_history(self, tmpdir): - f1 = Flow(thread_id="abc", history=FileHistory(base_path=tmpdir)) - f2 = Flow(thread_id="xyz", history=FileHistory(base_path=tmpdir)) - f3 = Flow(thread_id="abc", history=FileHistory(base_path=tmpdir)) + def test_flow_sets_thread_id_for_file_history(self, tmp_path): + f1 = Flow(thread_id="abc", history=FileHistory(base_path=tmp_path)) + f2 = Flow(thread_id="xyz", history=FileHistory(base_path=tmp_path)) + f3 = Flow(thread_id="abc", history=FileHistory(base_path=tmp_path)) f1.add_events([UserMessage(content="test")]) assert len(f1.get_events()) == 1 diff --git a/tests/fixtures/controlflow.py b/tests/fixtures/controlflow.py index 06d571a3..d62deed1 100644 --- a/tests/fixtures/controlflow.py +++ b/tests/fixtures/controlflow.py @@ -1,12 +1,53 @@ +import chromadb import pytest import controlflow -from controlflow.llm.messages import BaseMessage +from controlflow.events.history import InMemoryHistory +from controlflow.memory.providers.chroma import ChromaMemory +from controlflow.settings import temporary_settings from controlflow.utilities.testing import FakeLLM +@pytest.fixture(autouse=True, scope="session") +def temp_controlflow_settings(): + with temporary_settings( + pretty_print_agent_events=False, + log_all_messages=True, + log_level="DEBUG", + orchestrator_max_agent_turns=10, + orchestrator_max_llm_calls=10, + ): + yield + + +@pytest.fixture(autouse=True) +def reset_settings_after_each_test(): + with temporary_settings(): + yield + + +@pytest.fixture(autouse=True) +def temp_controlflow_defaults(tmp_path, monkeypatch): + # use in-memory history + monkeypatch.setattr( + controlflow.defaults, + "history", + InMemoryHistory(), + ) + + monkeypatch.setattr( + controlflow.defaults, + "memory_provider", + ChromaMemory( + client=chromadb.PersistentClient(path=str(tmp_path / "controlflow-memory")) + ), + ) + + yield + + @pytest.fixture(autouse=True) -def restore_defaults(monkeypatch): +def reset_defaults_after_each_test(monkeypatch): """ Monkeypatch defaults to themselves, which will automatically reset them after every test """ diff --git a/tests/memory/__init__.py b/tests/memory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/memory/test_memory.py b/tests/memory/test_memory.py new file mode 100644 index 00000000..3234d6dc --- /dev/null +++ b/tests/memory/test_memory.py @@ -0,0 +1,46 @@ +import chromadb +import pytest + +import controlflow +from controlflow.memory.providers.chroma import ChromaMemory + + +class TestMemory: + def test_store_and_retrieve(self): + m = controlflow.Memory(key="test", instructions="test") + m.add("The number is 42") + result = m.search("numbers") + assert len(result) == 1 + assert "The number is 42" in result.values() + + def test_delete(self): + m = controlflow.Memory(key="test", instructions="test") + m_id = m.add("The number is 42") + m.delete(m_id) + result = m.search("numbers") + assert len(result) == 0 + + def test_search(self): + m = controlflow.Memory(key="test", instructions="test") + m.add("The number is 42") + m.add("The number is 43") + result = m.search("numbers") + assert len(result) == 2 + assert "The number is 42" in result.values() + assert "The number is 43" in result.values() + + +class TestMemoryProvider: + def test_load_from_string_invalid(self): + with pytest.raises(ValueError): + controlflow.Memory(key="test", instructions="test", provider="invalid") + + def test_load_from_string_chroma_db(self): + m = controlflow.Memory(key="test", instructions="test", provider="chroma-db") + assert isinstance(m.provider, ChromaMemory) + + def test_load_from_instance(self, tmp_path): + mp = ChromaMemory( + client=chromadb.PersistentClient(path=str(tmp_path / "test_path")) + ) + m = controlflow.Memory(key="test", instructions="test", provider=mp) diff --git a/tests/orchestration/test_orchestrator.py b/tests/orchestration/test_orchestrator.py index 522329f5..a31e0e82 100644 --- a/tests/orchestration/test_orchestrator.py +++ b/tests/orchestration/test_orchestrator.py @@ -1,26 +1,28 @@ +from unittest.mock import MagicMock, patch + import pytest +import controlflow.orchestration.conditions from controlflow.agents import Agent from controlflow.flows import Flow from controlflow.orchestration.orchestrator import Orchestrator -from controlflow.orchestration.turn_strategies import ( # Add this import - Popcorn, - TurnStrategy, -) +from controlflow.orchestration.turn_strategies import Popcorn, TurnStrategy from controlflow.tasks.task import Task +from controlflow.utilities.testing import FakeLLM, SimpleTask class TestOrchestratorLimits: - call_count = 0 - turn_count = 0 - @pytest.fixture - def mocked_orchestrator(self, default_fake_llm): - # Reset counts at the start of each test - self.call_count = 0 - self.turn_count = 0 + def orchestrator(self, default_fake_llm): + default_fake_llm.set_responses([dict(name="count_call")]) + self.calls = 0 + self.turns = 0 class TwoCallTurnStrategy(TurnStrategy): + """ + A turn strategy that ends a turn after 2 calls + """ + calls: int = 0 def get_tools(self, *args, **kwargs): @@ -30,84 +32,52 @@ def get_next_agent(self, current_agent, available_agents): return current_agent def begin_turn(ts_instance): - self.turn_count += 1 + self.turns += 1 super().begin_turn() - def should_end_turn(ts_instance): - ts_instance.calls += 1 + def should_end_turn(ts_self): + ts_self.calls += 1 # if this would be the third call, end the turn - if ts_instance.calls >= 3: - ts_instance.calls = 0 + if ts_self.calls >= 3: + ts_self.calls = 0 return True # record a new call for the unit test - self.call_count += 1 + # self.calls += 1 return False - agent = Agent() + def count_call(): + self.calls += 1 + + agent = Agent(tools=[count_call]) task = Task("Test task", agents=[agent]) flow = Flow() orchestrator = Orchestrator( - tasks=[task], flow=flow, agent=agent, turn_strategy=TwoCallTurnStrategy() + tasks=[task], + flow=flow, + agent=agent, + turn_strategy=TwoCallTurnStrategy(), ) - return orchestrator - def test_default_limits(self, mocked_orchestrator): - mocked_orchestrator.run() - - assert self.turn_count == 5 - assert self.call_count == 10 - - @pytest.mark.parametrize( - "max_agent_turns, max_llm_calls, expected_turns, expected_calls", - [ - (1, 1, 1, 1), - (1, 2, 1, 2), - (5, 3, 2, 3), - (3, 12, 3, 6), - ], - ) - def test_custom_limits( - self, - mocked_orchestrator, - max_agent_turns, - max_llm_calls, - expected_turns, - expected_calls, - ): - mocked_orchestrator.run( - max_agent_turns=max_agent_turns, max_llm_calls=max_llm_calls + def test_max_llm_calls(self, orchestrator): + orchestrator.run(max_llm_calls=5) + assert self.calls == 5 + + def test_max_agent_turns(self, orchestrator): + orchestrator.run(max_agent_turns=3) + assert self.calls == 6 + + def test_max_llm_calls_and_max_agent_turns(self, orchestrator): + orchestrator.run( + max_llm_calls=10, + max_agent_turns=3, + model_kwargs={"tool_choice": "required"}, ) + assert self.calls == 6 - assert self.turn_count == expected_turns - assert self.call_count == expected_calls - - def test_task_limit(self, mocked_orchestrator): - task = Task("Test task", max_llm_calls=5, agents=[mocked_orchestrator.agent]) - mocked_orchestrator.tasks = [task] - mocked_orchestrator.run() - assert task.is_failed() - assert self.turn_count == 3 - # Note: the call count will be 6 because the orchestrator call count is - # incremented in "should_end_turn" which is called before the task's - # call count is evaluated - assert self.call_count == 6 - - def test_task_lifetime_limit(self, mocked_orchestrator): - task = Task("Test task", max_llm_calls=5, agents=[mocked_orchestrator.agent]) - mocked_orchestrator.tasks = [task] - mocked_orchestrator.run(max_agent_turns=1) - assert task.is_incomplete() - mocked_orchestrator.run(max_agent_turns=1) - assert task.is_incomplete() - mocked_orchestrator.run(max_agent_turns=1) - assert task.is_failed() - - assert self.turn_count == 3 - # Note: the call count will be 6 because the orchestrator call count is - # incremented in "should_end_turn" which is called before the task's - # call count is evaluated - assert self.call_count == 6 + def test_default_limits(self, orchestrator): + orchestrator.run(model_kwargs={"tool_choice": "required"}) + assert self.calls == 10 # Assuming the default max_llm_calls is 10 class TestOrchestratorCreation: @@ -162,3 +132,120 @@ def test_run_keeps_existing_agent_if_set(self): orchestrator.run(max_agent_turns=0) assert orchestrator.agent == agent1 + + +class TestRunEndConditions: + def test_run_until_all_complete(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + task1.mark_successful() + task2.mark_successful() + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run(run_until=controlflow.orchestration.conditions.AllComplete()) + + assert all(task.is_complete() for task in orchestrator.tasks) + + def test_run_until_any_complete(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + task1.mark_successful() + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run(run_until=controlflow.orchestration.conditions.AnyComplete()) + + assert any(task.is_complete() for task in orchestrator.tasks) + + def test_run_until_fn_condition(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + task2.mark_successful() + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run( + run_until=controlflow.orchestration.conditions.FnCondition( + lambda context: context.orchestrator.tasks[1].is_complete() + ) + ) + + assert task2.is_complete() + + def test_run_until_lambda_condition(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + task2.mark_successful() + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run( + run_until=lambda context: context.orchestrator.tasks[1].is_complete() + ) + + assert task2.is_complete() + + def test_compound_condition(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run( + run_until=( + # this condition will always fail + controlflow.orchestration.conditions.FnCondition(lambda context: False) + | + # this condition will always pass + controlflow.orchestration.conditions.FnCondition(lambda context: True) + ) + ) + + # assert to prove we reach this point and the run stopped + assert True diff --git a/tests/tasks/test_decorator.py b/tests/tasks/test_decorator.py deleted file mode 100644 index 360876ef..00000000 --- a/tests/tasks/test_decorator.py +++ /dev/null @@ -1,45 +0,0 @@ -import controlflow - - -class TestDecorator: - def test_decorator(self): - @controlflow.task - def write_poem(topic: str) -> str: - """write a poem about `topic`""" - - task = write_poem.as_task("AI") - assert task.name == "write_poem" - assert task.objective == "write a poem about `topic`" - assert task.result_type is str - - def test_decorator_can_return_context(self): - @controlflow.task - def write_poem(topic: str) -> str: - return f"write a poem about {topic}" - - task = write_poem.as_task("AI") - assert task.context["Additional context"] == "write a poem about AI" - - def test_return_annotation(self): - @controlflow.task - def generate_tags(text: str) -> list[str]: - """Generate a list of tags for the given text.""" - - task = generate_tags.as_task("Fly me to the moon") - assert task.result_type == list[str] - - def test_objective_can_be_provided_as_kwarg(self): - @controlflow.task(objective="Write a poem about `topic`") - def write_poem(topic: str) -> str: - """Writes a poem.""" - - task = write_poem.as_task("AI") - assert task.objective == "Write a poem about `topic`" - - def test_run_task(self): - @controlflow.task - def extract_fruit(text: str) -> list[str]: - return "Extract any fruit mentioned in the text; all lowercase" - - result = extract_fruit("I like apples and bananas") - assert result == ["apples", "bananas"] diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index f2da7123..c1f7f461 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -1,4 +1,5 @@ -from typing import Annotated, Any, Dict, List +from enum import Enum +from typing import Annotated, Any, Dict, List, Literal import pytest from pydantic import BaseModel @@ -46,6 +47,7 @@ def test_task_initialization(): assert task.result is None +@pytest.mark.skip(reason="IDs are not stable right now") def test_stable_id(): t1 = Task(objective="Test Objective") t2 = Task(objective="Test Objective") @@ -265,27 +267,19 @@ def test_str_result(self): task.mark_successful(result="5") assert task.result == "5" - def test_tuple_of_ints_result(self): - task = Task("choose 5", result_type=(4, 5, 6)) - task.mark_successful(result=5) - assert task.result == 5 - - def test_tuple_of_ints_validates(self): - task = Task("choose 5", result_type=(4, 5, 6)) - with pytest.raises(ValueError): - task.mark_successful(result=7) - def test_typed_dict_result(self): task = Task("", result_type=dict[str, int]) task.mark_successful(result={"a": 5, "b": "6"}) assert task.result == {"a": 5, "b": 6} def test_special_list_type_result(self): + # test capitalized List type task = Task("", result_type=List[int]) task.mark_successful(result=[5, 6]) assert task.result == [5, 6] def test_special_dict_type_result(self): + # test capitalized Dict type task = Task("", result_type=Dict[str, int]) task.mark_successful(result={"a": 5, "b": "6"}) assert task.result == {"a": 5, "b": 6} @@ -309,6 +303,107 @@ def test_annotated_result(self): assert int(task.result) +class TestResultTypeConstrainedChoice: + class Letter(BaseModel): + letter: str + + def __hash__(self): + return id(self) + + A = Letter(letter="a") + B = Letter(letter="b") + C = Letter(letter="c") + + def test_tuple_of_ints_result(self): + task = Task("choose 5", result_type=(4, 5, 6)) + task.mark_successful(result=5) + assert task.result == 5 + + def test_tuple_of_ints_validates(self): + task = Task("choose 5", result_type=(4, 5, 6)) + with pytest.raises(ValueError): + task.mark_successful(result=7) + + def test_list_of_strings_result(self): + # test list of strings result + task = Task( + "Choose the second letter of the alphabet", result_type=["b", "c", "a"] + ) + task.run() + assert task.result == "b" + + def test_list_of_objects_result(self): + # test list of strings result + task = Task( + "Choose the second letter of the alphabet", + result_type=[self.A, self.C, self.B], + ) + task.run() + assert task.result is self.B + + def test_tuple_of_objects_result(self): + # test list of strings result + task = Task( + "Choose the second letter of the alphabet", + result_type=(self.A, self.C, self.B), + ) + task.run() + assert task.result is self.B + + def test_set_of_objects_result(self): + # test list of strings result + task = Task( + "Choose the second letter of the alphabet", + result_type={self.A, self.C, self.B}, + ) + task.run() + assert task.result is self.B + + def test_literal_string_result(self): + task = Task( + "Choose the second letter of the alphabet", + result_type=Literal["a", "c", "b"], + ) + task.run() + assert task.result == "b" + + def test_enum_result(self): + class Letters(Enum): + A = "a" + B = "b" + C = "c" + + task = Task("Choose the second letter of the alphabet", result_type=Letters) + task.run() + assert task.result is Letters.B + + def test_literal_object_result(self): + # this is bad syntax, but works + task = Task( + "Choose the second letter of the alphabet", + result_type=Literal[self.A, self.B, self.C], # noqa + ) + task.run() + assert task.result is self.B + + def test_list_of_literals_result(self): + task = Task( + "Choose the second and third letters of the alphabet", + result_type=list[Literal["a", "b", "c"]], + ) + task.run() + assert task.result == ["b", "c"] + + def test_map_labels_to_values(self): + task = Task( + "Choose the right label, in order provided in context", + context=dict(goals=["the second letter", "the first letter"]), + result_type=list[Literal["a", "b", "c"]], + ) + task.run() + assert task.result == ["b", "a"] + + class TestResultValidator: def test_result_validator(self): def validate_even(value: int) -> int: @@ -389,27 +484,27 @@ def always_return_none(value: Any) -> None: class TestSuccessTool: def test_success_tool(self): task = Task("choose 5", result_type=int) - tool = task.create_success_tool() + tool = task.get_success_tool() tool.run(input=dict(result=5)) assert task.is_successful() assert task.result == 5 def test_success_tool_with_list_of_options(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) - tool = task.create_success_tool() + tool = task.get_success_tool() tool.run(input=dict(result=1)) assert task.is_successful() assert task.result == "good" def test_success_tool_with_list_of_options_requires_int(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) - tool = task.create_success_tool() + tool = task.get_success_tool() with pytest.raises(ValueError): tool.run(input=dict(result="good")) def test_tuple_of_ints_result(self): task = Task("choose 5", result_type=(4, 5, 6)) - tool = task.create_success_tool() + tool = task.get_success_tool() tool.run(input=dict(result=1)) assert task.result == 5 @@ -422,14 +517,14 @@ class Person(BaseModel): "Who is the oldest?", result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)), ) - tool = task.create_success_tool() + tool = task.get_success_tool() tool.run(input=dict(result=1)) assert task.result == Person(name="Bob", age=35) assert isinstance(task.result, Person) class TestHandlers: - class TestHandler(Handler): + class ExampleHandler(Handler): def __init__(self): self.events = [] self.agent_messages = [] @@ -441,18 +536,85 @@ def on_agent_message(self, event: AgentMessage): self.agent_messages.append(event) def test_task_run_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + handler = self.ExampleHandler() task = Task(objective="Calculate 2 + 2", result_type=int) task.run(handlers=[handler], max_llm_calls=1) assert len(handler.events) > 0 assert len(handler.agent_messages) == 1 - @pytest.mark.asyncio async def test_task_run_async_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + handler = self.ExampleHandler() task = Task(objective="Calculate 2 + 2", result_type=int) await task.run_async(handlers=[handler], max_llm_calls=1) assert len(handler.events) > 0 assert len(handler.agent_messages) == 1 + + +class TestCompletionTools: + def test_default_completion_tools(self): + task = Task(objective="Test task") + assert task.completion_tools is None + tools = task.get_completion_tools() + assert len(tools) == 2 + assert any(t.name == f"mark_task_{task.id}_successful" for t in tools) + assert any(t.name == f"mark_task_{task.id}_failed" for t in tools) + + def test_only_succeed_tool(self): + task = Task(objective="Test task", completion_tools=["SUCCEED"]) + tools = task.get_completion_tools() + assert len(tools) == 1 + assert tools[0].name == f"mark_task_{task.id}_successful" + + def test_only_fail_tool(self): + task = Task(objective="Test task", completion_tools=["FAIL"]) + tools = task.get_completion_tools() + assert len(tools) == 1 + assert tools[0].name == f"mark_task_{task.id}_failed" + + def test_no_completion_tools(self): + task = Task(objective="Test task", completion_tools=[]) + tools = task.get_completion_tools() + assert len(tools) == 0 + + def test_invalid_completion_tool(self): + with pytest.raises(ValueError): + Task(objective="Test task", completion_tools=["INVALID"]) + + def test_manual_success_tool(self): + task = Task(objective="Test task", completion_tools=[], result_type=int) + success_tool = task.get_success_tool() + success_tool.run(input=dict(result=5)) + assert task.is_successful() + assert task.result == 5 + + def test_manual_fail_tool(self): + task = Task(objective="Test task", completion_tools=[]) + fail_tool = task.get_fail_tool() + assert fail_tool.name == f"mark_task_{task.id}_failed" + fail_tool.run(input=dict(reason="test error")) + assert task.is_failed() + assert task.result == "test error" + + def test_completion_tools_with_run(self): + task = Task("Calculate 2 + 2", result_type=int, completion_tools=["SUCCEED"]) + result = task.run(max_llm_calls=1) + assert result == 4 + assert task.is_successful() + + def test_no_completion_tools_with_run(self): + task = Task("Calculate 2 + 2", result_type=int, completion_tools=[]) + task.run(max_llm_calls=1) + assert task.is_incomplete() + + async def test_completion_tools_with_run_async(self): + task = Task("Calculate 2 + 2", result_type=int, completion_tools=["SUCCEED"]) + result = await task.run_async(max_llm_calls=1) + assert result == 4 + assert task.is_successful() + + async def test_no_completion_tools_with_run_async(self): + task = Task("Calculate 2 + 2", result_type=int, completion_tools=[]) + await task.run_async(max_llm_calls=1) + assert task.is_incomplete() diff --git a/tests/test_decorator.py b/tests/test_decorator.py new file mode 100644 index 00000000..468b652f --- /dev/null +++ b/tests/test_decorator.py @@ -0,0 +1,140 @@ +import asyncio + +import controlflow + + +class TestDecorator: + def test_decorator(self): + @controlflow.task + def write_poem(topic: str) -> str: + """write a poem about `topic`""" + + task = write_poem.as_task("AI") + assert task.name == "write_poem" + assert task.objective == "write a poem about `topic`" + assert task.result_type is str + + def test_decorator_can_return_context(self): + @controlflow.task + def write_poem(topic: str) -> str: + return f"write a poem about {topic}" + + task = write_poem.as_task("AI") + assert task.context["Additional context"] == "write a poem about AI" + + def test_return_annotation(self): + @controlflow.task + def generate_tags(text: str) -> list[str]: + """Generate a list of tags for the given text.""" + + task = generate_tags.as_task("Fly me to the moon") + assert task.result_type == list[str] + + def test_objective_can_be_provided_as_kwarg(self): + @controlflow.task(objective="Write a poem about `topic`") + def write_poem(topic: str) -> str: + """Writes a poem.""" + + task = write_poem.as_task("AI") + assert task.objective == "Write a poem about `topic`" + + def test_run_task(self): + @controlflow.task + def extract_fruit(text: str) -> list[str]: + return "Extract any fruit mentioned in the text; all lowercase" + + result = extract_fruit("I like apples and bananas") + assert result == ["apples", "bananas"] + + +class TestFlowDecorator: + def test_sync_flow_decorator(self): + @controlflow.flow + def sync_flow(): + return 10 + + result = sync_flow() + assert result == 10 + + async def test_async_flow_decorator(self): + @controlflow.flow + async def async_flow(): + await asyncio.sleep(0.1) + return 10 + + result = await async_flow() + assert result == 10 + + def test_flow_decorator_preserves_function_metadata(self): + @controlflow.flow + def flow_with_metadata(): + """This is a test flow.""" + return 10 + + assert flow_with_metadata.__name__ == "flow_with_metadata" + assert flow_with_metadata.__doc__ == "This is a test flow." + + def test_flow_decorator_with_arguments(self): + @controlflow.flow(thread="test_thread", instructions="Test instructions") + def flow_with_args(x: int): + return x + 10 + + result = flow_with_args(5) + assert result == 15 + + async def test_async_flow_decorator_with_arguments(self): + @controlflow.flow( + thread="async_test_thread", instructions="Async test instructions" + ) + async def async_flow_with_args(x: int): + await asyncio.sleep(0.1) + return x + 10 + + result = await async_flow_with_args(5) + assert result == 15 + + def test_flow_decorator_partial_application(self): + custom_flow = controlflow.flow(thread="custom_thread") + + @custom_flow + def partial_flow(): + return 10 + + result = partial_flow() + assert result == 10 + + +class TestTaskDecorator: + def test_task_decorator_sync_as_task(self): + @controlflow.task + def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + task = write_poem.as_task("AI") + assert task.name == "write_poem" + assert task.objective == "write a two-line poem about `topic`" + assert task.result_type is str + + def test_task_decorator_async_as_task(self): + @controlflow.task + async def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + task = write_poem.as_task("AI") + assert task.name == "write_poem" + assert task.objective == "write a two-line poem about `topic`" + assert task.result_type is str + + def test_task_decorator_sync(self): + @controlflow.task + def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + assert write_poem("AI") + + async def test_task_decorator_async(self): + @controlflow.task + async def write_poem(topic: str) -> str: + """write a two-line poem about `topic`""" + + assert await write_poem("AI") diff --git a/tests/test_defaults.py b/tests/test_defaults.py index 98ec5bf3..4b6bf55c 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -9,7 +9,8 @@ def test_default_model_failed_validation(): with pytest.raises( - pydantic.ValidationError, match="Input must be an instance of BaseChatModel" + pydantic.ValidationError, + match="Input must be an instance of dict or BaseChatModel", ): controlflow.defaults.model = 5 diff --git a/tests/test_run.py b/tests/test_run.py index 41f5d470..c3a2fdd0 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,11 +1,14 @@ +from controlflow import instructions from controlflow.events.base import Event from controlflow.events.events import AgentMessage +from controlflow.orchestration.conditions import AnyComplete, AnyFailed, MaxLLMCalls from controlflow.orchestration.handler import Handler -from controlflow.run import run, run_async +from controlflow.run import run, run_async, run_tasks, run_tasks_async +from controlflow.tasks.task import Task class TestHandlers: - class TestHandler(Handler): + class ExampleHandler(Handler): def __init__(self): self.events = [] self.agent_messages = [] @@ -17,13 +20,13 @@ def on_agent_message(self, event: AgentMessage): self.agent_messages.append(event) def test_run_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + handler = self.ExampleHandler() run("what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1) assert len(handler.events) > 0 assert len(handler.agent_messages) == 1 async def test_run_async_with_handlers(self, default_fake_llm): - handler = self.TestHandler() + handler = self.ExampleHandler() await run_async( "what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1 ) @@ -40,3 +43,127 @@ def test_run(): async def test_run_async(): result = await run_async("what's 2 + 2", result_type=int) assert result == 4 + + +class TestRunUntil: + def test_any_complete(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("complete task 2"): + run_tasks([task1, task2], run_until=AnyComplete()) + + assert task2.is_complete() + assert task1.is_incomplete() + + def test_any_failed(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("fail task 2"): + run_tasks([task1, task2], run_until=AnyFailed(), raise_on_failure=False) + + assert task2.is_failed() + assert task1.is_incomplete() + + def test_max_llm_calls(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("say hi but do not complete any tasks"): + run_tasks([task1, task2], run_until=MaxLLMCalls(1)) + + assert task2.is_incomplete() + assert task1.is_incomplete() + + def test_min_complete(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + task3 = Task("Task 3") + + with instructions("complete tasks 1 and 2"): + run_tasks([task1, task2, task3], run_until=AnyComplete(min_complete=2)) + + assert task1.is_complete() + assert task2.is_complete() + assert task3.is_incomplete() + + def test_min_failed(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + task3 = Task("Task 3") + + with instructions("fail tasks 1 and 3"): + run_tasks( + [task1, task2, task3], + run_until=AnyFailed(min_failed=2), + raise_on_failure=False, + ) + + assert task1.is_failed() + assert task2.is_incomplete() + assert task3.is_failed() + + +class TestRunUntilAsync: + async def test_any_complete(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("complete task 2"): + await run_tasks_async([task1, task2], run_until=AnyComplete()) + + assert task2.is_complete() + assert task1.is_incomplete() + + async def test_any_failed(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("fail task 2"): + await run_tasks_async( + [task1, task2], run_until=AnyFailed(), raise_on_failure=False + ) + + assert task2.is_failed() + assert task1.is_incomplete() + + async def test_max_llm_calls(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("say hi but do not complete any tasks"): + await run_tasks_async([task1, task2], run_until=MaxLLMCalls(1)) + + assert task2.is_incomplete() + assert task1.is_incomplete() + + async def test_min_complete(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + task3 = Task("Task 3") + + with instructions("complete tasks 1 and 2"): + await run_tasks_async( + [task1, task2, task3], run_until=AnyComplete(min_complete=2) + ) + + assert task1.is_complete() + assert task2.is_complete() + assert task3.is_incomplete() + + async def test_min_failed(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + task3 = Task("Task 3") + + with instructions("fail tasks 1 and 3"): + await run_tasks_async( + [task1, task2, task3], + run_until=AnyFailed(min_failed=2), + raise_on_failure=False, + ) + + assert task1.is_failed() + assert task2.is_incomplete() + assert task3.is_failed() diff --git a/tests/test_settings.py b/tests/test_settings.py index 560e4b48..cf27db5a 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,5 +1,6 @@ import importlib +import openai import pytest from prefect.logging import get_logger @@ -78,7 +79,9 @@ def test_import_without_default_api_key_errors_when_loading_model(monkeypatch): importlib.reload(defaults_module) importlib.reload(controlflow) - with pytest.raises(ValueError, match="Did not find openai_api_key"): + with pytest.raises( + openai.OpenAIError, match="api_key client option must be set" + ): controlflow.llm.models.get_default_model() with pytest.raises( diff --git a/tests/tools/test_lc_tools.py b/tests/tools/test_lc_tools.py index 0a6d54b0..3e380843 100644 --- a/tests/tools/test_lc_tools.py +++ b/tests/tools/test_lc_tools.py @@ -2,7 +2,7 @@ from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.tools import BaseTool -from pydantic.v1 import BaseModel +from pydantic import BaseModel import controlflow from controlflow.events.events import AIMessage, ToolCall @@ -13,8 +13,8 @@ class LCBaseToolInput(BaseModel): class LCBaseTool(BaseTool): - name = "TestTool" - description = "A test tool" + name: str = "TestTool" + description: str = "A test tool" args_schema: type[BaseModel] = LCBaseToolInput def _run(self, x: int) -> str: diff --git a/tests/utilities/test_general.py b/tests/utilities/test_general.py new file mode 100644 index 00000000..d9467b3d --- /dev/null +++ b/tests/utilities/test_general.py @@ -0,0 +1,41 @@ +import controlflow.utilities.general as general + + +class TestUnwrap: + def test_unwrap(self): + assert general.unwrap("Hello, world!") == "Hello, world!" + assert ( + general.unwrap("Hello, world!\nThis is a test.") + == "Hello, world! This is a test." + ) + assert ( + general.unwrap("Hello, world!\nThis is a test.\n\nThis is another test.") + == "Hello, world! This is a test.\n\nThis is another test." + ) + + def test_unwrap_with_empty_string(self): + assert general.unwrap("") == "" + + def test_unwrap_with_multiple_newlines(self): + assert general.unwrap("\n\n\n") == "" + + def test_unwrap_with_multiline_string(self): + assert ( + general.unwrap(""" + Hello, world! + This is a test. + This is another test. + """) + == "Hello, world! This is a test. This is another test." + ) + + def test_unwrap_with_multiline_string_and_newlines(self): + assert ( + general.unwrap(""" + Hello, world! + This is a test. + + This is another test. + """) + == "Hello, world! This is a test.\n\nThis is another test." + )