Skip to content

Commit

Permalink
Merge pull request #5749 from hpcaitech/prefetch
Browse files Browse the repository at this point in the history
[Gemini] Prefetch next chunk before each op
  • Loading branch information
botbw authored May 29, 2024
2 parents b96c639 + 154720b commit 023ea13
Show file tree
Hide file tree
Showing 15 changed files with 239 additions and 65 deletions.
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def __init__(
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
Expand Down Expand Up @@ -387,6 +388,7 @@ def __init__(
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
)
self.zero_optim_config = dict(
Expand Down
13 changes: 8 additions & 5 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,14 +359,15 @@ def shard_move(self, device: torch.device, force_copy: bool = False):
else:
raise NotImplementedError

def access_chunk(self):
def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
# sanity check
assert self.chunk_temp is None

maybe_work = None
if not self.is_gathered:
self.__gather()
maybe_work = self.__gather(async_op=async_access)
self.__update_tensors_ptr()
return maybe_work

def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA."""
Expand Down Expand Up @@ -512,17 +513,19 @@ def optim_update(self) -> None:
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())

def __gather(self):
def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None

alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)

self.cuda_shard = None
self.is_gathered = True
return work
return None

def __scatter(self):
if self.keep_gathered:
Expand Down
12 changes: 7 additions & 5 deletions colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,16 @@ def close_all_groups(self):
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])

def access_chunk(self, chunk: Chunk) -> None:
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks:
return
return None
self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu":
chunk.shard_move(get_accelerator().get_current_device())
self.__add_accessed_chunk(chunk)
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
self.__add_memory_usage(chunk.memory_usage)
return maybe_work

def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA."""
Expand Down Expand Up @@ -251,10 +252,11 @@ def __add_memory_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
self.total_mem[k] += v

def __add_accessed_chunk(self, chunk: Chunk):
chunk.access_chunk()
def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
maybe_work = chunk.access_chunk(async_access=async_access)
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
return maybe_work

def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()
Expand Down
2 changes: 2 additions & 0 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
max_prefetch=max_prefetch,
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
Expand Down
34 changes: 30 additions & 4 deletions colossalai/zero/gemini/gemini_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,42 @@ def __init__(self, gemini_manager: GeminiManager) -> None:
self._training_phase = TrainingPhase.FORWARD

def pre_op(self, params):
# map params to chunks
params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params)
all_chunks = self._chunk_manager.get_chunks(params)

# wait for prefetched chunks, filter those are not prefetched
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)

# transfer state
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks)
for chunk in chunks:

# evit chunks, aware of async fetched
self._gemini_manager.adjust_layout(
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
)

# fetch the rest synchronously
for chunk in chunks_fetch_sync:
self._chunk_manager.access_chunk(chunk)

# record cuda model data of the current OP
# get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(
is_warmup=self._gemini_manager.is_warmup(),
compute_list=self._gemini_manager.compute_list,
compute_idx=self._gemini_manager.compute_idx,
async_works=self._gemini_manager.async_works,
)

# prefetch
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)

# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()

def post_op(self, params):
Expand Down
55 changes: 46 additions & 9 deletions colossalai/zero/gemini/gemini_mgr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import functools
from time import time
from typing import Dict, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple

import torch
import torch.distributed as dist

from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats
from .placement_policy import PlacementPolicyFactory
from .placement_policy import PlacementPolicy, PlacementPolicyFactory


class GeminiManager:
Expand Down Expand Up @@ -41,9 +42,12 @@ def __init__(
self._mem_stats_collector = (
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
)
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._placement_policy = policy_cls(
chunk_manager=chunk_manager, mem_stats_collector=self._mem_stats_collector, **placement_kwargs
)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
self._async_works: Dict[Chunk, dist.Work] = {}

self._h2d_volume = 0
self._d2h_volume = 0
Expand Down Expand Up @@ -91,18 +95,20 @@ def post_iter(self):
self._warmup = False
self.reset_attributes()

def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
"""Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
# don't evict chunks that are asynchronously fetched
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
self._layout_time += time() - start

vol, evict_time = self._placement_policy.evict_tensors(
can_evict_chunks=hold_cuda_tensor_list,
can_evict_chunks=can_evict_chunks,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
Expand All @@ -114,6 +120,21 @@ def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
# move COMPUTE tensors to CUDA
self._h2d_volume += cuda_demand

def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return tuple(non_prefetched_chunks)

def add_work(self, chunk: Chunk, work: dist.Work):
assert work is not None
assert chunk not in self._async_works
self._async_works[chunk] = work

@functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
start = time()
Expand All @@ -133,9 +154,9 @@ def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk,
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
return cuda_demand, can_evict_chunks

def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats:
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
self._compute_list.append(chunks)

def sample_overall_data(self):
Expand All @@ -156,6 +177,22 @@ def cuda_margin_mem(self) -> Optional[float]:
return self._mem_stats_collector.cuda_margin_mem
return None

@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy

@property
def compute_list(self) -> List[Tuple[Chunk, ...]]:
return self._compute_list

@property
def compute_idx(self) -> int:
return self._compute_idx

@property
def async_works(self) -> Dict[Chunk, dist.Work]:
return self._async_works

@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats
Expand Down
Loading

0 comments on commit 023ea13

Please sign in to comment.