Skip to content

Commit

Permalink
solve knowledgegraph issue when calling gemini model (infiniflow#2738)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
infiniflow#2720

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
JobSmithManipulation authored Oct 8, 2024
1 parent d92acdc commit 16472eb
Showing 1 changed file with 64 additions and 62 deletions.
126 changes: 64 additions & 62 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from rag.nlp import is_english
from rag.utils import num_tokens_from_string
from groq import Groq
import os
import os
import json
import requests
import asyncio
Expand Down Expand Up @@ -62,17 +62,17 @@ def chat_streamly(self, system, history, gen_conf):
stream=True,
**gen_conf)
for resp in response:
if not resp.choices:continue
if not resp.choices: continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
total_tokens = (
(
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
if not hasattr(resp, "usage") or not resp.usage
else resp.usage.get("total_tokens",total_tokens)
else resp.usage.get("total_tokens", total_tokens)
)
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
Expand All @@ -87,13 +87,13 @@ def chat_streamly(self, system, history, gen_conf):

class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
if not base_url: base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url)


class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url="https://api.moonshot.cn/v1"
if not base_url: base_url = "https://api.moonshot.cn/v1"
super().__init__(key, model_name, base_url)


Expand All @@ -108,7 +108,7 @@ def __init__(self, key=None, model_name="", base_url=""):

class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url="https://api.deepseek.com/v1"
if not base_url: base_url = "https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)


Expand Down Expand Up @@ -178,14 +178,14 @@ def chat_streamly(self, system, history, gen_conf):
stream=True,
**self._format_params(gen_conf))
for resp in response:
if not resp.choices:continue
if not resp.choices: continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
total_tokens = (
(
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
if not hasattr(resp, "usage")
else resp.usage["total_tokens"]
Expand Down Expand Up @@ -252,7 +252,8 @@ def chat_streamly(self, system, history, gen_conf):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
else:
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

Expand Down Expand Up @@ -298,7 +299,7 @@ def chat_streamly(self, system, history, gen_conf):
**gen_conf
)
for resp in response:
if not resp.choices[0].delta.content:continue
if not resp.choices[0].delta.content: continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
Expand Down Expand Up @@ -411,15 +412,15 @@ def __init__(self, key, model_name):
self.client = Client(port=12345, protocol="grpc", asyncio=True)

def _prepare_prompt(self, system, history, gen_conf):
from rag.svr.jina_server import Prompt,Generation
from rag.svr.jina_server import Prompt, Generation
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
return Prompt(message=history, gen_conf=gen_conf)

def _stream_response(self, endpoint, prompt):
from rag.svr.jina_server import Prompt,Generation
from rag.svr.jina_server import Prompt, Generation
answer = ""
try:
res = self.client.stream_doc(
Expand Down Expand Up @@ -463,10 +464,10 @@ def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/

class MiniMaxChat(Base):
def __init__(
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
):
if not base_url:
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
Expand Down Expand Up @@ -583,7 +584,7 @@ def chat_streamly(self, system, history, gen_conf):
messages=history,
**gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:continue
if not resp.choices or not resp.choices[0].delta.content: continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
Expand Down Expand Up @@ -620,19 +621,18 @@ def chat(self, system, history, gen_conf):
gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p")
for item in history:
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
item["content"] = [{"text":item["content"]}]

if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text": item["content"]}]

try:
# Send the message to the model, using a basic inference configuration.
response = self.client.converse(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}] ,
system=[{"text": (system if system else "Answer the user's message.")}],
)

# Extract and print the response text.
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)
Expand All @@ -652,9 +652,9 @@ def chat_streamly(self, system, history, gen_conf):
gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p")
for item in history:
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
item["content"] = [{"text":item["content"]}]
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text": item["content"]}]

