Skip to content

Commit

Permalink
Fix model input for chat
Browse files Browse the repository at this point in the history
  • Loading branch information
xusenlin committed Aug 15, 2023
1 parent 625b876 commit 141a5bd
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 112 deletions.
53 changes: 35 additions & 18 deletions api/generation/baichuan.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,44 @@
from typing import List

from transformers import PreTrainedTokenizer

from api.generation.utils import parse_messages
from api.utils.protocol import Role, ChatMessage


def build_baichuan_chat_input(tokenizer, messages: List[ChatMessage], context_len: int = 4096):
""" https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/modeling_baichuan.py """
total_input, round_input = [], []
for message in messages[::-1]:
role, content_tokens = message.role, tokenizer.encode(message.content)
if role in [Role.USER, Role.SYSTEM]:
round_input = [195] + content_tokens + round_input
if total_input and len(total_input) + len(round_input) > context_len:
break
def build_baichuan_chat_input(
tokenizer: PreTrainedTokenizer,
messages: List[ChatMessage],
context_len: int = 4096,
max_new_tokens: int = 256
) -> List[int]:
""" https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_utils.py """
max_input_tokens = context_len - max_new_tokens
system, rounds = parse_messages(messages)
system_tokens = tokenizer.encode(system)
max_history_tokens = max_input_tokens - len(system_tokens)

history_tokens = []
for round in rounds[::-1]:
round_tokens = []
for message in round:
if message.role == Role.USER:
round_tokens.append(195)
else:
total_input = round_input + total_input
round_input = []
elif role == Role.ASSISTANT:
round_input = [196] + content_tokens + round_input
else:
raise ValueError(f"message role not supported yet: {role}")
total_input = total_input[-context_len:] # truncate left
total_input.append(196)
return total_input
round_tokens.append(196)
round_tokens.extend(tokenizer.encode(message.content))

if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break

input_tokens = system_tokens + history_tokens
if messages[-1].role != Role.ASSISTANT:
input_tokens.append(196)

return input_tokens[-max_input_tokens:] # truncate left


def check_is_baichuan(model):
Expand Down
57 changes: 4 additions & 53 deletions api/generation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@
import torch
import torch.nn.functional as F
from loguru import logger
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)

from api.apapter import get_prompt_adapter
from api.generation.baichuan import build_baichuan_chat_input, check_is_baichuan
from api.generation.chatglm import generate_stream_chatglm, check_is_chatglm
from api.generation.qwen import build_qwen_chat_input, check_is_qwen
from api.generation.utils import prepare_logits_processor, is_partial_stop, get_context_length
from api.utils.constants import ErrorCode
from api.utils.protocol import ChatMessage

Expand All @@ -24,30 +18,6 @@
)


def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list


def is_partial_stop(output: str, stop_str: str):
"""Check whether the output contains a partial stop str."""
for i in range(0, min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]):
return True
return False


