From b52ba65c267c4d8bf05cd57ed3386a2d466887db Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Wed, 9 Oct 2024 22:58:06 +0800 Subject: [PATCH] [Feature] Integrate lmdeploy pipeline api (#1198) * integrate lmdeploy's pipeline api * fix linting * update user guide * rename * update * update * update * rollback class name * update * remove unused code * update * update * fix ci check * compatibility * remove concurrency * Update configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py * Update docs/zh_cn/advanced_guides/evaluation_lmdeploy.md * [Bug] fix lint --------- Co-authored-by: Songyang Zhang Co-authored-by: tonysy --- .../eval_internlm_chat_lmdeploy_pytorch.py | 69 ------ configs/eval_internlm_chat_lmdeploy_tis.py | 41 ---- configs/eval_internlm_chat_turbomind_tis.py | 40 ---- configs/eval_internlm_turbomind_tis.py | 28 --- .../hf_internlm/lmdeploy_internlm2_chat_7b.py | 17 +- .../en/advanced_guides/evaluation_lmdeploy.md | 88 ++++++++ .../advanced_guides/evaluation_turbomind.md | 78 ------- .../advanced_guides/evaluation_lmdeploy.md | 86 ++++++++ .../advanced_guides/evaluation_turbomind.md | 75 ------- .../hf_internlm/lmdeploy_internlm2_chat_7b.py | 17 +- opencompass/models/__init__.py | 3 - opencompass/models/lmdeploy_pytorch.py | 188 ---------------- opencompass/models/lmdeploy_tis.py | 200 ------------------ opencompass/models/turbomind_tis.py | 135 ------------ .../models/turbomind_with_tf_above_v4_33.py | 128 ++++------- opencompass/utils/run.py | 11 +- 16 files changed, 249 insertions(+), 955 deletions(-) delete mode 100644 configs/eval_internlm_chat_lmdeploy_pytorch.py delete mode 100644 configs/eval_internlm_chat_lmdeploy_tis.py delete mode 100644 configs/eval_internlm_chat_turbomind_tis.py delete mode 100644 configs/eval_internlm_turbomind_tis.py create mode 100644 docs/en/advanced_guides/evaluation_lmdeploy.md delete mode 100644 docs/en/advanced_guides/evaluation_turbomind.md create mode 100644 docs/zh_cn/advanced_guides/evaluation_lmdeploy.md delete mode 100644 docs/zh_cn/advanced_guides/evaluation_turbomind.md delete mode 100644 opencompass/models/lmdeploy_pytorch.py delete mode 100644 opencompass/models/lmdeploy_tis.py delete mode 100644 opencompass/models/turbomind_tis.py diff --git a/configs/eval_internlm_chat_lmdeploy_pytorch.py b/configs/eval_internlm_chat_lmdeploy_pytorch.py deleted file mode 100644 index 4ea1f84c2..000000000 --- a/configs/eval_internlm_chat_lmdeploy_pytorch.py +++ /dev/null @@ -1,69 +0,0 @@ -from mmengine.config import read_base -from opencompass.models import LmdeployPytorchModel - - -with read_base(): - # choose a list of datasets - from opencompass.configs.datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets - from opencompass.configs.datasets.ceval.ceval_gen_5f30c7 import ceval_datasets - from opencompass.configs.datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets - from opencompass.configs.datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets - from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets - from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets - from opencompass.configs.datasets.race.race_gen_69ee4f import race_datasets - from opencompass.configs.datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets - # and output the results in a choosen format - from opencompass.configs.summarizers.medium import summarizer - - -datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) - - -meta_template = dict( - round=[ - dict(role='HUMAN', begin='<|User|>:', end='\n'), - dict(role='BOT', begin='<|Bot|>:', end='\n', generate=True), - ], - eos_token_id=103028) - -# config for internlm-chat-7b -internlm_chat_7b = dict( - type=LmdeployPytorchModel, - abbr='internlm-chat-7b-pytorch', - path='internlm/internlm-chat-7b', - engine_config=dict(session_len=2048, - max_batch_size=16), - gen_config=dict(top_k=1, - top_p=0.8, - temperature=1.0, - max_new_tokens=100), - max_out_len=100, - max_seq_len=2048, - batch_size=16, - concurrency=16, - meta_template=meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), - end_str='', -) - -# config for internlm-chat-20b -internlm_chat_20b = dict( - type=LmdeployPytorchModel, - abbr='internlm-chat-20b-pytorch', - path='internlm/internlm-chat-20b', - engine_config=dict(session_len=2048, - max_batch_size=8), - gen_config=dict(top_k=1, - top_p=0.8, - temperature=1.0, - max_new_tokens=100), - max_out_len=100, - max_seq_len=2048, - batch_size=8, - concurrency=8, - meta_template=meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), - end_str='', - ) - -models = [internlm_chat_20b] diff --git a/configs/eval_internlm_chat_lmdeploy_tis.py b/configs/eval_internlm_chat_lmdeploy_tis.py deleted file mode 100644 index 8f5470d52..000000000 --- a/configs/eval_internlm_chat_lmdeploy_tis.py +++ /dev/null @@ -1,41 +0,0 @@ -from mmengine.config import read_base -from opencompass.models.lmdeploy_tis import LmdeployTisModel - -with read_base(): - # choose a list of datasets - from opencompass.configs.datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets - from opencompass.configs.datasets.ceval.ceval_gen_5f30c7 import ceval_datasets - from opencompass.configs.datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets - from opencompass.configs.datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets - from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets - from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets - from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets - from opencompass.configs.datasets.race.race_gen_69ee4f import race_datasets - from opencompass.configs.datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets - # and output the results in a choosen format - from opencompass.configs.summarizers.medium import summarizer - -datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) - -meta_template = dict( - round=[ - dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'), - dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True), - ], - eos_token_id=92542 -) - -models = [ - dict( - type=LmdeployTisModel, - abbr='internlm-chat-20b-lmdeploy-tis', - path='internlm/internlm-chat-20b', - tis_addr='0.0.0.0:33337', - max_out_len=100, - max_seq_len=2048, - batch_size=8, - meta_template=meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), - end_str='<|im_end|>', - ) -] diff --git a/configs/eval_internlm_chat_turbomind_tis.py b/configs/eval_internlm_chat_turbomind_tis.py deleted file mode 100644 index 01f42000f..000000000 --- a/configs/eval_internlm_chat_turbomind_tis.py +++ /dev/null @@ -1,40 +0,0 @@ -from mmengine.config import read_base -from opencompass.models.turbomind_tis import TurboMindTisModel - -with read_base(): - # choose a list of datasets - from opencompass.configs.datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets - from opencompass.configs.datasets.ceval.ceval_gen_5f30c7 import ceval_datasets - from opencompass.configs.datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets - from opencompass.configs.datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets - from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets - from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets - from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets - from opencompass.configs.datasets.race.race_gen_69ee4f import race_datasets - from opencompass.configs.datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets - # and output the results in a choosen format - from opencompass.configs.summarizers.medium import summarizer - -datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) - - -meta_template = dict( - round=[ - dict(role='HUMAN', begin='<|User|>:', end='\n'), - dict(role='BOT', begin='<|Bot|>:', end='\n', generate=True), - ], - eos_token_id=103028) - -models = [ - dict( - type=TurboMindTisModel, - abbr='internlm-chat-20b-turbomind', - path='internlm', - tis_addr='0.0.0.0:33337', - max_out_len=100, - max_seq_len=2048, - batch_size=8, - meta_template=meta_template, - run_cfg=dict(num_gpus=1, num_procs=1), - ) -] diff --git a/configs/eval_internlm_turbomind_tis.py b/configs/eval_internlm_turbomind_tis.py deleted file mode 100644 index 98914fa47..000000000 --- a/configs/eval_internlm_turbomind_tis.py +++ /dev/null @@ -1,28 +0,0 @@ -from mmengine.config import read_base -from opencompass.models.turbomind_tis import TurboMindTisModel - -with read_base(): - # choose a list of datasets - from opencompass.configs.datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets - from opencompass.configs.datasets.ceval.ceval_gen_5f30c7 import ceval_datasets - from opencompass.configs.datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets - from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets - from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets - from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets - # and output the results in a choosen format - from opencompass.configs.summarizers.medium import summarizer - -datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) - -models = [ - dict( - type=TurboMindTisModel, - abbr='internlm-chat-20b-turbomind', - path='internlm', - tis_addr='0.0.0.0:33337', - max_out_len=100, - max_seq_len=2048, - batch_size=8, - run_cfg=dict(num_gpus=1, num_procs=1), - ) -] diff --git a/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py b/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py index 60097e373..38ea39d7d 100644 --- a/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py +++ b/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py @@ -1,15 +1,24 @@ from opencompass.models import TurboMindModelwithChatTemplate + models = [ dict( type=TurboMindModelwithChatTemplate, - abbr='internlm2-chat-7b-turbomind', + abbr=f'internlm2-chat-7b-lmdeploy', path='internlm/internlm2-chat-7b', - engine_config=dict(session_len=8192, max_batch_size=16, tp=1), - gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=4096), + # inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'. + # If the model is not supported by 'turbomind', it will fallback to + # 'pytorch' + backend='turbomind', + # For the detailed engine config and generation config, please refer to + # https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py + engine_config=dict(tp=1), + gen_config=dict(do_sample=False), max_seq_len=8192, max_out_len=4096, - batch_size=16, + # the max number of prompts that LMDeploy receives + # in `generate` function + batch_size=5000, run_cfg=dict(num_gpus=1), ) ] diff --git a/docs/en/advanced_guides/evaluation_lmdeploy.md b/docs/en/advanced_guides/evaluation_lmdeploy.md new file mode 100644 index 000000000..bfacd4881 --- /dev/null +++ b/docs/en/advanced_guides/evaluation_lmdeploy.md @@ -0,0 +1,88 @@ +# Evaluation with LMDeploy + +We now support evaluation of models accelerated by the [LMDeploy](https://github.com/InternLM/lmdeploy). LMDeploy is a toolkit designed for compressing, deploying, and serving LLM. It has a remarkable inference performance. We now illustrate how to evaluate a model with the support of LMDeploy in OpenCompass. + +## Setup + +### Install OpenCompass + +Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets. + +### Install LMDeploy + +Install lmdeploy via pip (python 3.8+) + +```shell +pip install lmdeploy +``` + +The default prebuilt package is compiled on CUDA 12. However, if CUDA 11+ is required, you can install lmdeploy by: + +```shell +export LMDEPLOY_VERSION=0.6.0 +export PYTHON_VERSION=310 +pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 +``` + +## Evaluation + +When evaluating a model, it is necessary to prepare an evaluation configuration that specifies information such as the evaluation dataset, the model, and inference parameters. + +Taking [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) as an example, the evaluation config is as follows: + +```python +# configure the dataset +from mmengine.config import read_base + + +with read_base(): + # choose a list of datasets + from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets + from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets + from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets + from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import \ + gsm8k_datasets + # and output the results in a chosen format + from .summarizers.medium import summarizer + +datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + +# configure lmdeploy +from opencompass.models import TurboMindModelwithChatTemplate + + + +# configure the model +models = [ + dict( + type=TurboMindModelwithChatTemplate, + abbr=f'internlm2-chat-7b-lmdeploy', + # model path, which can be the address of a model repository on the Hugging Face Hub or a local path + path='internlm/internlm2-chat-7b', + # inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'. + # If the model is not supported by 'turbomind', it will fallback to + # 'pytorch' + backend='turbomind', + # For the detailed engine config and generation config, please refer to + # https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py + engine_config=dict(tp=1), + gen_config=dict(do_sample=False), + # the max size of the context window + max_seq_len=7168, + # the max number of new tokens + max_out_len=1024, + # the max number of prompts that LMDeploy receives + # in `generate` function + batch_size=5000, + run_cfg=dict(num_gpus=1), + ) +] +``` + +Place the aforementioned configuration in a file, such as "configs/eval_internlm2_lmdeploy.py". Then, in the home folder of OpenCompass, start evaluation by the following command: + +```shell +python run.py configs/eval_internlm2_lmdeploy.py -w outputs +``` + +You are expected to get the evaluation results after the inference and evaluation. diff --git a/docs/en/advanced_guides/evaluation_turbomind.md b/docs/en/advanced_guides/evaluation_turbomind.md deleted file mode 100644 index c1299f0b3..000000000 --- a/docs/en/advanced_guides/evaluation_turbomind.md +++ /dev/null @@ -1,78 +0,0 @@ -# Evaluation with LMDeploy - -We now support evaluation of models accelerated by the [LMDeploy](https://github.com/InternLM/lmdeploy). LMDeploy is a toolkit designed for compressing, deploying, and serving LLM. **TurboMind** is an efficient inference engine proposed by LMDeploy. OpenCompass is compatible with TurboMind. We now illustrate how to evaluate a model with the support of TurboMind in OpenCompass. - -## Setup - -### Install OpenCompass - -Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets. - -### Install LMDeploy - -Install lmdeploy via pip (python 3.8+) - -```shell -pip install lmdeploy -``` - -## Evaluation - -OpenCompass integrates turbomind's python API for evaluation. - -We take the InternLM-20B as example. Firstly, we prepare the evaluation config `configs/eval_internlm_turbomind.py`: - -```python -from mmengine.config import read_base -from opencompass.models.turbomind import TurboMindModel - - -with read_base(): - # choose a list of datasets - from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets - from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets - from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets - from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets - from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets - from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets - # and output the results in a chosen format - from .summarizers.medium import summarizer - -datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) - -# config for internlm-20b model -internlm_20b = dict( - type=TurboMindModel, - abbr='internlm-20b-turbomind', - path="internlm/internlm-20b", # this path should be same as in huggingface - engine_config=dict(session_len=2048, - max_batch_size=8, - rope_scaling_factor=1.0), - gen_config=dict(top_k=1, top_p=0.8, - temperature=1.0, - max_new_tokens=100), - max_out_len=100, - max_seq_len=2048, - batch_size=8, - concurrency=8, - run_cfg=dict(num_gpus=1, num_procs=1), - end_str='' - ) - -models = [internlm_20b] -``` - -Then, in the home folder of OpenCompass, start evaluation by the following command: - -```shell -python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b -``` - -You are expected to get the evaluation results after the inference and evaluation. - -**Note**: - -- If you want to pass more arguments for `engine_config`和`gen_config` in the evaluation config file, please refer to [TurbomindEngineConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#turbomindengineconfig) - and [GenerationConfig](https://lmdeploy.readthedocs.io/en/latest/inference/pipeline.html#generationconfig) -- If you evaluate the InternLM Chat model, please use configuration file `eval_internlm_chat_turbomind.py` -- If you evaluate the InternLM 7B model, please modify `eval_internlm_turbomind.py` or `eval_internlm_chat_turbomind.py` by changing to the setting `models = [internlm_7b]` in the last line. diff --git a/docs/zh_cn/advanced_guides/evaluation_lmdeploy.md b/docs/zh_cn/advanced_guides/evaluation_lmdeploy.md new file mode 100644 index 000000000..158399641 --- /dev/null +++ b/docs/zh_cn/advanced_guides/evaluation_lmdeploy.md @@ -0,0 +1,86 @@ +# 使用 LMDeploy 加速评测 + +我们支持在评测大语言模型时,使用 [LMDeploy](https://github.com/InternLM/lmdeploy) 作为推理加速引擎。LMDeploy 是涵盖了 LLM 和 VLM 任务的全套轻量化、部署和服务解决方案,拥有卓越的推理性能。本教程将介绍如何使用 LMDeploy 加速对模型的评测。 + +## 环境配置 + +### 安装 OpenCompass + +请根据 OpenCompass [安装指南](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) 来安装算法库和准备数据集。 + +### 安装 LMDeploy + +使用 pip 安装 LMDeploy (python 3.8+): + +```shell +pip install lmdeploy +``` + +LMDeploy 预编译包默认基于 CUDA 12 编译。如果需要在 CUDA 11+ 下安装 LMDeploy,请执行以下命令: + +```shell +export LMDEPLOY_VERSION=0.6.0 +export PYTHON_VERSION=310 +pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 +``` + +## 评测 + +在评测一个模型时,需要准备一份评测配置,指明评测集、模型和推理参数等信息。 + +以 [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) 模型为例,相关的配置信息如下: + +```python +# configure the dataset +from mmengine.config import read_base + + +with read_base(): + # choose a list of datasets + from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets + from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets + from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets + from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import \ + gsm8k_datasets + # and output the results in a chosen format + from .summarizers.medium import summarizer + +datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + +# configure lmdeploy +from opencompass.models import TurboMindModelwithChatTemplate + + + +# configure the model +models = [ + dict( + type=TurboMindModelwithChatTemplate, + abbr=f'internlm2-chat-7b-lmdeploy', + # model path, which can be the address of a model repository on the Hugging Face Hub or a local path + path='internlm/internlm2-chat-7b', + # inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'. + # If the model is not supported by 'turbomind', it will fallback to + # 'pytorch' + backend='turbomind', + # For the detailed engine config and generation config, please refer to + # https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py + engine_config=dict(tp=1), + gen_config=dict(do_sample=False), + # the max size of the context window + max_seq_len=7168, + # the max number of new tokens + max_out_len=1024, + # the max number of prompts that LMDeploy receives + # in `generate` function + batch_size=32, + run_cfg=dict(num_gpus=1), + ) +] +``` + +把上述配置放在文件中,比如 "configs/eval_internlm2_lmdeploy.py"。然后,在 OpenCompass 的项目目录下,执行如下命令可得到评测结果: + +```shell +python run.py configs/eval_internlm2_lmdeploy.py -w outputs +``` diff --git a/docs/zh_cn/advanced_guides/evaluation_turbomind.md b/docs/zh_cn/advanced_guides/evaluation_turbomind.md deleted file mode 100644 index a7c37b758..000000000 --- a/docs/zh_cn/advanced_guides/evaluation_turbomind.md +++ /dev/null @@ -1,75 +0,0 @@ -# 评测 LMDeploy 模型 - -我们支持评测使用 [LMDeploy](https://github.com/InternLM/lmdeploy) 加速过的大语言模型。LMDeploy 由 MMDeploy 和 MMRazor 团队联合开发,是涵盖了 LLM 任务的全套轻量化、部署和服务解决方案。 **TurboMind** 是 LMDeploy 推出的高效推理引擎。OpenCompass 对 TurboMind 进行了适配,本教程将介绍如何使用 OpenCompass 来对 TurboMind 加速后的模型进行评测。 - -## 环境配置 - -### 安装 OpenCompass - -请根据 OpenCompass [安装指南](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) 来安装算法库和准备数据集。 - -### 安装 LMDeploy - -使用 pip 安装 LMDeploy (python 3.8+): - -```shell -pip install lmdeploy -``` - -## 评测 - -OpenCompass 支持分别通过 turbomind python API 评测数据集。 - -下文以 InternLM-20B 模型为例,介绍如何评测。首先我们准备好测试配置文件`configs/eval_internlm_turbomind.py`: - -```python -from mmengine.config import read_base -from opencompass.models.turbomind import TurboMindModel - - -with read_base(): - # choose a list of datasets - from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets - from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets - from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets - from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets - from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets - from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets - # and output the results in a chosen format - from .summarizers.medium import summarizer - -datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) - -# config for internlm-20b model -internlm_20b = dict( - type=TurboMindModel, - abbr='internlm-20b-turbomind', - path="internlm/internlm-20b", # 注意路径与huggingface保持一致 - engine_config=dict(session_len=2048, - max_batch_size=8, - rope_scaling_factor=1.0), - gen_config=dict(top_k=1, top_p=0.8, - temperature=1.0, - max_new_tokens=100), - max_out_len=100, - max_seq_len=2048, - batch_size=8, - concurrency=8, - run_cfg=dict(num_gpus=1, num_procs=1), - end_str='' - ) - -models = [internlm_20b] -``` - -然后,在 OpenCompass 的项目目录下,执行如下命令可得到评测结果: - -```shell -python run.py configs/eval_internlm_turbomind.py -w outputs/turbomind/internlm-20b -``` - -**注:** - -- 如果想在测评配置文件中`engine_config`和`gen_config`字段传递更多参数,请参考[TurbomindEngineConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#turbomindengineconfig) 和 [GenerationConfig](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/pipeline.html#generationconfig) -- 如果评测 InternLM Chat 模型,请使用配置文件 `eval_internlm_chat_turbomind.py` -- 如果评测 InternLM 7B 模型,请修改 `eval_internlm_turbomind.py` 或者 `eval_internlm_chat_turbomind.py`。将`models`字段配置为`models = [internlm_7b]` 。 diff --git a/opencompass/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py b/opencompass/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py index 60097e373..38ea39d7d 100644 --- a/opencompass/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py +++ b/opencompass/configs/models/hf_internlm/lmdeploy_internlm2_chat_7b.py @@ -1,15 +1,24 @@ from opencompass.models import TurboMindModelwithChatTemplate + models = [ dict( type=TurboMindModelwithChatTemplate, - abbr='internlm2-chat-7b-turbomind', + abbr=f'internlm2-chat-7b-lmdeploy', path='internlm/internlm2-chat-7b', - engine_config=dict(session_len=8192, max_batch_size=16, tp=1), - gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=4096), + # inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'. + # If the model is not supported by 'turbomind', it will fallback to + # 'pytorch' + backend='turbomind', + # For the detailed engine config and generation config, please refer to + # https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py + engine_config=dict(tp=1), + gen_config=dict(do_sample=False), max_seq_len=8192, max_out_len=4096, - batch_size=16, + # the max number of prompts that LMDeploy receives + # in `generate` function + batch_size=5000, run_cfg=dict(num_gpus=1), ) ] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 0f55b869c..580402d46 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -25,8 +25,6 @@ from .krgpt_api import KrGPT # noqa: F401 from .lightllm_api import LightllmAPI, LightllmChatAPI # noqa: F401 from .llama2 import Llama2, Llama2Chat # noqa: F401 -from .lmdeploy_pytorch import LmdeployPytorchModel # noqa: F401 -from .lmdeploy_tis import LmdeployTisModel # noqa: F401 from .minimax_api import MiniMax, MiniMaxChatCompletionV2 # noqa: F401 from .mistral_api import Mistral # noqa: F401 from .mixtral import Mixtral # noqa: F401 @@ -41,7 +39,6 @@ from .sensetime_api import SenseTime # noqa: F401 from .stepfun_api import StepFun # noqa: F401 from .turbomind import TurboMindModel # noqa: F401 -from .turbomind_tis import TurboMindTisModel # noqa: F401 from .turbomind_with_tf_above_v4_33 import \ TurboMindModelwithChatTemplate # noqa: F401 from .unigpt_api import UniGPT # noqa: F401 diff --git a/opencompass/models/lmdeploy_pytorch.py b/opencompass/models/lmdeploy_pytorch.py deleted file mode 100644 index 80924c276..000000000 --- a/opencompass/models/lmdeploy_pytorch.py +++ /dev/null @@ -1,188 +0,0 @@ -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Union - -from opencompass.models.base import BaseModel -from opencompass.utils.logging import get_logger -from opencompass.utils.prompt import PromptList - -PromptType = Union[PromptList, str] - - -def valid_str(string, coding='utf-8'): - """decode text according to its encoding type.""" - invalid_chars = [b'\xef\xbf\xbd'] - bstr = bytes(string, coding) - for invalid_char in invalid_chars: - bstr = bstr.replace(invalid_char, b'') - ret = bstr.decode(encoding=coding, errors='ignore') - return ret - - -class LmdeployPytorchModel(BaseModel): - """Model wrapper for lmdeploy pytorch engine through python API. - - Args: - path (str): path of the supported pytorch model. - max_seq_len (int): The maximum allowed sequence length of a model. - Note that the length of prompt + generated tokens shall not exceed - this value. Defaults to 2048. - meta_template (Dict, optional): The model's meta prompt - template if needed, in case the requirement of injecting or - wrapping of any meta instructions. - engine_config (Dict, optional): The engine config to set - arguments like session_len, max_batch_size for TurboMind. - gen_config (Dict, optional): Generation config to set - arguments like top_k, top_p, temperature. - end_str (str, optional): Whether to trim generated strings with end_str - if the model has special ending strings that are not handled well. - Defaults to None. - """ - - def __init__(self, - path: str, - concurrency: int = 8, - max_seq_len: int = 2048, - meta_template: Optional[Dict] = None, - engine_config: Optional[Dict] = None, - gen_config: Optional[Dict] = None, - end_str: Optional[str] = None): - super().__init__(path=path, - max_seq_len=max_seq_len, - meta_template=meta_template) - from lmdeploy.pytorch import engine as tm - from lmdeploy.version import version_info - - if engine_config is not None: - from lmdeploy.messages import PytorchEngineConfig - engine_config = PytorchEngineConfig(**engine_config) - # set thread_safe - if hasattr(engine_config, 'thread_safe'): - engine_config.thread_safe = True - - if gen_config is not None: - from lmdeploy.messages import GenerationConfig - gen_config = GenerationConfig(**gen_config) - - self.logger = get_logger() - tm_model = tm.Engine(path, engine_config) - self.tokenizer = tm_model.tokenizer - self.generators = [ - tm_model.create_instance() for i in range(concurrency) - ] - self.generator_ids = [i + 1 for i in range(concurrency)] - - from transformers import GenerationConfig - try: - generation_config = GenerationConfig.from_pretrained(path) - except Exception: - generation_config = None - if generation_config and hasattr(generation_config, 'eos_token_id'): - if gen_config.stop_words is None: - stop_words = [] - if isinstance(generation_config.eos_token_id, int): - stop_words.append(generation_config.eos_token_id) - else: - assert isinstance(generation_config.eos_token_id, list) - for token_id in generation_config.eos_token_id: - stop_words.append(token_id) - gen_config.stop_words = stop_words - if version_info >= (0, 6, 0): - gen_config.stop_token_ids = stop_words - self.gen_config = gen_config - self.end_str = end_str - self.major_version, self.minor_version = version_info[:2] - - def generate( - self, - inputs: List[str], - max_out_len: int = 512, - ) -> List[str]: - """Generate results given a list of inputs. - - Args: - inputs (List[str]): A list of prompts - max_out_len (int): The maximum length of the output. - - Returns: - List[str]: A list of generated strings. - """ - assert isinstance( - inputs, List), f'List(str) is expected, but got {type(inputs)}' - - # split inputs into batches - batch_size = len(self.generators) - batch_inputs = [ - inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size) - ] - - results = [] - for batch_input in batch_inputs: - with ThreadPoolExecutor() as executor: - _results = list( - executor.map( - self._generate, - self.generators[:len(batch_input)], - self.generator_ids[:len(batch_input)], - batch_input, - [self.gen_config] * len(batch_input), - [self.end_str] * len(batch_input), - )) - results += _results - return results - - def get_token_len(self, prompt: str) -> int: - input_ids = self.tokenizer.encode(prompt) - return len(input_ids) - - def wait(self): - """Wait till the next query can be sent. - - Applicable in both single-thread and multi-thread environments. - """ - return self.token_bucket.get_token() - - def _generate(self, - generator, - session_id, - prompt: PromptType, - gen_config=None, - end_str: Optional[str] = None) -> str: - """Generate results given a list of inputs. - - Args: - prompt (PromptType): A string or PromptDict. - The PromptDict should be organized in OpenCompass' - API format. - gen_config (GenerationConfig, optional): Generation - config to set arguments like top_k, top_p, temperature. - end_str (str, optional): Whether to trim generated strings - with end_str if the model has special ending strings - that are not handled well. - Defaults to None. - Returns: - str: The generated string. - """ - assert type( - prompt) is str, 'We only support string for TurboMind Python API' - input_ids = self.tokenizer.encode(prompt) - if self.major_version >= 0 and self.minor_version >= 4: - outputs = generator.infer(session_id, - input_ids, - gen_config=gen_config) - output_ids = outputs.token_ids - else: - _, output_ids, _ = generator.infer(session_id, - input_ids, - gen_config=gen_config) - - # stop engine - if hasattr(generator, 'end'): - generator.end(session_id) - # decode output - response_all = self.tokenizer.decode(output_ids) - # trim output - if end_str: - response_all = response_all.split(end_str)[0] - # remove invalid characters - response_all = valid_str(response_all) - return response_all diff --git a/opencompass/models/lmdeploy_tis.py b/opencompass/models/lmdeploy_tis.py deleted file mode 100644 index 9c92ef18a..000000000 --- a/opencompass/models/lmdeploy_tis.py +++ /dev/null @@ -1,200 +0,0 @@ -import threading -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from queue import Queue -from typing import Dict, List, Optional, Union - -import numpy as np - -from opencompass.models.base import BaseModel, LMTemplateParser -from opencompass.utils.logging import get_logger -from opencompass.utils.prompt import PromptList - -PromptType = Union[PromptList, str] - - -def valid_str(string, coding='utf-8'): - """decode text according to its encoding type.""" - invalid_chars = [b'\xef\xbf\xbd'] - bstr = bytes(string, coding) - for invalid_char in invalid_chars: - bstr = bstr.replace(invalid_char, b'') - ret = bstr.decode(encoding=coding, errors='ignore') - return ret - - -def prepare_tensor(name, input_tensor): - """Create grpcclient's InferInput instance according to a given tensor.""" - import tritonclient.grpc as grpcclient - from tritonclient.utils import np_to_triton_dtype - t = grpcclient.InferInput(name, list(input_tensor.shape), - np_to_triton_dtype(input_tensor.dtype)) - t.set_data_from_numpy(input_tensor) - return t - - -def stream_callback(que, result, error): - """callback function invoked by triton client.""" - que.put((result, error)) - - -class LmdeployTisModel(BaseModel): - """Model wrapper for LMDeploy Python Backend Triton Inference Server gRPC - API. - - Args: - path (str): The name of OpenAI's model. - tis_addr (str): The address (ip:port format) of turbomind's - triton inference server - max_seq_len (int): The maximum allowed sequence length of a model. - Note that the length of prompt + generated tokens shall not exceed - this value. Defaults to 2048. - meta_template (Dict, optional): The model's meta prompt - template if needed, in case the requirement of injecting or - wrapping of any meta instructions. - """ - - is_api: bool = True - - def __init__(self, - path: str, - tis_addr: str = '0.0.0.0:33337', - max_seq_len: int = 2048, - meta_template: Optional[Dict] = None, - end_str: Optional[str] = None): - super().__init__(path=path, - max_seq_len=max_seq_len, - meta_template=meta_template) - from lmdeploy.tokenizer import Tokenizer - - self.logger = get_logger() - self.template_parser = LMTemplateParser(meta_template) - self.eos_token_id = None - if meta_template and 'eos_token_id' in meta_template: - self.eos_token_id = meta_template['eos_token_id'] - self.tis_addr = tis_addr - self.tokenizer = Tokenizer(path) - self.end_str = end_str - - def generate( - self, - inputs: List[str or PromptList], - max_out_len: int = 512, - temperature: float = 1.0, - ) -> List[str]: - """Generate results given a list of inputs. - - Args: - inputs (List[str or PromptList]): A list of strings or PromptDicts. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - temperature (float): What sampling temperature to use, - between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more - focused and deterministic. Defaults to 0.7. - - Returns: - List[str]: A list of generated strings. - """ - - with ThreadPoolExecutor() as executor: - results = list( - executor.map(self._generate, inputs, - [max_out_len] * len(inputs), - [temperature] * len(inputs), - [self.end_str] * len(inputs))) - return results - - def wait(self): - """Wait till the next query can be sent. - - Applicable in both single-thread and multi-thread environments. - """ - return self.token_bucket.get_token() - - def get_token_len(self, prompt: str) -> int: - input_ids = self.tokenizer.encode(prompt) - return len(input_ids) - - def _call_triton_server(self, prompt, tis_addr, session_id, - request_output_len, temperature, res_que): - import tritonclient.grpc as grpcclient - - with grpcclient.InferenceServerClient(tis_addr) as client: - inputs = [ - prepare_tensor('prompt', - np.array([prompt.encode()], dtype=np.object_)), - prepare_tensor('max_tokens', - np.array([request_output_len], dtype=np.int32)), - prepare_tensor('temperature', - np.array([temperature], dtype=np.float_)), - prepare_tensor('top_p', np.array([1.0], dtype=np.float_)), - prepare_tensor('top_k', np.array([1], dtype=np.int32)), - prepare_tensor('ignore_eos', np.array([False], - dtype=np.bool_)), - prepare_tensor('stream', np.array([True], dtype=np.bool_)), - ] - - # async_stream - client.start_stream(partial(stream_callback, res_que)) - client.async_stream_infer('lmdeploy_model', - inputs, - sequence_id=session_id, - sequence_start=True, - sequence_end=True) - - res_que.put(None) - return - - def _process_result(self, que): - text = '' - while True: - res = que.get() - if res is not None: - result, err = res - if err is not None: - print(err) - else: - res = result.as_numpy('response').item().decode() - text += res - else: - return text - - def _generate(self, - prompt: str or PromptList, - max_out_len: int, - temperature: float, - end_str: Optional[str] = None) -> str: - """Generate results given a list of inputs. - - Args: - prompt (str or PromptList): A string or PromptDict. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - temperature (float): What sampling temperature to use, - between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more - focused and deterministic. - - Returns: - str: The generated string. - """ - assert type( - prompt - ) is str, 'We only support string for LMDeploy Python Backend TIS API' - - res_que = Queue() - - self._call_triton_server(prompt=prompt, - tis_addr=self.tis_addr, - session_id=threading.currentThread().ident, - request_output_len=max_out_len, - temperature=temperature, - res_que=res_que) - text = self._process_result(res_que) - response = valid_str(text) - if end_str: - response = response.split(end_str)[0] - return response diff --git a/opencompass/models/turbomind_tis.py b/opencompass/models/turbomind_tis.py deleted file mode 100644 index 8541b9de5..000000000 --- a/opencompass/models/turbomind_tis.py +++ /dev/null @@ -1,135 +0,0 @@ -import logging -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Union - -from opencompass.models.base import BaseModel, LMTemplateParser -from opencompass.utils.logging import get_logger -from opencompass.utils.prompt import PromptList - -PromptType = Union[PromptList, str] - - -def valid_str(string, coding='utf-8'): - """decode text according to its encoding type.""" - invalid_chars = [b'\xef\xbf\xbd'] - bstr = bytes(string, coding) - for invalid_char in invalid_chars: - bstr = bstr.replace(invalid_char, b'') - ret = bstr.decode(encoding=coding, errors='ignore') - return ret - - -class TurboMindTisModel(BaseModel): - """Model wrapper for TurboMind Triton Inference Server gRPC API. - - Args: - path (str): The name of OpenAI's model. - tis_addr (str): The address (ip:port format) of turbomind's - triton inference server - max_seq_len (int): The maximum allowed sequence length of a model. - Note that the length of prompt + generated tokens shall not exceed - this value. Defaults to 2048. - meta_template (Dict, optional): The model's meta prompt - template if needed, in case the requirement of injecting or - wrapping of any meta instructions. - """ - - is_api: bool = True - - def __init__( - self, - path: str, - tis_addr: str = '0.0.0.0:33337', - max_seq_len: int = 2048, - meta_template: Optional[Dict] = None, - ): - super().__init__(path=path, - max_seq_len=max_seq_len, - meta_template=meta_template) - from lmdeploy.serve.turbomind.utils import Preprocessor - self.preprocess = Preprocessor(tis_addr) - self.logger = get_logger() - self.template_parser = LMTemplateParser(meta_template) - self.eos_token_id = None - if meta_template and 'eos_token_id' in meta_template: - self.eos_token_id = meta_template['eos_token_id'] - self.tis_addr = tis_addr - - def generate( - self, - inputs: List[PromptType], - max_out_len: int = 512, - temperature: float = 1.0, - ) -> List[str]: - """Generate results given a list of inputs. - - Args: - inputs (List[PromptType]): A list of strings or PromptDicts. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - temperature (float): What sampling temperature to use, - between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more - focused and deterministic. Defaults to 0.7. - - Returns: - List[str]: A list of generated strings. - """ - - with ThreadPoolExecutor() as executor: - results = list( - executor.map(self._generate, inputs, - [max_out_len] * len(inputs), - [temperature] * len(inputs))) - return results - - def get_token_len(self, prompt: str) -> int: - input_ids, _ = self.preprocess(prompt) - return input_ids.shape[-1] - - def wait(self): - """Wait till the next query can be sent. - - Applicable in both single-thread and multi-thread environments. - """ - return self.token_bucket.get_token() - - def _generate(self, prompt: PromptType, max_out_len: int, - temperature: float) -> str: - """Generate results given a list of inputs. - - Args: - prompt (PromptType): A string or PromptDict. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - temperature (float): What sampling temperature to use, - between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more - focused and deterministic. - - Returns: - str: The generated string. - """ - assert type( - prompt) is str, 'We only support string for TurboMind RPC API' - - from lmdeploy.serve.turbomind.chatbot import Chatbot - chatbot = Chatbot(self.tis_addr, - temperature=temperature, - capability='completion', - top_k=1, - log_level=logging.ERROR) - - for status, text, n_token in chatbot.stream_infer( - session_id=threading.currentThread().ident, - prompt=prompt, - request_output_len=max_out_len, - sequence_start=True, - sequence_end=True): - continue - response = valid_str(text) - response = response.replace('', '') - return response diff --git a/opencompass/models/turbomind_with_tf_above_v4_33.py b/opencompass/models/turbomind_with_tf_above_v4_33.py index 48706671f..ab6801c9c 100644 --- a/opencompass/models/turbomind_with_tf_above_v4_33.py +++ b/opencompass/models/turbomind_with_tf_above_v4_33.py @@ -1,7 +1,6 @@ # flake8: noqa # yapf: disable import copy -from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Union from opencompass.models.base import BaseModel @@ -31,38 +30,32 @@ def __init__( self, path: str, tokenizer_only: bool = False, + backend: str = 'turbomind', engine_config: Dict = {}, gen_config: Dict = {}, - concurrency: int = 8, max_seq_len: int = None, meta_template: Optional[Dict] = None, fastchat_template: Optional[str] = None, stop_words: List[str] = [], ): - from lmdeploy.messages import TurbomindEngineConfig - from lmdeploy.turbomind import TurboMind - from lmdeploy.version import version_info - from transformers import AutoTokenizer - self.logger = get_logger() self.path = path self.tokenizer_only = tokenizer_only self.template_parser = _get_meta_template(meta_template) self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path) - self.origin_tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + from lmdeploy import version_info + from transformers import AutoTokenizer + self.version_info = version_info + self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) if not tokenizer_only: DEFAULT_ENGING_CONFIG = {'session_len': self.max_seq_len} _engine_config = DEFAULT_ENGING_CONFIG.copy() _engine_config.update(engine_config) - engine_config = TurbomindEngineConfig(**_engine_config) - tm_model = TurboMind.from_pretrained(path, engine_config=engine_config) - self.tokenizer = tm_model.tokenizer - self.generators = [tm_model.create_instance() for i in range(concurrency)] - self.generator_ids = [i + 1 for i in range(concurrency)] - self.concurrency = concurrency + self.pipe = self._build_pipe(path, backend, _engine_config) + else: + self.pipe = None self.gen_config = gen_config - self.version_info = version_info self.fastchat_template = fastchat_template self.stop_words = list(set(stop_words + self._get_potential_stop_words(path))) self.logger.info(f'using stop words: {self.stop_words}') @@ -76,23 +69,23 @@ def _get_potential_stop_words(self, path: Optional[str]): generation_config = None if generation_config and hasattr(generation_config, 'eos_token_id'): if isinstance(generation_config.eos_token_id, int): - potential_stop_words.append(self.origin_tokenizer.decode(generation_config.eos_token_id)) + potential_stop_words.append(self.tokenizer.decode(generation_config.eos_token_id)) else: assert isinstance(generation_config.eos_token_id, list) for token_id in generation_config.eos_token_id: - potential_stop_words.append(self.origin_tokenizer.decode(token_id)) - if self.origin_tokenizer.eos_token is not None: - potential_stop_words.append(self.origin_tokenizer.eos_token) + potential_stop_words.append(self.tokenizer.decode(token_id)) + if self.tokenizer.eos_token is not None: + potential_stop_words.append(self.tokenizer.eos_token) potential_stop_words = list(set(potential_stop_words)) potential_stop_words = [s for s in potential_stop_words if s] return potential_stop_words def generate(self, inputs: List[str], - max_out_len: int = 512, + max_out_len: int, stopping_criteria: List[str] = [], do_sample: Optional[bool] = None, - temperature: int = 1, + temperature: float = 1.0, **kwargs) -> List[str]: """Generate results given a list of inputs. @@ -104,93 +97,45 @@ def generate(self, List[str]: A list of generated strings. """ assert isinstance(inputs, List), f'List(str) is expected, but got {type(inputs)}' - messages = _convert_chat_messages(inputs) if self.fastchat_template: messages = _format_with_fast_chat_template(messages, self.fastchat_template) else: - messages = [self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages] - - # split messages into batches - batch_messages = [messages[i:i + self.concurrency] for i in range(0, len(messages), self.concurrency)] + messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages] stop_words = list(set(self.stop_words + stopping_criteria)) - encode_stop_words = [] - if stop_words is not None and len(stop_words) > 0: - for words in stop_words: - encode_stop_words += self.tokenizer.encode(words, add_bos=False) DEFAULT_GEN_CONFIG = { 'max_new_tokens': max_out_len, 'min_new_tokens': 1, - 'top_k': 1, - 'stop_words': encode_stop_words, + 'stop_words': stop_words, } gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG) gen_config.update(self.gen_config) if do_sample: - gen_config['top_k'] = 1000 + gen_config['top_k'] = 40 gen_config['temperature'] = temperature + else: + if self.version_info >= (0, 6, 0): + gen_config['do_sample'] = False + else: + gen_config['top_k'] = 1 - from lmdeploy.messages import GenerationConfig + from lmdeploy import GenerationConfig + gen_config = {k: v for k, v in gen_config.items() if hasattr(GenerationConfig, k)} gen_config = GenerationConfig(**gen_config) - if self.version_info >= (0, 6, 0): - gen_config.stop_words = stop_words - gen_config.convert_stop_bad_words_to_ids(self.tokenizer) results = [] - for batch_message in batch_messages: - n = len(batch_message) - with ThreadPoolExecutor() as executor: - _results = list( - executor.map( - self._generate, - self.generators[:n], - self.generator_ids[:n], - batch_message, - [gen_config] * n, - )) - results += _results + outputs = self.pipe(messages, gen_config=gen_config, do_preprocess=False) + for output in outputs: + text = self.tokenizer.decode(output.token_ids) + results.append(text) for s in stop_words: results = [r.split(s)[0] for r in results] return results - def _generate(self, - generator, - session_id, - prompt: PromptType, - gen_config=None) -> str: - """Generate results given a list of inputs. - - Args: - prompt (PromptType): A string or PromptDict. - The PromptDict should be organized in OpenCompass' - API format. - gen_config (GenerationConfig, optional): Generation - config to set arguments like top_k, top_p, temperature. - Returns: - str: The generated string. - """ - assert type(prompt) is str, 'We only support string for TurboMind Python API' - - input_ids = self.tokenizer.encode(prompt, add_bos=False) - for outputs in generator.stream_infer(session_id=session_id, - input_ids=[input_ids], - gen_config=gen_config, - sequence_start=True, - sequence_end=True, - step=0, - stream_output=False): - if self.version_info >= (0, 4, 0): - output_ids = outputs.token_ids - else: - _, output_ids, _ = outputs - response = self.tokenizer.decode(output_ids) - response = valid_str(response) - return response - def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized strings. @@ -201,5 +146,20 @@ def get_token_len(self, prompt: str) -> int: int: Length of the input tokens """ m = _convert_chat_messages([prompt])[0] - t = self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True) + t = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True) return len(t['input_ids']) + + def _build_pipe(self, model_path, backend, engine_config): + from lmdeploy import (PytorchEngineConfig, TurbomindEngineConfig, + pipeline) + + assert backend in ['pytorch', 'turbomind'], \ + f'unsupported backend type: {backend}' + + if backend == 'turbomind': + filtered = {k: v for k, v in engine_config.items() if hasattr(TurbomindEngineConfig, k)} + backend_config = TurbomindEngineConfig(**filtered) + else: + filtered = {k: v for k, v in engine_config.items() if hasattr(PytorchEngineConfig, k)} + backend_config = PytorchEngineConfig(**filtered) + return pipeline(model_path, backend_config=backend_config, log_level='INFO', max_log_len=10) diff --git a/opencompass/utils/run.py b/opencompass/utils/run.py index 67c465941..025efc4b3 100644 --- a/opencompass/utils/run.py +++ b/opencompass/utils/run.py @@ -9,7 +9,7 @@ from opencompass.datasets.custom import make_custom_dataset_config from opencompass.models import (VLLM, HuggingFace, HuggingFaceBaseModel, HuggingFaceCausalLM, HuggingFaceChatGLM3, - HuggingFacewithChatTemplate, TurboMindModel, + HuggingFacewithChatTemplate, TurboMindModelwithChatTemplate, VLLMwithChatTemplate) from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner @@ -233,7 +233,7 @@ def change_accelerator(models, accelerator): model_accels = [] for model in models: logger.info(f'Transforming {model["abbr"]} to {accelerator}') - # change HuggingFace model to VLLM or TurboMindModel + # change HuggingFace model to VLLM or LMDeploy if model['type'] in [HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3, f'{HuggingFaceBaseModel.__module__}.{HuggingFaceBaseModel.__name__}']: gen_args = dict() if model.get('generation_kwargs') is not None: @@ -254,10 +254,10 @@ def change_accelerator(models, accelerator): if accelerator == 'lmdeploy': logger.info(f'Transforming {model["abbr"]} to {accelerator}') - mod = TurboMindModel + mod = TurboMindModelwithChatTemplate acc_model = dict( type=f'{mod.__module__}.{mod.__name__}', - abbr=model['abbr'].replace('hf', 'turbomind') if '-hf' in model['abbr'] else model['abbr'] + '-turbomind', + abbr=model['abbr'].replace('hf', 'lmdeploy') if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy', path=model['path'], engine_config=dict(session_len=model['max_seq_len'], max_batch_size=model['batch_size'], @@ -270,7 +270,6 @@ def change_accelerator(models, accelerator): max_out_len=model['max_out_len'], max_seq_len=model['max_seq_len'], batch_size=model['batch_size'], - concurrency=model['batch_size'], run_cfg=model['run_cfg'], ) for item in ['meta_template']: @@ -312,7 +311,7 @@ def change_accelerator(models, accelerator): mod = TurboMindModelwithChatTemplate acc_model = dict( type=f'{mod.__module__}.{mod.__name__}', - abbr=model['abbr'].replace('hf', 'turbomind') if '-hf' in model['abbr'] else model['abbr'] + '-turbomind', + abbr=model['abbr'].replace('hf', 'lmdeploy') if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy', path=model['path'], engine_config=dict(max_batch_size=model.get('batch_size', 16), tp=model['run_cfg']['num_gpus']), gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9),