if self.model_name.split('.')[0] == 'ai21':
try:
response = self.client.converse(
Expand Down Expand Up @@ -684,7 +684,7 @@ def chat_streamly(self, system, history, gen_conf):
if "contentBlockDelta" in resp:
ans += resp["contentBlockDelta"]["delta"]["text"]
yield ans

except (ClientError, Exception) as e:
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"

Expand All @@ -693,22 +693,21 @@ def chat_streamly(self, system, history, gen_conf):

class GeminiChat(Base):

def __init__(self, key, model_name,base_url=None):
from google.generativeai import client,GenerativeModel
def __init__(self, key, model_name, base_url=None):
from google.generativeai import client, GenerativeModel

client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = 'models/' + model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client


def chat(self,system,history,gen_conf):

def chat(self, system, history, gen_conf):
from google.generativeai.types import content_types

if system:
self.model._system_instruction = content_types.to_content(system)

if 'max_tokens' in gen_conf:
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
for k in list(gen_conf.keys()):
Expand All @@ -717,9 +716,11 @@ def chat(self,system,history,gen_conf):
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item :
if 'role' in item and item['role'] == 'system':
item['role'] = 'user'
if 'content' in item:
item['parts'] = item.pop('content')

try:
response = self.model.generate_content(
history,
Expand All @@ -731,7 +732,7 @@ def chat(self,system,history,gen_conf):

def chat_streamly(self, system, history, gen_conf):
from google.generativeai.types import content_types

if system:
self.model._system_instruction = content_types.to_content(system)
if 'max_tokens' in gen_conf:
Expand All @@ -742,25 +743,25 @@ def chat_streamly(self, system, history, gen_conf):
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item :
if 'content' in item:
item['parts'] = item.pop('content')
ans = ""
try:
response = self.model.generate_content(
history,
generation_config=gen_conf,stream=True)
generation_config=gen_conf, stream=True)
for resp in response:
ans += resp.text
yield ans

except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

yield response._chunks[-1].usage_metadata.total_token_count
yield response._chunks[-1].usage_metadata.total_token_count


class GroqChat:
def __init__(self, key, model_name,base_url=''):
def __init__(self, key, model_name, base_url=''):
self.client = Groq(api_key=key)
self.model_name = model_name

Expand Down Expand Up @@ -942,7 +943,7 @@ def chat_streamly(self, system, history, gen_conf):
class LeptonAIChat(Base):
def __init__(self, key, model_name, base_url=None):
if not base_url:
base_url = os.path.join("https://"+model_name+".lepton.run","api","v1")
base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1")
super().__init__(key, model_name, base_url)


Expand Down Expand Up @@ -1058,7 +1059,7 @@ def chat(self, system, history, gen_conf):
)

_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
if system:
_history.insert(0, {"Role": "system", "Content": system})
if "temperature" in gen_conf:
Expand All @@ -1084,7 +1085,7 @@ def chat_streamly(self, system, history, gen_conf):
)

_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
if system:
_history.insert(0, {"Role": "system", "Content": system})

Expand Down Expand Up @@ -1121,7 +1122,7 @@ def chat_streamly(self, system, history, gen_conf):

class SparkChat(Base):
def __init__(
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
):
if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1"
Expand All @@ -1141,26 +1142,27 @@ def __init__(self, key, model_name, base_url=None):
import qianfan

key = json.loads(key)
ak = key.get("yiyan_ak","")
sk = key.get("yiyan_sk","")
self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
self.model_name = model_name.lower()
self.system = ""

def chat(self, system, history, gen_conf):
if system:
self.system = system
gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
) + 1
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
ans = ""

try:
response = self.client.do(
model=self.model_name,
messages=history,
model=self.model_name,
messages=history,
system=self.system,
**gen_conf
).body
Expand All @@ -1174,17 +1176,18 @@ def chat_streamly(self, system, history, gen_conf):
if system:
self.system = system
gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
) + 1
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
ans = ""
total_tokens = 0

try:
response = self.client.do(
model=self.model_name,
messages=history,
model=self.model_name,
messages=history,
system=self.system,
stream=True,
**gen_conf
Expand Down Expand Up @@ -1415,4 +1418,3 @@ def chat_streamly(self, system, history, gen_conf):
yield ans + "\n**ERROR**: " + str(e)

yield response._chunks[-1].usage_metadata.total_token_count

0 comments on commit 16472eb

Please sign in to comment.