Skip to content

Commit

Permalink
[gemini] add GeminiMemoryManger (#832)
Browse files Browse the repository at this point in the history
* refactor StatefulTensor, tensor utilities

* add unitest for GeminiMemoryManager
  • Loading branch information
1SAA authored Apr 24, 2022
1 parent 35ea6e1 commit e5ea3fd
Show file tree
Hide file tree
Showing 23 changed files with 414 additions and 180 deletions.
45 changes: 45 additions & 0 deletions colossalai/gemini/gemini_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from enum import EnumMeta


class GeminiMemoryManager(object):

def __init__(self, states_cls: EnumMeta):
super().__init__()
self.states_cls = states_cls
self._cnter = 0 # the counter of instances

self.total_mem = dict()
self.state_mem = dict()
self.state_mem['cpu'] = dict()
self.state_mem['cuda'] = dict()

self.reset()

@property
def total_number(self):
return self._cnter

def reset(self):
self._cnter = 0 # the counter of instances

self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda

# memory conditions for all states
for state in self.states_cls:
self.state_mem['cpu'][state] = 0
self.state_mem['cuda'][state] = 0

def register_new_instance(self):
self._cnter += 1

def print_info(self):
print(
f"Total number: {self.total_number}",
f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n')

for state in self.states_cls:
print(
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n')
204 changes: 204 additions & 0 deletions colossalai/gemini/stateful_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from enum import Enum
from typing import Optional
import torch
from typing import Union

from colossalai.gemini.gemini_context import GeminiMemoryManager


def sizeof_tensor(tensor: torch.Tensor):
return tensor.numel() * tensor.element_size()


class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4


class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
# Global Stateful Tensor Manager
GST_MGR = GeminiMemoryManager(TensorState)

def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = None
self._payload_size = 0 # byte size of current payload

StatefulTensor.GST_MGR.register_new_instance()

if self._state == TensorState.FREE:
# when the state is free, payload should be None
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
else:
# otherwise, payload should not be None
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
self._payload = maybe_tensor
self._payload_size = sizeof_tensor(maybe_tensor)
self.__trans_state_update(TensorState.FREE, state)

def data_ptr(self):
if self._payload is None:
return 0 # if a tensor has no storage, 0 should be returned
return self._payload.data_ptr()

def set_null(self) -> None:
# notice that free stateful tensor do not need to become null again
if self.state != TensorState.FREE:
self.__trans_state_update(self.state, TensorState.FREE)
self.__release()

def is_null(self) -> bool:
if self.state == TensorState.FREE:
# check sanity here
assert self.payload is None
return True
return False

def trans_state(self, state: TensorState) -> None:
if self.state == TensorState.FREE:
# free stateful tensor can't change state
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
return

self.__trans_state_update(self.state, state)

if state == TensorState.FREE:
self.__release()
else:
self._state = state

def move_to(self, device: Union[torch.device, int]):
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"

if not isinstance(device, torch.device):
to_device = torch.device('cuda', device)
else:
to_device = device

from_device_type = self.device.type
if from_device_type == to_device.type:
# from device == to device
return

# update manager's information
self.__trans_device_update(from_device_type, to_device.type)
self.payload.data = self.payload.data.to(to_device)

def payload_copy(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))

def payload_reset(self, tensor) -> None:

assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"

if self.payload is not None:
# release old payload
self.__trans_state_update(self.state, TensorState.FREE)
else:
# otherwise, set the state to HOLD for new payload
self._state = TensorState.HOLD
del self._payload

self._payload = tensor
self._payload_size = sizeof_tensor(tensor)
# record new payload
self.__trans_state_update(TensorState.FREE, self.state)

def payload_relay(self, rhs):
# relay the payload of rhs to current stateful tensor
# can't support null relay right now
assert not rhs.is_null()

# now this function only support stateful tensor that has zero-length payload
# because it doesn't require memory manager updating
# you can extend this function by yourself
assert self.payload_size == 0

self._payload = rhs.payload
self._payload_size = rhs.payload_size
self._state = TensorState.HOLD
self.__trans_state_update(rhs.state, TensorState.HOLD)

rhs.__release()

@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload

@property
def payload_size(self) -> int:
return self._payload_size

@property
def state(self) -> TensorState:
return self._state

@property
def device(self) -> torch.device:
return self._payload.device

@property
def dtype(self) -> torch.dtype:
return self._payload.dtype

@property
def shape(self):
return self._payload.shape

def to(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")

def to_(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")

def __release(self):
# release current payload
# shouldn't be visible to users
self._state = TensorState.FREE
self._payload = None
self._payload_size = 0

def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
"""Update global manager when changing the state of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
device_type = self.device.type

if from_state != TensorState.FREE:
manager.state_mem[device_type][from_state] -= size
else:
# when from_state is FREE, the tensor is new to manager
# we should add its memory
manager.total_mem[device_type] += size

if to_state != TensorState.FREE:
manager.state_mem[device_type][to_state] += size
else:
# when to_state is FREE, the tensor will be deleted soon
# we should sub its memory
manager.total_mem[device_type] -= size

def __trans_device_update(self, from_type: str, to_type: str):
"""Update global manager when changing the device of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
state = self.state

# update aggregated information
manager.total_mem[from_type] -= size
manager.total_mem[to_type] += size

# update the information of each state
manager.state_mem[from_type][state] -= size
manager.state_mem[to_type][state] += size
8 changes: 4 additions & 4 deletions colossalai/gemini/stateful_tensor_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import torch
import types
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List
from colossalai.logging import get_dist_logger
Expand All @@ -30,7 +29,8 @@ def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None:

self._cpu_gpu_move_volume = 0

def register_stateful_param(self, param: ShardedParamV2) -> None:
def register_stateful_param(self, param) -> None:
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
assert isinstance(param, ShardedParamV2)
for t in param.get_payload_tensors():
assert isinstance(t, StatefulTensor)
Expand Down
4 changes: 2 additions & 2 deletions colossalai/gemini/tensor_placement_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity

from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from typing import Type
Expand Down
Loading

0 comments on commit e5ea3fd

Please sign in to comment.