Skip to content

Commit

Permalink
[shardformer] Fix serialization error with Tensor Parallel state savi…
Browse files Browse the repository at this point in the history
…ng (#5018)

* Fix serialization error with Tensor Parallel state saving

* Refactor state_dict CPU transfer using tree_map
  • Loading branch information
imgaojun authored Nov 9, 2023
1 parent 7244412 commit a448938
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn
from packaging.version import Version
from torch.optim import Optimizer
from torch.utils._pytree import tree_map

from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
Expand Down Expand Up @@ -293,7 +294,6 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Helper functions for saving state dict
# ======================================


def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
Expand All @@ -303,16 +303,19 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)

if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith(
".safetensors"
), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file

safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
safe_save_file(state_dict_cpu, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
torch.save(state_dict_cpu, checkpoint_file_path)


def save_param_groups(state_dict: dict, group_file_path: str) -> None:
Expand Down

0 comments on commit a448938

Please sign in to comment.