Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/lm-sys/FastChat
Browse files Browse the repository at this point in the history
  • Loading branch information
surak committed Oct 15, 2024
2 parents 9fef8fc + 63d5da2 commit eecdde5
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 141 deletions.
79 changes: 6 additions & 73 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ def get_api_provider_stream_iter(
)
elif model_api_dict["api_type"] == "bard":
prompt = conv.to_openai_api_messages()
stream_iter = bard_api_stream_iter(
stream_iter = gemini_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
api_key=model_api_dict["api_key"],
None, # use Bard's default temperature
None, # use Bard's default top_p
max_new_tokens,
api_key=(model_api_dict["api_key"] or os.environ["BARD_API_KEY"]),
use_stream=False,
)
elif model_api_dict["api_type"] == "mistral":
if model_api_dict.get("vision-arena", False):
Expand Down Expand Up @@ -759,75 +761,6 @@ def gemini_api_stream_iter(
}


def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None):
del top_p # not supported
del temperature # not supported

if api_key is None:
api_key = os.environ["BARD_API_KEY"]

# convert conv to conv_bard
conv_bard = []
for turn in conv:
if turn["role"] == "user":
conv_bard.append({"author": "0", "content": turn["content"]})
elif turn["role"] == "assistant":
conv_bard.append({"author": "1", "content": turn["content"]})
else:
raise ValueError(f"Unsupported role: {turn['role']}")

params = {
"model": model_name,
"prompt": conv_bard,
}
logger.info(f"==== request ====\n{params}")

try:
res = requests.post(
f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}",
json={
"prompt": {
"messages": conv_bard,
},
},
timeout=60,
)
except Exception as e:
logger.error(f"==== error ====\n{e}")
yield {
"text": f"**API REQUEST ERROR** Reason: {e}.",
"error_code": 1,
}

if res.status_code != 200:
logger.error(f"==== error ==== ({res.status_code}): {res.text}")
yield {
"text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.",
"error_code": 1,
}

response_json = res.json()
if "candidates" not in response_json:
logger.error(f"==== error ==== response blocked: {response_json}")
reason = response_json["filters"][0]["reason"]
yield {
"text": f"**API REQUEST ERROR** Reason: {reason}.",
"error_code": 1,
}

response = response_json["candidates"][0]["content"]
pos = 0
while pos < len(response):
# simulate token streaming
pos += 5
time.sleep(0.001)
data = {
"text": response[:pos],
"error_code": 0,
}
yield data


def ai2_api_stream_iter(
model_name,
model_id,
Expand Down
Loading

0 comments on commit eecdde5

Please sign in to comment.