From fdd84b9087f6309e3da315559367bac3dd8c05e3 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 02:32:03 +0000 Subject: [PATCH 01/11] fix the sp --- colossalai/kernel/kernel_loader.py | 2 ++ colossalai/shardformer/layer/attn.py | 32 ++++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 2411b6482ac1..8598cf0ae0ed 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -118,6 +118,8 @@ class FlashAttentionLoader(KernelLoader): FlashAttentionSdpaCudaExtension, ] +class FlashAttentionDaoLoader(KernelLoader): + REGISTRY = [FlashAttentionDaoCudaExtension] class FlashAttentionWithCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index bf4fa77c6c23..29ef64a4d6fe 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -10,6 +10,7 @@ from colossalai.kernel.kernel_loader import ( FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionLoader, + FlashAttentionDaoLoader, FlashAttentionWithCustomMaskLoader, KernelLoader, ) @@ -17,6 +18,8 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag +MEMORY_BOUND = 10 * 1e9 + __all__ = [ "AttnMaskType", "ColoAttention", @@ -104,7 +107,7 @@ def _init_kernels_dispatch(): } @staticmethod - def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable: + def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable: ColoAttention._init_kernels_dispatch() if ( dtype not in ColoAttention._kernel_dispatch_map @@ -113,12 +116,16 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> C raise ValueError( "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) ) + + if size > MEMORY_BOUND: + FlashAttentionDaoLoader().load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ mask_type ].load() - return ColoAttention._kernel_dispatch_map[dtype][mask_type] + + return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type] @staticmethod def prepare_attn_kwargs( @@ -163,7 +170,7 @@ def prepare_attn_kwargs( outputs["attention_mask_type"] = AttnMaskType.CAUSAL attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) if s_q != 1: - attention_mask = attention_mask.tril(diagonal=0) + attention_mask.tril_(diagonal=0) attention_mask = attention_mask.expand(b, s_q, s_kv) else: max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) @@ -197,6 +204,15 @@ def prepare_attn_kwargs( if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask + + element_size = torch.tensor([], dtype=dtype).element_size() + memory_size = (s_q * s_kv * element_size) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=dtype, device=device) + outputs["attention_mask"] = attention_mask + if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL: + outputs["attention_mask_type"] = AttnMaskType.CAUSAL + return outputs @staticmethod @@ -278,8 +294,16 @@ def attention( assert attention_mask_type == AttnMaskType.CUSTOM # kernel dispatch + b, _, s_q, _ = q.shape + b, _, s_kv, _ = v.shape + element_size = torch.tensor([], dtype=q.dtype).element_size() + memory_size = (s_q * s_kv * element_size) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device) + assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED + mask_type = attention_mask_type if attention_mask is not None else None - attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) + attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) is_causal = attention_mask is not None and attention_mask_type in ( AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL, From 216d54e3747fedb3404f181023449182410d48d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 02:38:39 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/kernel/kernel_loader.py | 2 ++ colossalai/shardformer/layer/attn.py | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 8598cf0ae0ed..36a49aae918b 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -118,9 +118,11 @@ class FlashAttentionLoader(KernelLoader): FlashAttentionSdpaCudaExtension, ] + class FlashAttentionDaoLoader(KernelLoader): REGISTRY = [FlashAttentionDaoCudaExtension] + class FlashAttentionWithCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 29ef64a4d6fe..0a4f985358af 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -8,9 +8,9 @@ from einops import rearrange from colossalai.kernel.kernel_loader import ( + FlashAttentionDaoLoader, FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionLoader, - FlashAttentionDaoLoader, FlashAttentionWithCustomMaskLoader, KernelLoader, ) @@ -116,7 +116,7 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size raise ValueError( "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) ) - + if size > MEMORY_BOUND: FlashAttentionDaoLoader().load() # lazy load @@ -124,8 +124,10 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ mask_type ].load() - - return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type] + + return ( + FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type] + ) @staticmethod def prepare_attn_kwargs( @@ -204,15 +206,15 @@ def prepare_attn_kwargs( if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask - + element_size = torch.tensor([], dtype=dtype).element_size() - memory_size = (s_q * s_kv * element_size) + memory_size = s_q * s_kv * element_size if memory_size > MEMORY_BOUND: attention_mask = torch.empty((0,), dtype=dtype, device=device) outputs["attention_mask"] = attention_mask if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL: outputs["attention_mask_type"] = AttnMaskType.CAUSAL - + return outputs @staticmethod @@ -297,11 +299,11 @@ def attention( b, _, s_q, _ = q.shape b, _, s_kv, _ = v.shape element_size = torch.tensor([], dtype=q.dtype).element_size() - memory_size = (s_q * s_kv * element_size) + memory_size = s_q * s_kv * element_size if memory_size > MEMORY_BOUND: attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device) assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED - + mask_type = attention_mask_type if attention_mask is not None else None attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) is_causal = attention_mask is not None and attention_mask_type in ( From 0a01e2a453abaa802cb839f67f7193af52709350 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 03:33:08 +0000 Subject: [PATCH 03/11] fix the attn --- colossalai/shardformer/layer/attn.py | 44 +++++++++++++++------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 0a4f985358af..1ffbae73e0a9 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -18,7 +18,7 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag -MEMORY_BOUND = 10 * 1e9 +MEMORY_BOUND = 1 * 1e9 __all__ = [ "AttnMaskType", @@ -125,9 +125,10 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size mask_type ].load() - return ( - FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type] - ) + if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): + return FlashAttentionDaoLoader() + else: + return ColoAttention._kernel_dispatch_map[dtype][mask_type] @staticmethod def prepare_attn_kwargs( @@ -163,6 +164,8 @@ def prepare_attn_kwargs( return {} assert len(shape_4d) == 4 and shape_4d[1] == 1 b, _, s_q, s_kv = shape_4d + element_size = torch.tensor([], dtype=dtype).element_size() + memory_size = s_q * s_kv * element_size outputs = {} if (q_padding_mask is None or q_padding_mask.bool().all()) and ( kv_padding_mask is None or kv_padding_mask.bool().all() @@ -170,10 +173,13 @@ def prepare_attn_kwargs( # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) - if s_q != 1: - attention_mask.tril_(diagonal=0) - attention_mask = attention_mask.expand(b, s_q, s_kv) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=dtype, device=device) + else: + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) + if s_q != 1: + attention_mask.tril_(diagonal=0) + attention_mask = attention_mask.expand(b, s_q, s_kv) else: max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: @@ -186,7 +192,10 @@ def prepare_attn_kwargs( b, s_kv, ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=dtype, device=device) + else: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -199,22 +208,16 @@ def prepare_attn_kwargs( ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - if s_q != 1: - attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=dtype, device=device) + else: + if s_q != 1: + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: outputs["attention_mask_type"] = AttnMaskType.PADDED if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask - - element_size = torch.tensor([], dtype=dtype).element_size() - memory_size = s_q * s_kv * element_size - if memory_size > MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=dtype, device=device) - outputs["attention_mask"] = attention_mask - if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL: - outputs["attention_mask_type"] = AttnMaskType.CAUSAL - return outputs @staticmethod @@ -301,7 +304,6 @@ def attention( element_size = torch.tensor([], dtype=q.dtype).element_size() memory_size = s_q * s_kv * element_size if memory_size > MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device) assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED mask_type = attention_mask_type if attention_mask is not None else None From 683179cefda2cb71c11d1ec83f75a4b9251ac0b0 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 03:40:56 +0000 Subject: [PATCH 04/11] fix --- colossalai/shardformer/layer/attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1ffbae73e0a9..65051a61fe07 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -18,7 +18,7 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag -MEMORY_BOUND = 1 * 1e9 +MEMORY_BOUND = 10 * 1e9 __all__ = [ "AttnMaskType", From 6eb8832366c76187350059985d780acebbcd9a2d Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 05:06:56 +0000 Subject: [PATCH 05/11] fix --- colossalai/shardformer/layer/attn.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 65051a61fe07..c18d57de1215 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -117,7 +117,7 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) ) - if size > MEMORY_BOUND: + if size >= MEMORY_BOUND: FlashAttentionDaoLoader().load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): @@ -173,7 +173,7 @@ def prepare_attn_kwargs( # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - if memory_size > MEMORY_BOUND: + if memory_size >= MEMORY_BOUND: attention_mask = torch.empty((0,), dtype=dtype, device=device) else: attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) @@ -192,10 +192,10 @@ def prepare_attn_kwargs( b, s_kv, ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" - if memory_size > MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=dtype, device=device) - else: + if memory_size < MEMORY_BOUND and not is_causal: attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + else: + attention_mask = torch.empty((0,), dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -208,7 +208,7 @@ def prepare_attn_kwargs( ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - if memory_size > MEMORY_BOUND: + if memory_size >= MEMORY_BOUND: attention_mask = torch.empty((0,), dtype=dtype, device=device) else: if s_q != 1: @@ -303,9 +303,6 @@ def attention( b, _, s_kv, _ = v.shape element_size = torch.tensor([], dtype=q.dtype).element_size() memory_size = s_q * s_kv * element_size - if memory_size > MEMORY_BOUND: - assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED - mask_type = attention_mask_type if attention_mask is not None else None attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) is_causal = attention_mask is not None and attention_mask_type in ( From f393867cff97924e0b90d81758a29bc5a2e94923 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 05:24:52 +0000 Subject: [PATCH 06/11] fix --- colossalai/shardformer/layer/attn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c18d57de1215..8890da242a02 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -173,13 +173,13 @@ def prepare_attn_kwargs( # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - if memory_size >= MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=dtype, device=device) - else: + if memory_size < MEMORY_BOUND and not is_causal: attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) if s_q != 1: attention_mask.tril_(diagonal=0) attention_mask = attention_mask.expand(b, s_q, s_kv) + else: + attention_mask = torch.empty((0,), dtype=dtype, device=device) else: max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: @@ -208,11 +208,11 @@ def prepare_attn_kwargs( ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - if memory_size >= MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=dtype, device=device) - else: + if memory_size < MEMORY_BOUND: if s_q != 1: attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + else: + attention_mask = torch.empty((0,), dtype=dtype, device=device) else: outputs["attention_mask_type"] = AttnMaskType.PADDED if invert: From dc032172c34538abcdad101997b9637b70ef0552 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 06:00:58 +0000 Subject: [PATCH 07/11] fix --- colossalai/shardformer/layer/attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 8890da242a02..a2ea761bf3e9 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -173,7 +173,7 @@ def prepare_attn_kwargs( # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - if memory_size < MEMORY_BOUND and not is_causal: + if memory_size < MEMORY_BOUND: attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) if s_q != 1: attention_mask.tril_(diagonal=0) From 0b14a5512e5220450b61effd15ff2d9c93ef7d22 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 07:06:14 +0000 Subject: [PATCH 08/11] fix --- colossalai/shardformer/layer/attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index a2ea761bf3e9..c755ffa2f211 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -118,7 +118,7 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size ) if size >= MEMORY_BOUND: - FlashAttentionDaoLoader().load() + flash_kernel = FlashAttentionDaoLoader().load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ @@ -126,7 +126,7 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size ].load() if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): - return FlashAttentionDaoLoader() + return flash_kernel else: return ColoAttention._kernel_dispatch_map[dtype][mask_type] From 0ad3129cb95cb74f13dead3f8369837ea997b499 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 09:01:26 +0000 Subject: [PATCH 09/11] fix --- colossalai/shardformer/layer/attn.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c755ffa2f211..129b04fb796d 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -80,6 +80,7 @@ def get_pad_info( class ColoAttention: _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None + _flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None @staticmethod def _init_kernels_dispatch(): @@ -105,6 +106,8 @@ def _init_kernels_dispatch(): torch.bfloat16: half_dispatch_map, torch.float32: float_dispatch_map, } + if ColoAttention._flash_kernel_dispatch is None: + ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader() @staticmethod def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable: @@ -118,7 +121,7 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size ) if size >= MEMORY_BOUND: - flash_kernel = FlashAttentionDaoLoader().load() + ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ @@ -126,7 +129,7 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size ].load() if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): - return flash_kernel + return ColoAttention._flash_kernel_dispatch else: return ColoAttention._kernel_dispatch_map[dtype][mask_type] From b5823192738299b602b0876c486ae97964652326 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 10:24:41 +0000 Subject: [PATCH 10/11] fix --- colossalai/shardformer/layer/attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 129b04fb796d..7157fbed8163 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -210,6 +210,7 @@ def prepare_attn_kwargs( } ) if is_causal: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL if memory_size < MEMORY_BOUND: if s_q != 1: From 827ef3ee9a176de774422f7361cc95efae57e3f1 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Sat, 14 Sep 2024 10:40:35 +0000 Subject: [PATCH 11/11] fix --- colossalai/shardformer/layer/attn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 7157fbed8163..2f8e4d677c54 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -195,10 +195,6 @@ def prepare_attn_kwargs( b, s_kv, ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" - if memory_size < MEMORY_BOUND and not is_causal: - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) - else: - attention_mask = torch.empty((0,), dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -210,15 +206,18 @@ def prepare_attn_kwargs( } ) if is_causal: - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL if memory_size < MEMORY_BOUND: if s_q != 1: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: attention_mask = torch.empty((0,), dtype=dtype, device=device) else: outputs["attention_mask_type"] = AttnMaskType.PADDED + if memory_size < MEMORY_BOUND: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask