Skip to content

Commit

Permalink
Fix model max length
Browse files Browse the repository at this point in the history
  • Loading branch information
xusenlin committed Aug 15, 2023
1 parent d8313af commit 625b876
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
11 changes: 5 additions & 6 deletions api/generation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
)

from api.apapter import get_prompt_adapter
from api.utils.constants import ErrorCode
from api.generation.baichuan import build_baichuan_chat_input, check_is_baichuan
from api.generation.chatglm import generate_stream_chatglm, check_is_chatglm
from api.generation.qwen import build_qwen_chat_input, check_is_qwen
from api.utils.constants import ErrorCode
from api.utils.protocol import ChatMessage

server_error_msg = (
Expand Down Expand Up @@ -298,11 +298,7 @@ def __init__(
self.model_name = model_name.lower()
self.prompt_name = prompt_name.lower() if prompt_name is not None else None
self.stream_interval = stream_interval

if context_len is None:
self.context_len = get_context_length(self.model.config)
else:
self.context_len = context_len
self.context_len = context_len

self.construct_prompt = True
if check_is_chatglm(self.model):
Expand All @@ -316,10 +312,13 @@ def __init__(
logger.info("Using Qwen Model for Chat!")
self.construct_prompt = False
self.generate_stream_func = generate_stream
self.context_len = 8192
else:
self.generate_stream_func = generate_stream

self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
if self.context_len is None:
self.context_len = get_context_length(self.model.config)

def count_token(self, params):
prompt = params["prompt"]
Expand Down
10 changes: 9 additions & 1 deletion api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def get_generate_model():
)


def get_context_len(model_config):
if "qwen" in config.MODEL_NAME.lower():
max_model_len = config.CONTEXT_LEN or 8192
else:
max_model_len = config.CONTEXT_LEN or model_config.get_max_model_len()
return max_model_len


def get_vllm_engine():
try:
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -76,7 +84,7 @@ def get_vllm_engine():
)

engine_model_config = asyncio.run(engine.get_model_config())
engine.max_model_len = engine_model_config.get_max_model_len()
engine.max_model_len = get_context_len(engine_model_config)

return engine

Expand Down

0 comments on commit 625b876

Please sign in to comment.