Skip to content

Commit

Permalink
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
Browse files Browse the repository at this point in the history
* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement
  • Loading branch information
LRY89757 authored Jul 30, 2024
1 parent 7b38964 commit bcf0181
Show file tree
Hide file tree
Showing 15 changed files with 1,089 additions and 16 deletions.
12 changes: 11 additions & 1 deletion colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


## 📌 Introduction
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)

<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/colossal-inference-v1-1.png" width=1000/>
Expand Down Expand Up @@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
journal={arXiv},
year={2023}
}
# Distrifusion
@InProceedings{Li_2024_CVPR,
author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month={June},
year={2024},
pages={7183-7193}
}
```
16 changes: 16 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
"""

# NOTE: arrange configs according to their importance and frequency of usage
Expand Down Expand Up @@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
start_token_size: int = 4
generated_token_size: int = 512

# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
patched_parallelism_size: int = 1 # for distrifusion
# pipeFusion_m_size: int = 1 # for pipefusion
# pipeFusion_n_size: int = 1 # for pipefusion

def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
Expand Down Expand Up @@ -288,6 +294,14 @@ def _verify_config(self) -> None:
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size

# check Distrifusion
# TODO(@lry89757) need more detailed check
if self.patched_parallelism_size > 1:
# self.use_patched_parallelism = True
self.tp_size = (
self.patched_parallelism_size
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size

# check prompt template
if self.prompt_template is None:
return
Expand Down Expand Up @@ -324,6 +338,7 @@ def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
patched_parallelism_size=self.patched_parallelism_size,
)
return model_inference_config

Expand Down Expand Up @@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/core/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
Expand Down
Loading

0 comments on commit bcf0181

Please sign in to comment.