Skip to content

Commit

Permalink
Fix model input length
Browse files Browse the repository at this point in the history
  • Loading branch information
xusenlin committed Aug 15, 2023
1 parent 141a5bd commit 36bbdbb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions api/vllm_routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ async def get_gen_prompt(request, model_name):


async def get_model_inputs(request, prompt, model_name):
max_input_tokens = VLLM_ENGINE.max_model_len - request.max_tokens
if isinstance(prompt, str):
input_ids = VLLM_ENGINE.encode_tokenizer(prompt).input_ids
input_ids = VLLM_ENGINE.encode_tokenizer(prompt).input_ids[-max_input_tokens:] # truncate left
elif isinstance(prompt[0], int):
input_ids = prompt
input_ids = prompt[-max_input_tokens:] # truncate left
else:
if "baichuan-13b" in model_name:
input_ids = build_baichuan_chat_input(
Expand Down

0 comments on commit 36bbdbb

Please sign in to comment.