diff --git a/api/vllm_routes/utils.py b/api/vllm_routes/utils.py index 0782561..29e6f6f 100644 --- a/api/vllm_routes/utils.py +++ b/api/vllm_routes/utils.py @@ -21,10 +21,11 @@ async def get_gen_prompt(request, model_name): async def get_model_inputs(request, prompt, model_name): + max_input_tokens = VLLM_ENGINE.max_model_len - request.max_tokens if isinstance(prompt, str): - input_ids = VLLM_ENGINE.encode_tokenizer(prompt).input_ids + input_ids = VLLM_ENGINE.encode_tokenizer(prompt).input_ids[-max_input_tokens:] # truncate left elif isinstance(prompt[0], int): - input_ids = prompt + input_ids = prompt[-max_input_tokens:] # truncate left else: if "baichuan-13b" in model_name: input_ids = build_baichuan_chat_input(