@torch.inference_mode()
def generate_stream(
model,
Expand Down Expand Up @@ -76,9 +46,9 @@ def generate_stream(
)

if isinstance(prompt, list) and check_is_baichuan(model):
input_ids = build_baichuan_chat_input(tokenizer, prompt, context_len)
input_ids = build_baichuan_chat_input(tokenizer, prompt, context_len, max_new_tokens)
elif isinstance(prompt, list) and check_is_qwen(model):
input_ids = build_qwen_chat_input(tokenizer, prompt)
input_ids = build_qwen_chat_input(tokenizer, prompt, context_len, max_new_tokens)
stop_token_ids.extend([tokenizer.im_end_id, tokenizer.im_start_id])
else:
input_ids = tokenizer(prompt).input_ids
Expand Down Expand Up @@ -262,25 +232,6 @@ def generate_stream(
torch.cuda.empty_cache()


SEQUENCE_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_position_embeddings",
"max_seq_len",
"model_max_length",
]


def get_context_length(config):
"""Get the context length of a model from a huggingface model config."""
for key in SEQUENCE_LENGTH_KEYS:
if hasattr(config, key):
val = getattr(config, key)
if val is not None:
return val
return 2048


class ModelServer:
def __init__(
self,
Expand Down Expand Up @@ -312,7 +263,7 @@ def __init__(
logger.info("Using Qwen Model for Chat!")
self.construct_prompt = False
self.generate_stream_func = generate_stream
self.context_len = 8192
self.context_len = 8192 if self.context_len is None else self.context_len
else:
self.generate_stream_func = generate_stream

Expand Down
64 changes: 38 additions & 26 deletions api/generation/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@

from transformers import PreTrainedTokenizer

from api.generation.baichuan import parse_messages
from api.utils.protocol import Role, ChatMessage


def build_qwen_chat_input(
tokenizer: PreTrainedTokenizer,
messages: List[ChatMessage],
max_window_size: int = 6144,
):
context_len: int = 8192,
max_new_tokens: int = 256
) -> List[int]:
""" https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py """
max_input_tokens = context_len - max_new_tokens
system, rounds = parse_messages(messages)
system = "You are a helpful assistant." + system # fix system prompt

im_start_tokens, im_end_tokens = [tokenizer.im_start_id], [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")

Expand All @@ -19,31 +25,37 @@ def _tokenize_str(role, content):
role, allowed_special=set()
) + nl_tokens + tokenizer.encode(content, allowed_special=set())

system_tokens_part = _tokenize_str("system", "You are a helpful assistant.")
system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens

context_tokens = []
for i, message in enumerate(messages[::-1]):
role, content = message.role, message.content
if context_tokens:
context_tokens = nl_tokens + context_tokens

if role == Role.USER:
content_tokens = _tokenize_str("user", content)
elif role == Role.SYSTEM:
content_tokens = _tokenize_str("system", content)
elif role == Role.ASSISTANT:
content_tokens = _tokenize_str("assistant", content)
else:
raise ValueError(f"message role not supported yet: {role}")

if len(im_start_tokens + content_tokens + im_end_tokens + context_tokens) > max_window_size:
break
else:
context_tokens = im_start_tokens + content_tokens + im_end_tokens + context_tokens

context_tokens = system_tokens + nl_tokens + context_tokens
return context_tokens + nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens
max_history_tokens = max_input_tokens - len(system_tokens)

history_tokens = []
for round in rounds[::-1]:
round_tokens = []
for message in round:
if round_tokens:
round_tokens += nl_tokens

if message.role == Role.USER:
content_tokens = im_start_tokens + _tokenize_str("user", message.content) + im_end_tokens
else:
content_tokens = im_start_tokens + _tokenize_str("assistant", message.content) + im_end_tokens

round_tokens.extend(content_tokens)

if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
if history_tokens:
history_tokens = nl_tokens + history_tokens

history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break

input_tokens = system_tokens + nl_tokens + history_tokens
if messages[-1].role != Role.ASSISTANT:
input_tokens += nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens
return input_tokens[-max_input_tokens:] # truncate left


def check_is_qwen(model):
Expand Down
80 changes: 80 additions & 0 deletions api/generation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import List
from typing import Tuple

from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)

from api.utils.protocol import ChatMessage, Role


def parse_messages(messages: List[ChatMessage], split_role=Role.USER) -> Tuple[str, List[List[ChatMessage]]]:
system, rounds = "", []
round = []
for i, message in enumerate(messages):
if message.role == Role.SYSTEM:
system = message.content
continue
if message.role == split_role and round:
rounds.append(round)
round = []
round.append(message)
if round:
rounds.append(round)
return system, rounds


def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list


def is_partial_stop(output: str, stop_str: str):
"""Check whether the output contains a partial stop str."""
for i in range(0, min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]):
return True
return False


# Models don't use the same configuration key for determining the maximum
# sequence length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
SEQUENCE_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_position_embeddings",
"max_seq_len",
"model_max_length",
]


def get_context_length(config):
"""Get the context length of a model from a huggingface model config."""
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = config.rope_scaling["factor"]
else:
rope_scaling_factor = 1

for key in SEQUENCE_LENGTH_KEYS:
val = getattr(config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048
26 changes: 11 additions & 15 deletions api/vllm_routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,17 @@ async def get_model_inputs(request, prompt, model_name):
input_ids = prompt
else:
if "baichuan-13b" in model_name:
input_ids = build_baichuan_chat_input(VLLM_ENGINE.encode_tokenizer, prompt)
input_ids = build_baichuan_chat_input(
VLLM_ENGINE.encode_tokenizer,
prompt,
max_new_tokens=request.max_tokens,
)
elif "qwen" in model_name:
input_ids = build_qwen_chat_input(VLLM_ENGINE.encode_tokenizer, prompt)
input_ids = build_qwen_chat_input(
VLLM_ENGINE.encode_tokenizer,
prompt,
max_new_tokens=request.max_tokens,
)
else:
raise ValueError(f"Model not supported yet: {model_name}")

token_num = len(input_ids)
if token_num + request.max_tokens > VLLM_ENGINE.max_model_len:
return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {VLLM_ENGINE.max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return input_ids, None
return input_ids, None

0 comments on commit 141a5bd

Please sign in to comment.