Skip to content

Commit

Permalink
feat: support bard (#315)
Browse files Browse the repository at this point in the history
* feat: support bard
  • Loading branch information
yihong0618 authored Jul 20, 2023
1 parent 9e172e2 commit e0f0d3a
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 17 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ python3 xiaogpt.py --hardware LX06 --mute_xiaoai --use_gpt3

# 如果你想使用 ChatGLM api
python3 xiaogpt.py --hardware LX06 --mute_xiaoai --use_glm --glm_key ${glm_key}
# 如果你想使用 google 的 bard
python3 xiaogpt.py --hardware LX06 --mute_xiaoai --use_bard --bard_token ${bard_token}
```

## config.json
Expand Down Expand Up @@ -127,6 +129,7 @@ python3 xiaogpt.py

具体参数作用请参考 [Open AI API 文档](https://platform.openai.com/docs/api-reference/chat/create)
ChatGLM [文档](http://open.bigmodel.cn/doc/api#chatglm_130b)
Bard-API [参考](https://github.com/dsdanielpark/Bard-API)
## 配置项说明

| 参数 | 说明 | 默认值 |
Expand All @@ -136,6 +139,7 @@ ChatGLM [文档](http://open.bigmodel.cn/doc/api#chatglm_130b)
| password | 小爱账户密码 | |
| openai_key | openai的apikey | |
| glm_key | chatglm 的 apikey | |
| bard_token | bard 的 token 参考 [Bard-API](https://github.com/dsdanielpark/Bard-API) | |
| cookie | 小爱账户cookie (如果用上面密码登录可以不填) | |
| mi_did | 设备did | |
| use_command | 使用 MI command 与小爱交互 | `false` |
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "xiaogpt"
description = "Play ChatGPT with xiaomi AI speaker"
description = "Play ChatGPT or other LLM with xiaomi AI speaker"
readme = "README.md"
authors = [
{name = "yihong0618", email = "[email protected]"},
Expand All @@ -17,6 +17,7 @@ dependencies = [
"aiohttp",
"rich",
"zhipuai",
"bardapi",
"edge-tts>=6.1.3",
"EdgeGPT==0.1.26",
]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ websockets==11.0
yarl==1.8.2
edge-tts==6.1.5
zhipuai
bardapi
1 change: 1 addition & 0 deletions xiao_config.json.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"password": "",
"openai_key": "",
"glm_key": "",
"bard_token": "",
"cookie": "",
"mi_did": "",
"use_command": false,
Expand Down
4 changes: 3 additions & 1 deletion xiaogpt/bot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from xiaogpt.bot.gpt3_bot import GPT3Bot
from xiaogpt.bot.newbing_bot import NewBingBot
from xiaogpt.bot.glm_bot import GLMBot
from xiaogpt.bot.bard_bot import BardBot
from xiaogpt.config import Config

BOTS: dict[str, type[BaseBot]] = {
"gpt3": GPT3Bot,
"newbing": NewBingBot,
"chatgptapi": ChatGPTBot,
"glm": GLMBot,
"bard": BardBot,
}


Expand All @@ -22,4 +24,4 @@ def get_bot(config: Config) -> BaseBot:
raise ValueError(f"Unsupported bot {config.bot}, must be one of {list(BOTS)}")


__all__ = ["GPT3Bot", "ChatGPTBot", "NewBingBot", "GLMBot", "get_bot"]
__all__ = ["GPT3Bot", "ChatGPTBot", "NewBingBot", "GLMBot", "BardBot", "get_bot"]
32 changes: 32 additions & 0 deletions xiaogpt/bot/bard_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""ChatGLM bot"""
from __future__ import annotations
from typing import Any

from bardapi import BardAsync
from rich import print

from xiaogpt.bot.base_bot import BaseBot


class BardBot(BaseBot):
def __init__(
self,
bard_token: str,
) -> None:
self._bot = BardAsync(token=bard_token)
self.history = []

@classmethod
def from_config(cls, config):
return cls(bard_token=config.bard_token)

async def ask(self, query, **options):
try:
r = await self._bot.get_answer(query)
except Exception as e:
print(str(e))
print(r["content"])
return r["content"]

def ask_stream(self, query: str, **options: Any):
raise Exception("Bard do not support stream")
16 changes: 12 additions & 4 deletions xiaogpt/bot/chatgptapi_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ async def ask(self, query, **options):
ms.append({"role": "assistant", "content": h[1]})
ms.append({"role": "user", "content": f"{query}"})
kwargs = {**self.default_options, **options}
completion = await openai.ChatCompletion.acreate(messages=ms, **kwargs)
try:
completion = await openai.ChatCompletion.acreate(messages=ms, **kwargs)
except Exception as e:
print(str(e))
return ""
message = (
completion["choices"][0]
.get("message")
Expand All @@ -72,9 +76,13 @@ async def ask_stream(self, query, **options):
kwargs = {"model": "gpt-3.5-turbo", **options}
if openai.api_type == "azure":
kwargs["deployment_id"] = self.deployment_id
completion = await openai.ChatCompletion.acreate(
messages=ms, stream=True, **kwargs
)
try:
completion = await openai.ChatCompletion.acreate(
messages=ms, stream=True, **kwargs
)
except Exception as e:
print(str(e))
return

async def text_gen():
async for event in completion:
Expand Down
8 changes: 6 additions & 2 deletions xiaogpt/bot/glm_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def ask(self, query, **options):
kwargs = {**self.default_options, **options}
kwargs["prompt"] = ms
ms.append({"role": "user", "content": f"{query}"})
r = zhipuai.model_api.sse_invoke(**kwargs)
try:
r = zhipuai.model_api.sse_invoke(**kwargs)
except Exception as e:
print(str(e))
return
message = ""
for i in r.events():
message += str(i.data)
Expand All @@ -43,4 +47,4 @@ def ask(self, query, **options):
return message

def ask_stream(self, query: str, **options: Any):
pass
raise Exception("GLM do not support stream")
12 changes: 10 additions & 2 deletions xiaogpt/bot/gpt3_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ async def ask(self, query, **options):
"top_p": 1,
**options,
}
completion = await openai.Completion.acreate(**data)
try:
completion = await openai.Completion.acreate(**data)
except Exception as e:
print(str(e))
return ""
print(completion["choices"][0]["text"])
return completion["choices"][0]["text"]

Expand All @@ -43,7 +47,11 @@ async def ask_stream(self, query, **options):
"stream": True,
**options,
}
completion = await openai.Completion.acreate(**data)
try:
completion = await openai.Completion.acreate(**data)
except Exception as e:
print(str(e))
return

async def text_gen():
async for event in completion:
Expand Down
11 changes: 9 additions & 2 deletions xiaogpt/bot/newbing_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,20 @@ def clean_text(s):
async def ask(self, query, **options):
kwargs = {"conversation_style": ConversationStyle.balanced, **options}
completion = await self._bot.ask(prompt=query, **kwargs)
text = self.clean_text(completion["item"]["messages"][1]["text"])
try:
text = self.clean_text(completion["item"]["messages"][1]["text"])
except Exception as e:
print(str(e))
return
print(text)
return text

async def ask_stream(self, query, **options):
kwargs = {"conversation_style": ConversationStyle.balanced, **options}
completion = self._bot.ask_stream(prompt=query, **kwargs)
try:
completion = self._bot.ask_stream(prompt=query, **kwargs)
except Exception as e:
return

async def text_gen():
current = ""
Expand Down
16 changes: 14 additions & 2 deletions xiaogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def main():
dest="glm_key",
help="chatglm api key",
)
parser.add_argument(
"--bard_token",
dest="bard_token",
help="google bard token see https://github.com/dsdanielpark/Bard-API",
)
parser.add_argument(
"--proxy",
dest="proxy",
Expand Down Expand Up @@ -106,6 +111,13 @@ def main():
const="glm",
help="if use chatglm",
)
group.add_argument(
"--use_bard",
dest="bot",
action="store_const",
const="bard",
help="if use bard",
)
parser.add_argument(
"--bing_cookie_path",
dest="bing_cookie_path",
Expand All @@ -115,7 +127,7 @@ def main():
"--bot",
dest="bot",
help="bot type",
choices=["gpt3", "chatgptapi", "newbing", "glm"],
choices=["gpt3", "chatgptapi", "newbing", "glm", "bard"],
)
parser.add_argument(
"--config",
Expand Down Expand Up @@ -144,7 +156,7 @@ def main():
)

options = parser.parse_args()
if options.bot == "glm" and options.stream:
if options.bot in ["glm", "bard"] and options.stream:
raise Exception("For now ChatGLM do not support stream")
config = Config.from_options(options)

Expand Down
5 changes: 5 additions & 0 deletions xiaogpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Config:
password: str = os.getenv("MI_PASS", "")
openai_key: str = os.getenv("OPENAI_API_KEY", "")
glm_key: str = os.getenv("CHATGLM_KEY", "")
bard_token: str = os.getenv("BARD_TOKEN", "")
proxy: str | None = None
mi_did: str = os.getenv("MI_DID", "")
keyword: Iterable[str] = KEY_WORD
Expand Down Expand Up @@ -139,5 +140,9 @@ def read_from_file(cls, config_path: str) -> dict:
key, value = "bot", "gpt3"
elif key == "use_newbing":
key, value = "bot", "newbing"
elif key == "use_glm":
key, value = "bot", "glm"
elif key == "use_bard":
key, value = "bot", "bard"
result[key] = value
return result
7 changes: 4 additions & 3 deletions xiaogpt/xiaogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ async def wakeup_xiaoai(self):
)

async def run_forever(self):
ask_name = self.config.bot.upper()
async with ClientSession() as session:
await self.init_all_data(session)
task = asyncio.create_task(self.poll_latest_ask())
Expand Down Expand Up @@ -470,15 +471,15 @@ async def run_forever(self):
else:
# waiting for xiaoai speaker done
await asyncio.sleep(8)
await self.do_tts("正在问GPT请耐心等待")
await self.do_tts(f"正在问{ask_name}请耐心等待")
try:
print(
"以下是小爱的回答: ",
new_record.get("answers", [])[0].get("tts", {}).get("text"),
)
except IndexError:
print("小爱没回")
print("以下是GPT的回答: ", end="")
print(f"以下是 {ask_name} 的回答: ", end="")
try:
if not self.config.enable_edge_tts:
async for message in self.ask_gpt(query):
Expand All @@ -492,7 +493,7 @@ async def run_forever(self):
await self.edge_tts(self.ask_gpt(query), tts_lang)
print("回答完毕")
except Exception as e:
print(f"GPT回答出错 {str(e)}")
print(f"{ask_name} 回答出错 {str(e)}")
if self.in_conversation:
print(f"继续对话, 或用`{self.config.end_conversation}`结束对话")
await self.wakeup_xiaoai()

0 comments on commit e0f0d3a

Please sign in to comment.