Skip to content

Commit

Permalink
feat: support memory and non-stream mode for langchain bot
Browse files Browse the repository at this point in the history
Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming committed Oct 27, 2023
1 parent 12117de commit 321f4d9
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 56 deletions.
5 changes: 3 additions & 2 deletions xiaogpt/bot/bard_bot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""ChatGLM bot"""
from __future__ import annotations

from typing import Any

from bardapi import BardAsync
from rich import print

from xiaogpt.bot.base_bot import BaseBot
from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin


class BardBot(BaseBot):
class BardBot(ChatHistoryMixin, BaseBot):
def __init__(
self,
bard_token: str,
Expand Down
35 changes: 33 additions & 2 deletions xiaogpt/bot/base_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@


class BaseBot(ABC):
history: list

@abstractmethod
async def ask(self, query: str, **options: Any) -> str:
pass
Expand All @@ -23,3 +21,36 @@ async def ask_stream(self, query: str, **options: Any) -> AsyncGenerator[str, No
@abstractmethod
def from_config(cls: type[T], config: Config) -> T:
pass

@abstractmethod
def has_history(self) -> bool:
pass

@abstractmethod
def change_prompt(self, new_prompt: str) -> None:
pass


class ChatHistoryMixin:
history: list[tuple[str, str]]

def has_history(self) -> bool:
return bool(self.history)

def change_prompt(self, new_prompt: str) -> None:
if self.history:
print(self.history)
self.history[0][0] = new_prompt

def get_messages(self) -> list[dict]:
ms = []
for h in self.history:
ms.append({"role": "user", "content": h[0]})
ms.append({"role": "assistant", "content": h[1]})
return ms

def add_message(self, query: str, message: str) -> None:
self.history.append([f"{query}", message])
# only keep 5 history
first_history = self.history.pop(0)
self.history = [first_history] + self.history[-5:]
23 changes: 6 additions & 17 deletions xiaogpt/bot/chatgptapi_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import openai
from rich import print

from xiaogpt.bot.base_bot import BaseBot
from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
from xiaogpt.utils import split_sentences


class ChatGPTBot(BaseBot):
class ChatGPTBot(ChatHistoryMixin, BaseBot):
default_options = {"model": "gpt-3.5-turbo-0613"}

def __init__(
Expand Down Expand Up @@ -42,10 +42,7 @@ def from_config(cls, config):
)

async def ask(self, query, **options):
ms = []
for h in self.history:
ms.append({"role": "user", "content": h[0]})
ms.append({"role": "assistant", "content": h[1]})
ms = self.get_messages()
ms.append({"role": "user", "content": f"{query}"})
kwargs = {**self.default_options, **options}
try:
Expand All @@ -60,18 +57,12 @@ async def ask(self, query, **options):
.encode("utf8")
.decode()
)
self.history.append([f"{query}", message])
# only keep 5 history
first_history = self.history.pop(0)
self.history = [first_history] + self.history[-5:]
self.add_message(query, message)
print(message)
return message

async def ask_stream(self, query, **options):
ms = []
for h in self.history:
ms.append({"role": "user", "content": h[0]})
ms.append({"role": "assistant", "content": h[1]})
ms = self.get_messages()
ms.append({"role": "user", "content": f"{query}"})
kwargs = {"model": "gpt-3.5-turbo", **options}
if openai.api_type == "azure":
Expand Down Expand Up @@ -99,6 +90,4 @@ async def text_gen():
yield sentence
finally:
print()
self.history.append([f"{query}", message])
first_history = self.history.pop(0)
self.history = [first_history] + self.history[-5:]
self.add_message(query, message)
21 changes: 6 additions & 15 deletions xiaogpt/bot/glm_bot.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
"""ChatGLM bot"""
from __future__ import annotations

from typing import Any, AsyncGenerator
from typing import Any

import zhipuai
from rich import print

from xiaogpt.bot.base_bot import BaseBot
from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin


class GLMBot(BaseBot):
class GLMBot(ChatHistoryMixin, BaseBot):
default_options = {"model": "chatglm_130b"}

def __init__(
self,
glm_key: str,
) -> None:
def __init__(self, glm_key: str) -> None:
self.history = []
zhipuai.api_key = glm_key

Expand All @@ -24,10 +21,7 @@ def from_config(cls, config):
return cls(glm_key=config.glm_key)

def ask(self, query, **options):
ms = []
for h in self.history:
ms.append({"role": "user", "content": h[0]})
ms.append({"role": "assistant", "content": h[1]})
ms = self.get_messages()
kwargs = {**self.default_options, **options}
kwargs["prompt"] = ms
ms.append({"role": "user", "content": f"{query}"})
Expand All @@ -40,10 +34,7 @@ def ask(self, query, **options):
for i in r.events():
message += str(i.data)

self.history.append([f"{query}", message])
# only keep 5 history
first_history = self.history.pop(0)
self.history = [first_history] + self.history[-5:]
self.add_message(query, message)
print(message)
return message

Expand Down
4 changes: 2 additions & 2 deletions xiaogpt/bot/gpt3_bot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import openai
from rich import print

from xiaogpt.bot.base_bot import BaseBot
from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
from xiaogpt.utils import split_sentences


class GPT3Bot(BaseBot):
class GPT3Bot(ChatHistoryMixin, BaseBot):
def __init__(self, openai_key, api_base=None, proxy=None):
openai.api_key = openai_key
if api_base:
Expand Down
18 changes: 11 additions & 7 deletions xiaogpt/bot/langchain_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import os

from langchain.memory import ConversationBufferWindowMemory
from rich import print

from xiaogpt.bot.base_bot import BaseBot
Expand All @@ -26,8 +27,14 @@ def __init__(
os.environ["OPENAI_API_BASE"] = api_base
if proxy:
os.environ["OPENAI_PROXY"] = proxy
# Todo,Plan to implement within langchain
self.history = []
self.memory = ConversationBufferWindowMemory()

def has_history(self) -> bool:
return len(self.memory.chat_memory.messages) > 0

def change_prompt(self, new_prompt: str) -> None:
self.memory.clear()
self.memory.chat_memory.add_user_message(new_prompt)

@classmethod
def from_config(cls, config):
Expand All @@ -39,14 +46,11 @@ def from_config(cls, config):
)

async def ask(self, query, **options):
# Todo,Currently only supports stream
raise Exception(
"The bot does not support it. Please use 'ask_stream, add: --stream'"
)
return await agent_search(query, self.memory)

async def ask_stream(self, query, **options):
callback = AsyncIteratorCallbackHandler()
task = asyncio.create_task(agent_search(query, callback))
task = asyncio.create_task(agent_search(query, self.memory, callback))
try:
async for message in split_sentences(callback.aiter()):
yield message
Expand Down
6 changes: 3 additions & 3 deletions xiaogpt/bot/newbing_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from EdgeGPT import Chatbot, ConversationStyle

from xiaogpt.bot.base_bot import BaseBot
from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
from xiaogpt.utils import split_sentences

_reference_link_re = re.compile(r"\[\d+\]: .+?\n+")


class NewBingBot(BaseBot):
class NewBingBot(ChatHistoryMixin, BaseBot):
def __init__(
self,
bing_cookie_path: str = "",
Expand Down Expand Up @@ -52,7 +52,7 @@ async def ask_stream(self, query, **options):
kwargs = {"conversation_style": ConversationStyle.balanced, **options}
try:
completion = self._bot.ask_stream(prompt=query, **kwargs)
except Exception as e:
except Exception:
return

async def text_gen():
Expand Down
11 changes: 7 additions & 4 deletions xiaogpt/langchain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMMathChain
from langchain.chat_models import ChatOpenAI
from langchain.schema.memory import BaseMemory
from langchain.utilities import SerpAPIWrapper


async def agent_search(query: str, callback: BaseCallbackHandler) -> None:
async def agent_search(
query: str, memeory: BaseMemory, callback: BaseCallbackHandler | None = None
) -> str:
llm = ChatOpenAI(
streaming=True,
temperature=0,
Expand All @@ -23,8 +26,8 @@ async def agent_search(query: str, callback: BaseCallbackHandler) -> None:
]

agent = initialize_agent(
tools, llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=False
tools, llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=False, memory=memeory
)

callbacks = [callback] if callback else None
# query eg:'杭州亚运会中国队获得了多少枚金牌?' // '计算3的2次方'
await agent.arun(query, callbacks=[callback])
return await agent.arun(query, callbacks=callbacks)
6 changes: 2 additions & 4 deletions xiaogpt/xiaogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ def _change_prompt(self, new_prompt):
new_prompt = "以下都" + new_prompt
print(f"Prompt from {self.config.prompt} change to {new_prompt}")
self.config.prompt = new_prompt
if self.chatbot.history:
print(self.chatbot.history)
self.chatbot.history[0][0] = new_prompt
self.chatbot.change_prompt(new_prompt)

async def get_latest_ask_from_xiaoai(self, session: ClientSession) -> dict | None:
retries = 3
Expand Down Expand Up @@ -489,7 +487,7 @@ async def run_forever(self):

print("-" * 20)
print("问题:" + query + "?")
if not self.chatbot.history:
if not self.chatbot.has_history():
query = f"{query}{self.config.prompt}"
if self.config.mute_xiaoai:
await self.stop_if_xiaoai_is_playing()
Expand Down

0 comments on commit 321f4d9

Please sign in to comment.