-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[gemini] add GeminiMemoryManger (#832)
* refactor StatefulTensor, tensor utilities * add unitest for GeminiMemoryManager
- Loading branch information
Showing
23 changed files
with
414 additions
and
180 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from enum import EnumMeta | ||
|
||
|
||
class GeminiMemoryManager(object): | ||
|
||
def __init__(self, states_cls: EnumMeta): | ||
super().__init__() | ||
self.states_cls = states_cls | ||
self._cnter = 0 # the counter of instances | ||
|
||
self.total_mem = dict() | ||
self.state_mem = dict() | ||
self.state_mem['cpu'] = dict() | ||
self.state_mem['cuda'] = dict() | ||
|
||
self.reset() | ||
|
||
@property | ||
def total_number(self): | ||
return self._cnter | ||
|
||
def reset(self): | ||
self._cnter = 0 # the counter of instances | ||
|
||
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu | ||
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda | ||
|
||
# memory conditions for all states | ||
for state in self.states_cls: | ||
self.state_mem['cpu'][state] = 0 | ||
self.state_mem['cuda'][state] = 0 | ||
|
||
def register_new_instance(self): | ||
self._cnter += 1 | ||
|
||
def print_info(self): | ||
print( | ||
f"Total number: {self.total_number}", | ||
f"Total CPU memory occupation: {self.total_mem['cpu']}", | ||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n') | ||
|
||
for state in self.states_cls: | ||
print( | ||
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", | ||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
from enum import Enum | ||
from typing import Optional | ||
import torch | ||
from typing import Union | ||
|
||
from colossalai.gemini.gemini_context import GeminiMemoryManager | ||
|
||
|
||
def sizeof_tensor(tensor: torch.Tensor): | ||
return tensor.numel() * tensor.element_size() | ||
|
||
|
||
class TensorState(Enum): | ||
FREE = 0 | ||
HOLD = 1 | ||
HOLD_AFTER_FWD = 2 | ||
HOLD_AFTER_BWD = 3 | ||
COMPUTE = 4 | ||
|
||
|
||
class StatefulTensor(object): | ||
"""A Structure stores a Torch Tensor and labeled states. | ||
Inspired from the paper: | ||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management | ||
https://arxiv.org/abs/2108.05818 | ||
""" | ||
# Global Stateful Tensor Manager | ||
GST_MGR = GeminiMemoryManager(TensorState) | ||
|
||
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: | ||
self._state = state | ||
self._payload = None | ||
self._payload_size = 0 # byte size of current payload | ||
|
||
StatefulTensor.GST_MGR.register_new_instance() | ||
|
||
if self._state == TensorState.FREE: | ||
# when the state is free, payload should be None | ||
assert maybe_tensor is None, f"payload has to None if state is {self._state}" | ||
else: | ||
# otherwise, payload should not be None | ||
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}" | ||
self._payload = maybe_tensor | ||
self._payload_size = sizeof_tensor(maybe_tensor) | ||
self.__trans_state_update(TensorState.FREE, state) | ||
|
||
def data_ptr(self): | ||
if self._payload is None: | ||
return 0 # if a tensor has no storage, 0 should be returned | ||
return self._payload.data_ptr() | ||
|
||
def set_null(self) -> None: | ||
# notice that free stateful tensor do not need to become null again | ||
if self.state != TensorState.FREE: | ||
self.__trans_state_update(self.state, TensorState.FREE) | ||
self.__release() | ||
|
||
def is_null(self) -> bool: | ||
if self.state == TensorState.FREE: | ||
# check sanity here | ||
assert self.payload is None | ||
return True | ||
return False | ||
|
||
def trans_state(self, state: TensorState) -> None: | ||
if self.state == TensorState.FREE: | ||
# free stateful tensor can't change state | ||
assert state == TensorState.FREE, "Free stateful tensor can't change to other states" | ||
return | ||
|
||
self.__trans_state_update(self.state, state) | ||
|
||
if state == TensorState.FREE: | ||
self.__release() | ||
else: | ||
self._state = state | ||
|
||
def move_to(self, device: Union[torch.device, int]): | ||
assert self.state is not TensorState.FREE, "Can't move free stateful tensor" | ||
|
||
if not isinstance(device, torch.device): | ||
to_device = torch.device('cuda', device) | ||
else: | ||
to_device = device | ||
|
||
from_device_type = self.device.type | ||
if from_device_type == to_device.type: | ||
# from device == to device | ||
return | ||
|
||
# update manager's information | ||
self.__trans_device_update(from_device_type, to_device.type) | ||
self.payload.data = self.payload.data.to(to_device) | ||
|
||
def payload_copy(self, tensor) -> None: | ||
self._payload.view(-1).copy_(tensor.view(-1)) | ||
|
||
def payload_reset(self, tensor) -> None: | ||
|
||
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead" | ||
|
||
if self.payload is not None: | ||
# release old payload | ||
self.__trans_state_update(self.state, TensorState.FREE) | ||
else: | ||
# otherwise, set the state to HOLD for new payload | ||
self._state = TensorState.HOLD | ||
del self._payload | ||
|
||
self._payload = tensor | ||
self._payload_size = sizeof_tensor(tensor) | ||
# record new payload | ||
self.__trans_state_update(TensorState.FREE, self.state) | ||
|
||
def payload_relay(self, rhs): | ||
# relay the payload of rhs to current stateful tensor | ||
# can't support null relay right now | ||
assert not rhs.is_null() | ||
|
||
# now this function only support stateful tensor that has zero-length payload | ||
# because it doesn't require memory manager updating | ||
# you can extend this function by yourself | ||
assert self.payload_size == 0 | ||
|
||
self._payload = rhs.payload | ||
self._payload_size = rhs.payload_size | ||
self._state = TensorState.HOLD | ||
self.__trans_state_update(rhs.state, TensorState.HOLD) | ||
|
||
rhs.__release() | ||
|
||
@property | ||
def payload(self) -> Optional[torch.Tensor]: | ||
return self._payload | ||
|
||
@property | ||
def payload_size(self) -> int: | ||
return self._payload_size | ||
|
||
@property | ||
def state(self) -> TensorState: | ||
return self._state | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
return self._payload.device | ||
|
||
@property | ||
def dtype(self) -> torch.dtype: | ||
return self._payload.dtype | ||
|
||
@property | ||
def shape(self): | ||
return self._payload.shape | ||
|
||
def to(self, device: torch.device): | ||
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor") | ||
|
||
def to_(self, device: torch.device): | ||
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor") | ||
|
||
def __release(self): | ||
# release current payload | ||
# shouldn't be visible to users | ||
self._state = TensorState.FREE | ||
self._payload = None | ||
self._payload_size = 0 | ||
|
||
def __trans_state_update(self, from_state: TensorState, to_state: TensorState): | ||
"""Update global manager when changing the state of a tensor | ||
""" | ||
manager = StatefulTensor.GST_MGR | ||
size = self.payload_size | ||
device_type = self.device.type | ||
|
||
if from_state != TensorState.FREE: | ||
manager.state_mem[device_type][from_state] -= size | ||
else: | ||
# when from_state is FREE, the tensor is new to manager | ||
# we should add its memory | ||
manager.total_mem[device_type] += size | ||
|
||
if to_state != TensorState.FREE: | ||
manager.state_mem[device_type][to_state] += size | ||
else: | ||
# when to_state is FREE, the tensor will be deleted soon | ||
# we should sub its memory | ||
manager.total_mem[device_type] -= size | ||
|
||
def __trans_device_update(self, from_type: str, to_type: str): | ||
"""Update global manager when changing the device of a tensor | ||
""" | ||
manager = StatefulTensor.GST_MGR | ||
size = self.payload_size | ||
state = self.state | ||
|
||
# update aggregated information | ||
manager.total_mem[from_type] -= size | ||
manager.total_mem[to_type] += size | ||
|
||
# update the information of each state | ||
manager.state_mem[from_type][state] -= size | ||
manager.state_mem[to_type][state] += size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.