diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index f5b66872e..6e552999c 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -23,7 +23,7 @@ from rag.nlp import is_english from rag.utils import num_tokens_from_string from groq import Groq -import os +import os import json import requests import asyncio @@ -62,17 +62,17 @@ def chat_streamly(self, system, history, gen_conf): stream=True, **gen_conf) for resp in response: - if not resp.choices:continue + if not resp.choices: continue if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" + resp.choices[0].delta.content = "" ans += resp.choices[0].delta.content total_tokens = ( ( - total_tokens - + num_tokens_from_string(resp.choices[0].delta.content) + total_tokens + + num_tokens_from_string(resp.choices[0].delta.content) ) if not hasattr(resp, "usage") or not resp.usage - else resp.usage.get("total_tokens",total_tokens) + else resp.usage.get("total_tokens", total_tokens) ) if resp.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( @@ -87,13 +87,13 @@ def chat_streamly(self, system, history, gen_conf): class GptTurbo(Base): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): - if not base_url: base_url="https://api.openai.com/v1" + if not base_url: base_url = "https://api.openai.com/v1" super().__init__(key, model_name, base_url) class MoonshotChat(Base): def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): - if not base_url: base_url="https://api.moonshot.cn/v1" + if not base_url: base_url = "https://api.moonshot.cn/v1" super().__init__(key, model_name, base_url) @@ -108,7 +108,7 @@ def __init__(self, key=None, model_name="", base_url=""): class DeepSeekChat(Base): def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): - if not base_url: base_url="https://api.deepseek.com/v1" + if not base_url: base_url = "https://api.deepseek.com/v1" super().__init__(key, model_name, base_url) @@ -178,14 +178,14 @@ def chat_streamly(self, system, history, gen_conf): stream=True, **self._format_params(gen_conf)) for resp in response: - if not resp.choices:continue + if not resp.choices: continue if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" + resp.choices[0].delta.content = "" ans += resp.choices[0].delta.content total_tokens = ( ( - total_tokens - + num_tokens_from_string(resp.choices[0].delta.content) + total_tokens + + num_tokens_from_string(resp.choices[0].delta.content) ) if not hasattr(resp, "usage") else resp.usage["total_tokens"] @@ -252,7 +252,8 @@ def chat_streamly(self, system, history, gen_conf): [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" yield ans else: - yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**" + yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find( + "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**" except Exception as e: yield ans + "\n**ERROR**: " + str(e) @@ -298,7 +299,7 @@ def chat_streamly(self, system, history, gen_conf): **gen_conf ) for resp in response: - if not resp.choices[0].delta.content:continue + if not resp.choices[0].delta.content: continue delta = resp.choices[0].delta.content ans += delta if resp.choices[0].finish_reason == "length": @@ -411,7 +412,7 @@ def __init__(self, key, model_name): self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): - from rag.svr.jina_server import Prompt,Generation + from rag.svr.jina_server import Prompt, Generation if system: history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: @@ -419,7 +420,7 @@ def _prepare_prompt(self, system, history, gen_conf): return Prompt(message=history, gen_conf=gen_conf) def _stream_response(self, endpoint, prompt): - from rag.svr.jina_server import Prompt,Generation + from rag.svr.jina_server import Prompt, Generation answer = "" try: res = self.client.stream_doc( @@ -463,10 +464,10 @@ def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/ class MiniMaxChat(Base): def __init__( - self, - key, - model_name, - base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", + self, + key, + model_name, + base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", ): if not base_url: base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" @@ -583,7 +584,7 @@ def chat_streamly(self, system, history, gen_conf): messages=history, **gen_conf) for resp in response: - if not resp.choices or not resp.choices[0].delta.content:continue + if not resp.choices or not resp.choices[0].delta.content: continue ans += resp.choices[0].delta.content total_tokens += 1 if resp.choices[0].finish_reason == "length": @@ -620,9 +621,8 @@ def chat(self, system, history, gen_conf): gen_conf["topP"] = gen_conf["top_p"] _ = gen_conf.pop("top_p") for item in history: - if not isinstance(item["content"],list) and not isinstance(item["content"],tuple): - item["content"] = [{"text":item["content"]}] - + if not isinstance(item["content"], list) and not isinstance(item["content"], tuple): + item["content"] = [{"text": item["content"]}] try: # Send the message to the model, using a basic inference configuration. @@ -630,9 +630,9 @@ def chat(self, system, history, gen_conf): modelId=self.model_name, messages=history, inferenceConfig=gen_conf, - system=[{"text": (system if system else "Answer the user's message.")}] , + system=[{"text": (system if system else "Answer the user's message.")}], ) - + # Extract and print the response text. ans = response["output"]["message"]["content"][0]["text"] return ans, num_tokens_from_string(ans) @@ -652,9 +652,9 @@ def chat_streamly(self, system, history, gen_conf): gen_conf["topP"] = gen_conf["top_p"] _ = gen_conf.pop("top_p") for item in history: - if not isinstance(item["content"],list) and not isinstance(item["content"],tuple): - item["content"] = [{"text":item["content"]}] - + if not isinstance(item["content"], list) and not isinstance(item["content"], tuple): + item["content"] = [{"text": item["content"]}] + if self.model_name.split('.')[0] == 'ai21': try: response = self.client.converse( @@ -684,7 +684,7 @@ def chat_streamly(self, system, history, gen_conf): if "contentBlockDelta" in resp: ans += resp["contentBlockDelta"]["delta"]["text"] yield ans - + except (ClientError, Exception) as e: yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}" @@ -693,22 +693,21 @@ def chat_streamly(self, system, history, gen_conf): class GeminiChat(Base): - def __init__(self, key, model_name,base_url=None): - from google.generativeai import client,GenerativeModel - + def __init__(self, key, model_name, base_url=None): + from google.generativeai import client, GenerativeModel + client.configure(api_key=key) _client = client.get_default_generative_client() self.model_name = 'models/' + model_name self.model = GenerativeModel(model_name=self.model_name) self.model._client = _client - - - def chat(self,system,history,gen_conf): + + def chat(self, system, history, gen_conf): from google.generativeai.types import content_types - + if system: self.model._system_instruction = content_types.to_content(system) - + if 'max_tokens' in gen_conf: gen_conf['max_output_tokens'] = gen_conf['max_tokens'] for k in list(gen_conf.keys()): @@ -717,9 +716,11 @@ def chat(self,system,history,gen_conf): for item in history: if 'role' in item and item['role'] == 'assistant': item['role'] = 'model' - if 'content' in item : + if 'role' in item and item['role'] == 'system': + item['role'] = 'user' + if 'content' in item: item['parts'] = item.pop('content') - + try: response = self.model.generate_content( history, @@ -731,7 +732,7 @@ def chat(self,system,history,gen_conf): def chat_streamly(self, system, history, gen_conf): from google.generativeai.types import content_types - + if system: self.model._system_instruction = content_types.to_content(system) if 'max_tokens' in gen_conf: @@ -742,13 +743,13 @@ def chat_streamly(self, system, history, gen_conf): for item in history: if 'role' in item and item['role'] == 'assistant': item['role'] = 'model' - if 'content' in item : + if 'content' in item: item['parts'] = item.pop('content') ans = "" try: response = self.model.generate_content( history, - generation_config=gen_conf,stream=True) + generation_config=gen_conf, stream=True) for resp in response: ans += resp.text yield ans @@ -756,11 +757,11 @@ def chat_streamly(self, system, history, gen_conf): except Exception as e: yield ans + "\n**ERROR**: " + str(e) - yield response._chunks[-1].usage_metadata.total_token_count + yield response._chunks[-1].usage_metadata.total_token_count class GroqChat: - def __init__(self, key, model_name,base_url=''): + def __init__(self, key, model_name, base_url=''): self.client = Groq(api_key=key) self.model_name = model_name @@ -942,7 +943,7 @@ def chat_streamly(self, system, history, gen_conf): class LeptonAIChat(Base): def __init__(self, key, model_name, base_url=None): if not base_url: - base_url = os.path.join("https://"+model_name+".lepton.run","api","v1") + base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1") super().__init__(key, model_name, base_url) @@ -1058,7 +1059,7 @@ def chat(self, system, history, gen_conf): ) _gen_conf = {} - _history = [{k.capitalize(): v for k, v in item.items() } for item in history] + _history = [{k.capitalize(): v for k, v in item.items()} for item in history] if system: _history.insert(0, {"Role": "system", "Content": system}) if "temperature" in gen_conf: @@ -1084,7 +1085,7 @@ def chat_streamly(self, system, history, gen_conf): ) _gen_conf = {} - _history = [{k.capitalize(): v for k, v in item.items() } for item in history] + _history = [{k.capitalize(): v for k, v in item.items()} for item in history] if system: _history.insert(0, {"Role": "system", "Content": system}) @@ -1121,7 +1122,7 @@ def chat_streamly(self, system, history, gen_conf): class SparkChat(Base): def __init__( - self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1" + self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1" ): if not base_url: base_url = "https://spark-api-open.xf-yun.com/v1" @@ -1141,9 +1142,9 @@ def __init__(self, key, model_name, base_url=None): import qianfan key = json.loads(key) - ak = key.get("yiyan_ak","") - sk = key.get("yiyan_sk","") - self.client = qianfan.ChatCompletion(ak=ak,sk=sk) + ak = key.get("yiyan_ak", "") + sk = key.get("yiyan_sk", "") + self.client = qianfan.ChatCompletion(ak=ak, sk=sk) self.model_name = model_name.lower() self.system = "" @@ -1151,16 +1152,17 @@ def chat(self, system, history, gen_conf): if system: self.system = system gen_conf["penalty_score"] = ( - (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2 - ) + 1 + (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", + 0)) / 2 + ) + 1 if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] ans = "" try: response = self.client.do( - model=self.model_name, - messages=history, + model=self.model_name, + messages=history, system=self.system, **gen_conf ).body @@ -1174,8 +1176,9 @@ def chat_streamly(self, system, history, gen_conf): if system: self.system = system gen_conf["penalty_score"] = ( - (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2 - ) + 1 + (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", + 0)) / 2 + ) + 1 if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] ans = "" @@ -1183,8 +1186,8 @@ def chat_streamly(self, system, history, gen_conf): try: response = self.client.do( - model=self.model_name, - messages=history, + model=self.model_name, + messages=history, system=self.system, stream=True, **gen_conf @@ -1415,4 +1418,3 @@ def chat_streamly(self, system, history, gen_conf): yield ans + "\n**ERROR**: " + str(e) yield response._chunks[-1].usage_metadata.total_token_count - \ No newline at end of file