From 625b876a7869af5ecc78d2504152f5d42c0f93f0 Mon Sep 17 00:00:00 2001 From: xusenlin Date: Tue, 15 Aug 2023 10:41:14 +0800 Subject: [PATCH] Fix model max length --- api/generation/core.py | 11 +++++------ api/models.py | 10 +++++++++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/api/generation/core.py b/api/generation/core.py index 0e73795..2e9edad 100644 --- a/api/generation/core.py +++ b/api/generation/core.py @@ -13,10 +13,10 @@ ) from api.apapter import get_prompt_adapter -from api.utils.constants import ErrorCode from api.generation.baichuan import build_baichuan_chat_input, check_is_baichuan from api.generation.chatglm import generate_stream_chatglm, check_is_chatglm from api.generation.qwen import build_qwen_chat_input, check_is_qwen +from api.utils.constants import ErrorCode from api.utils.protocol import ChatMessage server_error_msg = ( @@ -298,11 +298,7 @@ def __init__( self.model_name = model_name.lower() self.prompt_name = prompt_name.lower() if prompt_name is not None else None self.stream_interval = stream_interval - - if context_len is None: - self.context_len = get_context_length(self.model.config) - else: - self.context_len = context_len + self.context_len = context_len self.construct_prompt = True if check_is_chatglm(self.model): @@ -316,10 +312,13 @@ def __init__( logger.info("Using Qwen Model for Chat!") self.construct_prompt = False self.generate_stream_func = generate_stream + self.context_len = 8192 else: self.generate_stream_func = generate_stream self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name) + if self.context_len is None: + self.context_len = get_context_length(self.model.config) def count_token(self, params): prompt = params["prompt"] diff --git a/api/models.py b/api/models.py index 2c3ceca..10d6808 100644 --- a/api/models.py +++ b/api/models.py @@ -47,6 +47,14 @@ def get_generate_model(): ) +def get_context_len(model_config): + if "qwen" in config.MODEL_NAME.lower(): + max_model_len = config.CONTEXT_LEN or 8192 + else: + max_model_len = config.CONTEXT_LEN or model_config.get_max_model_len() + return max_model_len + + def get_vllm_engine(): try: from vllm.engine.arg_utils import AsyncEngineArgs @@ -76,7 +84,7 @@ def get_vllm_engine(): ) engine_model_config = asyncio.run(engine.get_model_config()) - engine.max_model_len = engine_model_config.get_max_model_len() + engine.max_model_len = get_context_len(engine_model_config) return engine