From a4489384d5239074c541820ef583fb5e8371b9ae Mon Sep 17 00:00:00 2001 From: Jun Gao Date: Thu, 9 Nov 2023 17:00:25 +0800 Subject: [PATCH] [shardformer] Fix serialization error with Tensor Parallel state saving (#5018) * Fix serialization error with Tensor Parallel state saving * Refactor state_dict CPU transfer using tree_map --- colossalai/checkpoint_io/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 06dab1fdb72a..e1800f29b0af 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -11,6 +11,7 @@ import torch.nn as nn from packaging.version import Version from torch.optim import Optimizer +from torch.utils._pytree import tree_map from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, @@ -293,7 +294,6 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Helper functions for saving state dict # ====================================== - def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: """ Save state dict to checkpoint. @@ -303,6 +303,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors checkpoint_file_path (str): path to the checkpoint file. use_safetensors (bool): whether to use safetensors to save the checkpoint. """ + # Move all tensors in the state_dict to CPU before saving to avoid serialization issues + state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict) + if use_safetensors: assert is_safetensors_available(), "safetensors is not available." assert checkpoint_file_path.endswith( @@ -310,9 +313,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors ), "safetensors only supports .safetensors suffix for checkpoint file." from safetensors.torch import save_file as safe_save_file - safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) + safe_save_file(state_dict_cpu, checkpoint_file_path, metadata={"format": "pt"}) else: - torch.save(state_dict, checkpoint_file_path) + torch.save(state_dict_cpu, checkpoint_file_path) def save_param_groups(state_dict: dict, group_file_path: str) -> None: