Skip to content

Commit

Permalink
Solved test issue, cells get assigned in bunches, no more sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
st4rl3ss committed Jan 31, 2024
1 parent c502dd9 commit 4abfd4e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 36 deletions.
3 changes: 2 additions & 1 deletion neurodamus/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,12 +753,13 @@ def _get_conn_stats(self, dst_target):
return {}

local_counter = Counter()
dst_pop_name = self._cell_manager.population_name

# NOTE:
# - Estimation (and extrapolation) is performed per metype since properties can vary
# - Consider only the cells for the current target

for metype, me_gids in self._dry_run_stats.metype_gids.items():
for metype, me_gids in self._dry_run_stats.metype_gids[dst_pop_name].items():
me_gids = set(me_gids).intersection(new_gids)
me_gids_count = len(me_gids)
if not me_gids_count:
Expand Down
2 changes: 1 addition & 1 deletion neurodamus/io/cell_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def load_base_info_dry_run():
# skip_metypes = set(dry_run_stats.metype_memory.keys())
metype_gids, counts = _retrieve_unique_metypes(node_pop, all_gids)
dry_run_stats.metype_counts += counts
dry_run_stats.metype_gids = metype_gids
dry_run_stats.metype_gids[node_population] = metype_gids
gid_metype_bundle = list(metype_gids.values())
gidvec = dry_run_distribution(gid_metype_bundle, stride, stride_offset, total_cells)

Expand Down
96 changes: 62 additions & 34 deletions neurodamus/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def pretty_printing_memory_mb(memory_mb):


@run_only_rank0
def distribute_cells(dry_run_stats, num_ranks) -> (dict, dict):
def distribute_cells(dry_run_stats, num_ranks, batch_size=10) -> (dict, dict):
"""
Distributes cells across ranks based on their memory load.
Expand All @@ -182,7 +182,10 @@ def distribute_cells(dry_run_stats, num_ranks) -> (dict, dict):
# Check inputs
logging.debug("Distributing cells across %d ranks", num_ranks)
logging.debug("Checking inputs...")
assert set(dry_run_stats.metype_gids.keys()) == set(dry_run_stats.metype_memory.keys())
set_metype_gids = set()
for values in dry_run_stats.metype_gids.values():
set_metype_gids.update(values.keys())
assert set_metype_gids == set(dry_run_stats.metype_memory.keys())
average_syns_keys = set(dry_run_stats.average_syns_per_cell.keys())
metype_memory_keys = set(dry_run_stats.metype_memory.keys())
assert average_syns_keys == metype_memory_keys
Expand All @@ -199,39 +202,57 @@ def distribute_cells(dry_run_stats, num_ranks) -> (dict, dict):
# We sum the memory load of the cell type and the average number of synapses per cell
logging.debug("Creating generator...")

def generate_cells():
heap = []
for cell_type, gids in dry_run_stats.metype_gids.items():
def generate_cells(metype_gids):
for cell_type, gids in metype_gids.items():
for gid in gids:
memory_usage = dry_run_stats.metype_memory[cell_type] + average_syns_mem_per_cell[cell_type]
# Use negative memory usage as the priority for descending order
heapq.heappush(heap, (-memory_usage, gid))

# Yield from the heap in sorted order
while heap:
memory_usage, gid = heapq.heappop(heap)
yield gid, -memory_usage
memory_usage = (dry_run_stats.metype_memory[cell_type] +
average_syns_mem_per_cell[cell_type])
yield gid, memory_usage

# Initialize structures
logging.debug("Initializing structures...")
ranks = [(0, i) for i in range(num_ranks)] # (total_memory, rank_id)
heapq.heapify(ranks)
rank_allocation = {i: [] for i in range(num_ranks)}
rank_memory = {i: 0 for i in range(num_ranks)}

# Start distributing cells across ranks starting with the ones with higher memory load
logging.debug("Distributing cells across ranks...")
for cell_id, memory in generate_cells():
# Get the rank with the lowest memory load
total_memory, rank_id = heapq.heappop(ranks)
logging.debug("Assigning cell %d to rank %d", cell_id, rank_id)
# Add the cell to the rank
rank_allocation[rank_id].append(cell_id)
# Update the total memory load of the rank
total_memory += memory
rank_memory[rank_id] = total_memory
# Update total memory and re-add to the heap
heapq.heappush(ranks, (total_memory, rank_id))
rank_allocation = {}
rank_memory = {}

# Start distributing cells across ranks
for pop, metype_gids in dry_run_stats.metype_gids.items():
logging.info("Distributing cells of population %s", pop)
rank_allocation[pop] = {}
rank_memory[pop] = {}
batch = []
batch_memory = 0

for cell_id, memory in generate_cells(metype_gids):
batch.append(cell_id)
batch_memory += memory
if len(batch) == batch_size:
# Get the rank with the lowest memory load
total_memory, rank_id = heapq.heappop(ranks)
logging.debug("Assigning batch to rank %d", rank_id)
# Add the cell to the rank
if rank_id not in rank_allocation[pop]:
rank_allocation[pop][rank_id] = []
rank_allocation[pop][rank_id].append(cell_id)
# Update the total memory load of the rank
total_memory += batch_memory
rank_memory[pop][rank_id] = total_memory
# Update total memory and re-add to the heap
heapq.heappush(ranks, (total_memory, rank_id))
batch = []
batch_memory = 0

# Assign any remaining cells in the last, potentially incomplete batch
if batch:
total_memory, rank_id = heapq.heappop(ranks)
logging.debug("Assigning remaining cells in batch to rank %d", rank_id)
if rank_id not in rank_allocation[pop]:
rank_allocation[pop][rank_id] = []
rank_allocation[pop][rank_id].append(batch)
total_memory += batch_memory
rank_memory[pop][rank_id] = total_memory
heapq.heappush(ranks, (total_memory, rank_id))

return rank_allocation, rank_memory

Expand All @@ -247,12 +268,18 @@ def print_allocation_stats(rank_allocation, rank_memory):
rank_memory (dict): A dictionary where keys are rank IDs
and values are the total memory load on each rank.
"""
print("Total memory per rank: ", rank_memory)
logging.debug("Rank allocation: {}".format(rank_allocation))
logging.debug("Total memory per rank: {}".format(rank_memory))
import statistics
values = list(rank_memory.values())
print("Mean: ", round(statistics.mean(values)))
print("Median: ", round(statistics.median(values)))
print("Stdev: ", round(statistics.stdev(values)))
for pop, rank_dict in rank_memory.items():
values = list(rank_dict.values())
logging.info("Population: {}".format(pop))
logging.info("Mean: {}".format(round(statistics.mean(values))))
try:
stdev = round(statistics.stdev(values))
except statistics.StatisticsError:
stdev = 0
logging.info("Stdev: {}".format(stdev))


class SynapseMemoryUsage:
Expand Down Expand Up @@ -280,6 +307,7 @@ def __init__(self) -> None:
self.average_syns_per_cell = {}
self.metype_counts = Counter()
self.synapse_counts = Counter()
self.metype_gids = {}
_, _, self.base_memory, _ = get_task_level_mem_usage()

@run_only_rank0
Expand Down

0 comments on commit 4abfd4e

Please sign in to comment.