From 141a5bd45c99e65a31fa16d264050605ad67bca2 Mon Sep 17 00:00:00 2001 From: xusenlin Date: Tue, 15 Aug 2023 15:26:41 +0800 Subject: [PATCH] Fix model input for chat --- api/generation/baichuan.py | 53 ++++++++++++++++--------- api/generation/core.py | 57 ++------------------------- api/generation/qwen.py | 64 +++++++++++++++++------------- api/generation/utils.py | 80 ++++++++++++++++++++++++++++++++++++++ api/vllm_routes/utils.py | 26 ++++++------- 5 files changed, 168 insertions(+), 112 deletions(-) create mode 100644 api/generation/utils.py diff --git a/api/generation/baichuan.py b/api/generation/baichuan.py index 09ddc54..1173f6f 100644 --- a/api/generation/baichuan.py +++ b/api/generation/baichuan.py @@ -1,27 +1,44 @@ from typing import List +from transformers import PreTrainedTokenizer + +from api.generation.utils import parse_messages from api.utils.protocol import Role, ChatMessage -def build_baichuan_chat_input(tokenizer, messages: List[ChatMessage], context_len: int = 4096): - """ https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/modeling_baichuan.py """ - total_input, round_input = [], [] - for message in messages[::-1]: - role, content_tokens = message.role, tokenizer.encode(message.content) - if role in [Role.USER, Role.SYSTEM]: - round_input = [195] + content_tokens + round_input - if total_input and len(total_input) + len(round_input) > context_len: - break +def build_baichuan_chat_input( + tokenizer: PreTrainedTokenizer, + messages: List[ChatMessage], + context_len: int = 4096, + max_new_tokens: int = 256 +) -> List[int]: + """ https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_utils.py """ + max_input_tokens = context_len - max_new_tokens + system, rounds = parse_messages(messages) + system_tokens = tokenizer.encode(system) + max_history_tokens = max_input_tokens - len(system_tokens) + + history_tokens = [] + for round in rounds[::-1]: + round_tokens = [] + for message in round: + if message.role == Role.USER: + round_tokens.append(195) else: - total_input = round_input + total_input - round_input = [] - elif role == Role.ASSISTANT: - round_input = [196] + content_tokens + round_input - else: - raise ValueError(f"message role not supported yet: {role}") - total_input = total_input[-context_len:] # truncate left - total_input.append(196) - return total_input + round_tokens.append(196) + round_tokens.extend(tokenizer.encode(message.content)) + + if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: + history_tokens = round_tokens + history_tokens # concat left + if len(history_tokens) < max_history_tokens: + continue + break + + input_tokens = system_tokens + history_tokens + if messages[-1].role != Role.ASSISTANT: + input_tokens.append(196) + + return input_tokens[-max_input_tokens:] # truncate left def check_is_baichuan(model): diff --git a/api/generation/core.py b/api/generation/core.py index 2e9edad..6f6e038 100644 --- a/api/generation/core.py +++ b/api/generation/core.py @@ -4,18 +4,12 @@ import torch import torch.nn.functional as F from loguru import logger -from transformers.generation.logits_process import ( - LogitsProcessorList, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, -) from api.apapter import get_prompt_adapter from api.generation.baichuan import build_baichuan_chat_input, check_is_baichuan from api.generation.chatglm import generate_stream_chatglm, check_is_chatglm from api.generation.qwen import build_qwen_chat_input, check_is_qwen +from api.generation.utils import prepare_logits_processor, is_partial_stop, get_context_length from api.utils.constants import ErrorCode from api.utils.protocol import ChatMessage @@ -24,30 +18,6 @@ ) -def prepare_logits_processor( - temperature: float, repetition_penalty: float, top_p: float, top_k: int -) -> LogitsProcessorList: - processor_list = LogitsProcessorList() - # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. - if temperature >= 1e-5 and temperature != 1.0: - processor_list.append(TemperatureLogitsWarper(temperature)) - if repetition_penalty > 1.0: - processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) - if 1e-8 <= top_p < 1.0: - processor_list.append(TopPLogitsWarper(top_p)) - if top_k > 0: - processor_list.append(TopKLogitsWarper(top_k)) - return processor_list - - -def is_partial_stop(output: str, stop_str: str): - """Check whether the output contains a partial stop str.""" - for i in range(0, min(len(output), len(stop_str))): - if stop_str.startswith(output[-i:]): - return True - return False - - @torch.inference_mode() def generate_stream( model, @@ -76,9 +46,9 @@ def generate_stream( ) if isinstance(prompt, list) and check_is_baichuan(model): - input_ids = build_baichuan_chat_input(tokenizer, prompt, context_len) + input_ids = build_baichuan_chat_input(tokenizer, prompt, context_len, max_new_tokens) elif isinstance(prompt, list) and check_is_qwen(model): - input_ids = build_qwen_chat_input(tokenizer, prompt) + input_ids = build_qwen_chat_input(tokenizer, prompt, context_len, max_new_tokens) stop_token_ids.extend([tokenizer.im_end_id, tokenizer.im_start_id]) else: input_ids = tokenizer(prompt).input_ids @@ -262,25 +232,6 @@ def generate_stream( torch.cuda.empty_cache() -SEQUENCE_LENGTH_KEYS = [ - "max_sequence_length", - "seq_length", - "max_position_embeddings", - "max_seq_len", - "model_max_length", -] - - -def get_context_length(config): - """Get the context length of a model from a huggingface model config.""" - for key in SEQUENCE_LENGTH_KEYS: - if hasattr(config, key): - val = getattr(config, key) - if val is not None: - return val - return 2048 - - class ModelServer: def __init__( self, @@ -312,7 +263,7 @@ def __init__( logger.info("Using Qwen Model for Chat!") self.construct_prompt = False self.generate_stream_func = generate_stream - self.context_len = 8192 + self.context_len = 8192 if self.context_len is None else self.context_len else: self.generate_stream_func = generate_stream diff --git a/api/generation/qwen.py b/api/generation/qwen.py index bbd9456..ecf42c3 100644 --- a/api/generation/qwen.py +++ b/api/generation/qwen.py @@ -2,15 +2,21 @@ from transformers import PreTrainedTokenizer +from api.generation.baichuan import parse_messages from api.utils.protocol import Role, ChatMessage def build_qwen_chat_input( tokenizer: PreTrainedTokenizer, messages: List[ChatMessage], - max_window_size: int = 6144, -): + context_len: int = 8192, + max_new_tokens: int = 256 +) -> List[int]: """ https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py """ + max_input_tokens = context_len - max_new_tokens + system, rounds = parse_messages(messages) + system = "You are a helpful assistant." + system # fix system prompt + im_start_tokens, im_end_tokens = [tokenizer.im_start_id], [tokenizer.im_end_id] nl_tokens = tokenizer.encode("\n") @@ -19,31 +25,37 @@ def _tokenize_str(role, content): role, allowed_special=set() ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) - system_tokens_part = _tokenize_str("system", "You are a helpful assistant.") + system_tokens_part = _tokenize_str("system", system) system_tokens = im_start_tokens + system_tokens_part + im_end_tokens - - context_tokens = [] - for i, message in enumerate(messages[::-1]): - role, content = message.role, message.content - if context_tokens: - context_tokens = nl_tokens + context_tokens - - if role == Role.USER: - content_tokens = _tokenize_str("user", content) - elif role == Role.SYSTEM: - content_tokens = _tokenize_str("system", content) - elif role == Role.ASSISTANT: - content_tokens = _tokenize_str("assistant", content) - else: - raise ValueError(f"message role not supported yet: {role}") - - if len(im_start_tokens + content_tokens + im_end_tokens + context_tokens) > max_window_size: - break - else: - context_tokens = im_start_tokens + content_tokens + im_end_tokens + context_tokens - - context_tokens = system_tokens + nl_tokens + context_tokens - return context_tokens + nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens + max_history_tokens = max_input_tokens - len(system_tokens) + + history_tokens = [] + for round in rounds[::-1]: + round_tokens = [] + for message in round: + if round_tokens: + round_tokens += nl_tokens + + if message.role == Role.USER: + content_tokens = im_start_tokens + _tokenize_str("user", message.content) + im_end_tokens + else: + content_tokens = im_start_tokens + _tokenize_str("assistant", message.content) + im_end_tokens + + round_tokens.extend(content_tokens) + + if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: + if history_tokens: + history_tokens = nl_tokens + history_tokens + + history_tokens = round_tokens + history_tokens # concat left + if len(history_tokens) < max_history_tokens: + continue + break + + input_tokens = system_tokens + nl_tokens + history_tokens + if messages[-1].role != Role.ASSISTANT: + input_tokens += nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens + return input_tokens[-max_input_tokens:] # truncate left def check_is_qwen(model): diff --git a/api/generation/utils.py b/api/generation/utils.py new file mode 100644 index 0000000..7dfbe0e --- /dev/null +++ b/api/generation/utils.py @@ -0,0 +1,80 @@ +from typing import List +from typing import Tuple + +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from api.utils.protocol import ChatMessage, Role + + +def parse_messages(messages: List[ChatMessage], split_role=Role.USER) -> Tuple[str, List[List[ChatMessage]]]: + system, rounds = "", [] + round = [] + for i, message in enumerate(messages): + if message.role == Role.SYSTEM: + system = message.content + continue + if message.role == split_role and round: + rounds.append(round) + round = [] + round.append(message) + if round: + rounds.append(round) + return system, rounds + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +def is_partial_stop(output: str, stop_str: str): + """Check whether the output contains a partial stop str.""" + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False + + +# Models don't use the same configuration key for determining the maximum +# sequence length. Store them here so we can sanely check them. +# NOTE: The ordering here is important. Some models have two of these and we +# have a preference for which value gets used. +SEQUENCE_LENGTH_KEYS = [ + "max_sequence_length", + "seq_length", + "max_position_embeddings", + "max_seq_len", + "model_max_length", +] + + +def get_context_length(config): + """Get the context length of a model from a huggingface model config.""" + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling: + rope_scaling_factor = config.rope_scaling["factor"] + else: + rope_scaling_factor = 1 + + for key in SEQUENCE_LENGTH_KEYS: + val = getattr(config, key, None) + if val is not None: + return int(rope_scaling_factor * val) + return 2048 diff --git a/api/vllm_routes/utils.py b/api/vllm_routes/utils.py index cf7cdc2..0782561 100644 --- a/api/vllm_routes/utils.py +++ b/api/vllm_routes/utils.py @@ -27,21 +27,17 @@ async def get_model_inputs(request, prompt, model_name): input_ids = prompt else: if "baichuan-13b" in model_name: - input_ids = build_baichuan_chat_input(VLLM_ENGINE.encode_tokenizer, prompt) + input_ids = build_baichuan_chat_input( + VLLM_ENGINE.encode_tokenizer, + prompt, + max_new_tokens=request.max_tokens, + ) elif "qwen" in model_name: - input_ids = build_qwen_chat_input(VLLM_ENGINE.encode_tokenizer, prompt) + input_ids = build_qwen_chat_input( + VLLM_ENGINE.encode_tokenizer, + prompt, + max_new_tokens=request.max_tokens, + ) else: raise ValueError(f"Model not supported yet: {model_name}") - - token_num = len(input_ids) - if token_num + request.max_tokens > VLLM_ENGINE.max_model_len: - return input_ids, create_error_response( - HTTPStatus.BAD_REQUEST, - f"This model's maximum context length is {VLLM_ENGINE.max_model_len} tokens. " - f"However, you requested {request.max_tokens + token_num} tokens " - f"({token_num} in the messages, " - f"{request.max_tokens} in the completion). " - f"Please reduce the length of the messages or completion.", - ) - else: - return input_ids, None + return input_ids, None