Skip to content

Commit

Permalink
Merge pull request #354 from PrefectHQ/flow-context
Browse files Browse the repository at this point in the history
Make flow context explicit
  • Loading branch information
jlowin authored Oct 9, 2024
2 parents 33c0701 + 622db7e commit 4fec8c2
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 41 deletions.
15 changes: 7 additions & 8 deletions docs/concepts/flows.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,18 @@ The following flow properties are inferred from the decorated function:
| ------------- | ------------- |
| `name` | The function's name |
| `description` | The function's docstring |
| `context` | The function's arguments (keyed by argument name) |
| `context` | The function's arguments, if specified as `context_kwargs` (keyed by argument name) |

Additional properties can be set by passing keyword arguments directly to the `@flow` decorator or to the `flow_kwargs` parameter when calling the decorated function.

<Tip>
You may not want the arguments to your flow function to be used as context. In that case, you can set `args_as_context=False` when decorating or calling the function:
To automatically put some of your flow's arguments into the global context that all agents can see, specify `context_kwargs` when decorating your flow:

```python
@cf.flow(args_as_context=False)
def my_flow(secret_var: str):
@cf.flow(context_kwargs=["x"])
def my_flow(x: int, y: int):
# x will be automatically added to a global, agent-visible context
...
```
</Tip>

Additional properties can be set by passing keyword arguments directly to the `@flow` decorator or to the `flow_kwargs` parameter when calling the decorated function.

### The `Flow` object and context manager

Expand Down
68 changes: 35 additions & 33 deletions src/controlflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def flow(
thread: Optional[str] = None,
instructions: Optional[str] = None,
tools: Optional[list[Callable[..., Any]]] = None,
default_agent: Optional[Agent] = None, # Changed from 'agents'
default_agent: Optional[Agent] = None,
retries: Optional[int] = None,
retry_delay_seconds: Optional[Union[float, int]] = None,
timeout_seconds: Optional[Union[float, int]] = None,
prefect_kwargs: Optional[dict[str, Any]] = None,
args_as_context: Optional[bool] = True,
context_kwargs: Optional[list[str]] = None,
**kwargs: Optional[dict[str, Any]],
):
"""
Expand All @@ -46,67 +46,69 @@ def flow(
instructions (str, optional): Instructions for the flow. Defaults to None.
tools (list[Callable], optional): List of tools to be used in the flow. Defaults to None.
default_agent (Agent, optional): The default agent to be used in the flow. Defaults to None.
args_as_context (bool, optional): Whether to pass the arguments as context to the flow. Defaults to True.
context_kwargs (list[str], optional): List of argument names to be added to the flow context.
Defaults to None.
Returns:
callable: The wrapped function or a new flow decorator if `fn` is not provided.
"""
...

if fn is None:
return functools.partial(
flow,
thread=thread,
instructions=instructions,
tools=tools,
default_agent=default_agent, # Changed from 'agents'
default_agent=default_agent,
retries=retries,
retry_delay_seconds=retry_delay_seconds,
timeout_seconds=timeout_seconds,
args_as_context=args_as_context,
context_kwargs=context_kwargs,
**kwargs,
)

sig = inspect.signature(fn)

def _inner_wrapper(*wrapper_args, flow_kwargs: dict = None, **wrapper_kwargs):
# first process callargs
bound = sig.bind(*wrapper_args, **wrapper_kwargs)
bound.apply_defaults()

flow_kwargs = kwargs | (flow_kwargs or {})

def create_flow_context(bound_args):
flow_kwargs = kwargs.copy()
if thread is not None:
flow_kwargs.setdefault("thread_id", thread)
if tools is not None:
flow_kwargs.setdefault("tools", tools)
if default_agent is not None: # Changed from 'agents'
flow_kwargs.setdefault(
"default_agent", default_agent
) # Changed from 'agents'

context = bound.arguments if args_as_context else {}

with (
Flow(
name=fn.__name__,
description=fn.__doc__,
context=context,
**flow_kwargs,
),
controlflow.instructions(instructions),
):
return fn(*wrapper_args, **wrapper_kwargs)
if default_agent is not None:
flow_kwargs.setdefault("default_agent", default_agent)

context = {}
if context_kwargs:
context = {k: bound_args[k] for k in context_kwargs if k in bound_args}

return Flow(
name=fn.__name__,
description=fn.__doc__,
context=context,
**flow_kwargs,
)

if asyncio.iscoroutinefunction(fn):

@functools.wraps(fn)
async def wrapper(*wrapper_args, **wrapper_kwargs):
return await _inner_wrapper(*wrapper_args, **wrapper_kwargs)
bound = sig.bind(*wrapper_args, **wrapper_kwargs)
bound.apply_defaults()
with (
create_flow_context(bound.arguments),
controlflow.instructions(instructions),
):
return await fn(*wrapper_args, **wrapper_kwargs)
else:

@functools.wraps(fn)
def wrapper(*wrapper_args, **wrapper_kwargs):
return _inner_wrapper(*wrapper_args, **wrapper_kwargs)
bound = sig.bind(*wrapper_args, **wrapper_kwargs)
bound.apply_defaults()
with (
create_flow_context(bound.arguments),
controlflow.instructions(instructions),
):
return fn(*wrapper_args, **wrapper_kwargs)

wrapper = prefect_flow(
timeout_seconds=timeout_seconds,
Expand Down
27 changes: 27 additions & 0 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,33 @@ def partial_flow():
result = partial_flow()
assert result == 10

def test_flow_decorator_with_context_kwargs(self):
@controlflow.flow(context_kwargs=["x", "z"])
def flow_with_context(x: int, y: int, z: str):
flow = controlflow.flows.get_flow()
return flow.context

result = flow_with_context(1, 2, "test")
assert result == {"x": 1, "z": "test"}

def test_flow_decorator_without_context_kwargs(self):
@controlflow.flow
def flow_without_context(x: int, y: int, z: str):
flow = controlflow.flows.get_flow()
return flow.context

result = flow_without_context(1, 2, "test")
assert result == {}

async def test_async_flow_decorator_with_context_kwargs(self):
@controlflow.flow(context_kwargs=["a", "b"])
async def async_flow_with_context(a: int, b: str, c: float):
flow = controlflow.flows.get_flow()
return flow.context

result = await async_flow_with_context(10, "hello", 3.14)
assert result == {"a": 10, "b": "hello"}


class TestTaskDecorator:
def test_task_decorator_sync_as_task(self):
Expand Down

0 comments on commit 4fec8c2

Please sign in to comment.