diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index b35bb6b94991..1b5c03ce48f1 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -8,8 +8,15 @@ BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, + TokenClassifierOutput, +) +from transformers.models.t5.modeling_t5 import ( + T5EncoderModel, + T5ForConditionalGeneration, + T5ForTokenClassification, + T5Model, + T5Stack, ) -from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -582,6 +589,71 @@ def t5_encoder_model_forward( return outputs + @staticmethod + def t5_for_token_classification_forward( + self: T5ForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + backward_tensor_keys: Optional[List[str]] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForTokenClassification.forward. + Please refer to original code of transformers for more details. + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = T5PipelineForwards.t5_stack_forward( + self.transformer.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return outputs + def get_t5_flash_attention_forward(): from transformers.models.t5.modeling_t5 import T5Attention diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 008dead6ba5c..99b68aee2420 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -68,6 +68,9 @@ class PolicyLocation: file_name="t5", class_name="T5ForConditionalGenerationPolicy" ), "transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), + "transformers.models.t5.modeling_t5.T5ForTokenClassification": PolicyLocation( + file_name="t5", class_name="T5ForTokenClassificationPolicy" + ), # GPT2 "transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation( diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 1298f0af3e61..0b594678c71b 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -31,7 +31,13 @@ ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] +__all__ = [ + "distribute_t5_layers", + "T5ModelPolicy", + "T5ForConditionalGenerationPolicy", + "T5EncoderPolicy", + "T5ForTokenClassificationPolicy", +] class T5BasePolicy(Policy): @@ -312,9 +318,13 @@ def get_held_layers(self) -> List[nn.Module]: assert self.pipeline_stage_manager is not None stage_manager = self.pipeline_stage_manager - model = self.model - encoder = self.model.encoder - decoder = getattr(self.model, "decoder", None) + if self.model.__class__.__name__ == "T5ForTokenClassification": + model = self.model.transformer + else: + model = self.model + + encoder = model.encoder + decoder = getattr(model, "decoder", None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 @@ -353,7 +363,11 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - encoder = self.model.encoder + if self.model.__class__.__name__ == "T5ForTokenClassification": + encoder = self.model.transformer.encoder + else: + encoder = self.model.encoder + decoder = getattr(self.model, "decoder", None) num_encoder_layers = len(encoder.block) @@ -542,3 +556,46 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] + + +class T5ForTokenClassificationPolicy(T5EncoderPolicy): + def module_policy(self): + from transformers.models.t5.modeling_t5 import T5ForTokenClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + T5ForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ] + ) + } + policy.update(addon_module) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=T5ForTokenClassification, + new_forward=T5PipelineForwards.t5_for_token_classification_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 2ccfb0356c2b..f6ccb297ea41 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -40,6 +40,14 @@ def data_gen_for_t5_model(): return data +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen_for_encoder_only() + data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + # output transform function output_transform_fn = lambda x: x @@ -47,6 +55,7 @@ def data_gen_for_t5_model(): loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean() loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean() loss_fn_for_conditional_generation = lambda x: x["loss"] +loss_fn_for_token_classification = lambda x: x["loss"] # define model config config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) @@ -79,3 +88,11 @@ def data_gen_for_t5_model(): loss_fn=loss_fn_for_encoder_only, model_attribute=ModelAttribute(has_control_flow=True), ) +model_zoo.register( + name="transformers_t5_for_token_classification", + model_fn=lambda: transformers.T5ForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_token_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 521dc9130b7e..6cdf5bf41c68 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -41,14 +41,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, t5 = unwrap_model(org_model) sharded_t5 = unwrap_model(sharded_model) - row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"] + if t5.__class__.__name__ == "T5ForTokenClassification": + row_layer_for_check = ["transformer.shared", "transformer.encoder.block[0].layer[0].SelfAttention.q"] + else: + row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: - atol, rtol = 5e-3, 5e-3 + atol, rtol = 5e-2, 5e-2 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: row_layer_grads = get_grad_tensors_for_check( t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 @@ -66,7 +69,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ != "T5ForConditionalGeneration": + if org_model.__class__.__name__ not in ["T5ForConditionalGeneration", "T5ForTokenClassification"]: check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -157,7 +160,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) @clear_cache_before_run() def run_t5_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") + sub_model_zoo = model_zoo.get_sub_registry(["transformers_t5_for_token_classification"]) for name, ( model_fn, @@ -167,7 +170,10 @@ def run_t5_test(test_config): _, ) in sub_model_zoo.items(): # skip 4-stage pp test for t5_encoder - if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model": + if test_config["pp_size"] > 2 and name in [ + "transformers_t5_encoder_model", + "transformers_t5_for_token_classification", + ]: continue check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)