Skip to content

Commit

Permalink
fix: Async callback for langchain (#353)
Browse files Browse the repository at this point in the history
* fix: Async callback for langchain

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Oct 26, 2023
1 parent 414b77f commit d8b06d4
Show file tree
Hide file tree
Showing 9 changed files with 752 additions and 104 deletions.
623 changes: 595 additions & 28 deletions pdm.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
]
requires-python = ">=3.7.1"
requires-python = ">=3.8.1"
dependencies = [
"miservice_fork",
"openai",
Expand All @@ -22,11 +22,11 @@ dependencies = [
"EdgeGPT==0.1.26",
"langchain==0.0.301",
"datetime==5.2",
"bs4==0.0.1",
"beautifulsoup4>=4.12.0",
"chardet==5.1.0",
"typing==3.7.4.3",
"google-search-results==2.4.2",
"numexpr==2.8.6"
"google-search-results>=2.4.2",
"numexpr==2.8.6",
]
license = {text = "MIT"}
dynamic = ["version"]
Expand Down
40 changes: 28 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# This file is @generated by PDM.
# Please do not edit it manually.

aiohttp==3.8.5
aiohttp==3.8.4
aiosignal==1.3.1
annotated-types==0.6.0
anyio==3.6.2
async-timeout==4.0.2
attrs==22.2.0
Expand All @@ -11,10 +12,13 @@ beautifulsoup4==4.12.2
BingImageCreator==0.1.3
browser-cookie3==0.19.1
cachetools==4.2.4
certifi==2023.7.22
certifi==2022.12.7
chardet==5.1.0
charset-normalizer==3.1.0
colorama==0.4.6
dataclasses==0.6
dataclasses-json==0.6.1
datetime==5.2
deep-translator==1.11.4
edge-tts==6.1.3
EdgeGPT==0.1.26
Expand All @@ -23,6 +27,7 @@ google-api-core==1.34.0
google-auth==1.35.0
google-cloud-core==1.7.3
google-cloud-translate==2.0.1
google-search-results==2.4.2
googleapis-common-protos==1.59.1
grpcio==1.56.2
grpcio-status==1.48.2
Expand All @@ -33,38 +38,49 @@ httpcore==0.16.3
httpx==0.24.1
hyperframe==6.0.1
idna==3.4
jsonpatch==1.33
jsonpointer==2.4
langchain==0.0.301
langsmith==0.0.52
lz4==4.3.2
markdown-it-py==2.2.0
marshmallow==3.20.1
mdurl==0.1.2
miservice-fork==2.1.1
multidict==6.0.4
mypy-extensions==1.0.0
numexpr==2.8.6
numpy==1.24.4
openai==0.27.2
packaging==23.2
prompt-toolkit==3.0.38
protobuf==3.20.3
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycryptodomex==3.18.0
pygments==2.15.0
pydantic==2.4.2
pydantic-core==2.10.1
pygments==2.14.0
PyJWT==2.8.0
pytz==2023.3.post1
PyYAML==6.0.1
regex==2022.10.31
requests==2.31.0
requests==2.28.2
rich==13.3.2
rsa==4.9
setuptools==68.0.0
six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
SQLAlchemy==2.0.22
tenacity==8.2.3
tqdm==4.65.0
typing==3.7.4.3
typing-extensions==4.8.0
typing-inspect==0.9.0
urllib3==1.26.15
wcwidth==0.2.6
websockets==11.0
yarl==1.8.2
zhipuai==1.0.7
langchain==0.0.301
datetime==5.2
bs4==0.0.1
chardet==5.1.0
typing==3.7.4.3
google-search-results==2.4.2
numexpr==2.8.6

zope-interface==6.1
28 changes: 13 additions & 15 deletions xiaogpt/bot/langchain_bot.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

import openai
import asyncio
import os

from rich import print

from xiaogpt.bot.base_bot import BaseBot
from xiaogpt.utils import split_sentences

from xiaogpt.langchain.callbacks import AsyncIteratorCallbackHandler
from xiaogpt.langchain.chain import agent_search
from xiaogpt.langchain.stream_call_back import streaming_call_queue

import os
from xiaogpt.utils import split_sentences


class LangChainBot(BaseBot):
Expand Down Expand Up @@ -42,18 +41,17 @@ def from_config(cls, config):
async def ask(self, query, **options):
# Todo,Currently only supports stream
raise Exception(
"The bot does not support it. Please use 'ask_streamadd --stream'"
"The bot does not support it. Please use 'ask_stream, add: --stream'"
)

async def ask_stream(self, query, **options):
agent_search(query)
callback = AsyncIteratorCallbackHandler()
task = asyncio.create_task(agent_search(query, callback))
try:
while True:
if not streaming_call_queue.empty():
token = streaming_call_queue.get()
print(token, end="")
yield token
else:
break
async for message in split_sentences(callback.aiter()):
yield message
except Exception as e:
print("An error occurred:", str(e))
finally:
print()
await task
87 changes: 87 additions & 0 deletions xiaogpt/langchain/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator
from uuid import UUID

from langchain.callbacks.base import AsyncCallbackHandler


class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
"""Callback handler that returns an async iterator."""

@property
def always_verbose(self) -> bool:
return True

def __init__(self) -> None:
self.queue = asyncio.Queue()
self.done = asyncio.Event()

async def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.done.clear()

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if token is not None and token != "":
print(token, end="")
self.queue.put_nowait(token)

async def on_chain_end(
self,
outputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
self.done.set()

async def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
self.done.set()

async def aiter(self) -> AsyncIterator[str]:
while not self.queue.empty() or not self.done.is_set():
# Wait for the next token in the queue,
# but stop waiting if the done event is set
done, other = await asyncio.wait(
[
# NOTE: If you add other tasks here, update the code below,
# which assumes each set has exactly one task each
asyncio.ensure_future(self.queue.get()),
asyncio.ensure_future(self.done.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)

# Cancel the other task
if other:
other.pop().cancel()

# Extract the value of the first completed task
token_or_done = done.pop().result()

# If the extracted value is the boolean True, the done event was set
if token_or_done is True:
break

# Otherwise, the extracted value is a token, which we yield
yield token_or_done
18 changes: 5 additions & 13 deletions xiaogpt/langchain/chain.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
from langchain.tools import BaseTool
from langchain.llms import OpenAI
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMMathChain
from langchain.utilities import SerpAPIWrapper
from langchain.chat_models import ChatOpenAI
from langchain.memory import ChatMessageHistory
from xiaogpt.langchain.stream_call_back import StreamCallbackHandler
from langchain.agents.agent_toolkits import ZapierToolkit
from langchain.utilities.zapier import ZapierNLAWrapper
from langchain.memory import ConversationBufferMemory
from langchain.utilities import SerpAPIWrapper


def agent_search(query):
async def agent_search(query: str, callback: BaseCallbackHandler) -> None:
llm = ChatOpenAI(
streaming=True,
temperature=0,
model="gpt-3.5-turbo-0613",
callbacks=[StreamCallbackHandler()],
)

# Initialization: search chain, mathematical calculation chain
Expand All @@ -35,4 +27,4 @@ def agent_search(query):
)

# query eg:'杭州亚运会中国队获得了多少枚金牌?' // '计算3的2次方'
agent.run(query)
await agent.arun(query, callbacks=[callback])
16 changes: 0 additions & 16 deletions xiaogpt/langchain/stream_call_back.py

This file was deleted.

2 changes: 1 addition & 1 deletion xiaogpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def calculate_tts_elapse(text: str) -> float:
return len(_no_elapse_chars.sub("", text)) / speed


_ending_punctuations = ("。", "?", "!", ";", ".", "?", "!", ";")
_ending_punctuations = ("。", "?", "!", ";", "\n", "?", "!", ";")


async def split_sentences(text_stream: AsyncIterator[str]) -> AsyncIterator[str]:
Expand Down
Loading

0 comments on commit d8b06d4

Please sign in to comment.