Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: multi hop agent #1

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7914525
implement multihop agent
BeatrixCohere Mar 26, 2024
c047868
remove
BeatrixCohere Mar 26, 2024
8b0c86a
Update prompt
BeatrixCohere Mar 27, 2024
157fcd0
Remove new line
BeatrixCohere Mar 27, 2024
db7be4d
Fix notebok
BeatrixCohere Mar 27, 2024
6853755
Lint
BeatrixCohere Mar 27, 2024
bb4762d
Spelling
BeatrixCohere Mar 27, 2024
768dd37
Fix test
BeatrixCohere Mar 27, 2024
eeb46f1
Merge conflicts
BeatrixCohere Mar 28, 2024
a549a11
Add chat history
BeatrixCohere Mar 28, 2024
068549b
Fix prompt
BeatrixCohere Mar 28, 2024
4bde450
Update libs/partners/cohere/langchain_cohere/multi_hop/agent.py
efriis Mar 28, 2024
fc98316
Remove new lines
BeatrixCohere Mar 28, 2024
e705454
Merge branch 'beatrix/MultiHopAgent' of github.com:BeatrixCohere/lang…
BeatrixCohere Mar 28, 2024
74b55e3
Update the naming and notebook
BeatrixCohere Mar 28, 2024
fed0bad
Test
BeatrixCohere Mar 29, 2024
e00a3f1
Add premable override
BeatrixCohere Mar 29, 2024
a9e6a4a
Fix text repsonse
BeatrixCohere Mar 29, 2024
e3cf615
Format
BeatrixCohere Mar 29, 2024
a24ffc3
Delete
BeatrixCohere Mar 29, 2024
59c8c39
Fix formatting
BeatrixCohere Mar 29, 2024
e46b382
Increase default timeout
harry-cohere Mar 29, 2024
c31c157
Fix stop sequences
harry-cohere Mar 29, 2024
b3b9d3f
convert parameter types
harry-cohere Mar 29, 2024
ec58586
add resilience to action parsing
harry-cohere Mar 29, 2024
403f78a
prompt rendering changes
harry-cohere Mar 29, 2024
8f8d8f0
prompt rendering changes
harry-cohere Mar 29, 2024
a6f445d
fix types
harry-cohere Mar 29, 2024
610d9c4
don't nest prompt function
harry-cohere Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 318 additions & 0 deletions libs/partners/cohere/docs/multi_hop_agent.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions libs/partners/cohere/langchain_cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.embeddings import CohereEmbeddings
from langchain_cohere.rag_retrievers import CohereRagRetriever
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent
from langchain_cohere.rerank import CohereRerank

__all__ = [
Expand All @@ -11,4 +12,5 @@
"CohereRagRetriever",
"CohereRerank",
"create_cohere_tools_agent",
"create_cohere_react_agent",
]
18 changes: 14 additions & 4 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_cohere_chat_request(
*,
documents: Optional[List[Dict[str, str]]] = None,
connectors: Optional[List[Dict[str, str]]] = None,
stop_sequences: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Get the request for the Cohere chat API.
Expand Down Expand Up @@ -115,6 +116,7 @@ def get_cohere_chat_request(
"documents": formatted_docs,
"connectors": connectors,
"prompt_truncation": prompt_truncation,
"stop_sequences": stop_sequences,
**kwargs,
}

Expand Down Expand Up @@ -180,7 +182,9 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)

if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
Expand Down Expand Up @@ -210,7 +214,9 @@ async def _astream(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)

if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
Expand Down Expand Up @@ -266,7 +272,9 @@ def _generate(
)
return generate_from_stream(stream_iter)

request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)
response = self.client.chat(**request)

generation_info = self._get_generation_info(response)
Expand All @@ -290,7 +298,9 @@ async def _agenerate(
)
return await agenerate_from_stream(stream_iter)

request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)
response = self.client.chat(**request)

generation_info = self._get_generation_info(response)
Expand Down
23 changes: 18 additions & 5 deletions libs/partners/cohere/langchain_cohere/cohere_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableLambda
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.function_calling import (
PYTHON_TO_JSON_TYPES,
convert_to_openai_function,
)

JSON_TO_PYTHON_TYPES = {v: k for k, v in PYTHON_TO_JSON_TYPES.items()}


def create_cohere_tools_agent(
Expand Down Expand Up @@ -83,8 +88,12 @@ def _convert_to_cohere_tool(
description=tool.description,
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
description=param_definition.get("description")
if "description" in param_definition
else "",
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required="default" not in param_definition,
)
for param_name, param_definition in tool.args.items()
Expand All @@ -101,7 +110,9 @@ def _convert_to_cohere_tool(
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required="default" not in param_definition,
)
for param_name, param_definition in tool.get("properties", {}).items()
Expand All @@ -121,7 +132,9 @@ def _convert_to_cohere_tool(
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required=param_name in parameters.get("required", []),
)
for param_name, param_definition in properties.items()
Expand Down
5 changes: 5 additions & 0 deletions libs/partners/cohere/langchain_cohere/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,24 @@ class BaseCohere(Serializable):
user_agent: str = "langchain"
"""Identifier for the application making the request."""

timeout_seconds: Optional[float] = 300

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["cohere_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "cohere_api_key", "COHERE_API_KEY")
)
client_name = values["user_agent"]
timeout_seconds = values.get("timeout_seconds")
values["client"] = cohere.Client(
api_key=values["cohere_api_key"].get_secret_value(),
timeout=timeout_seconds,
client_name=client_name,
)
values["async_client"] = cohere.AsyncClient(
api_key=values["cohere_api_key"].get_secret_value(),
timeout=timeout_seconds,
client_name=client_name,
)
return values
Expand Down
Empty file.
Loading
Loading