Skip to content

Commit

Permalink
[ShardFormer] fix qwen2 sp (#5903)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuangyaoZhang authored Jul 15, 2024
1 parent 45c49dd commit 1c961b2
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 49 deletions.
6 changes: 3 additions & 3 deletions colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -513,7 +514,6 @@ def forward(
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
Expand Down Expand Up @@ -698,9 +698,9 @@ def forward(
next_decoder_cache = None

if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)

for decoder_layer in self.layers:
if output_hidden_states:
Expand Down
99 changes: 53 additions & 46 deletions tests/test_shardformer/test_model/test_shard_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,51 +135,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_qwen2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()


@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
Expand Down Expand Up @@ -242,6 +197,54 @@ def run_qwen2_test(test_config):
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_qwen2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()


@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
Expand All @@ -259,7 +262,11 @@ def run_qwen2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e

clear_layout_converter()
Randomizer.reset_index()
Expand Down

0 comments on commit 1c961b2

Please sign in to comment.