Skip to content

Commit

Permalink
Refactored cell distribution function, added export allocation to pic…
Browse files Browse the repository at this point in the history
…kle file
  • Loading branch information
st4rl3ss committed Feb 5, 2024
1 parent 1ff8006 commit 5659b11
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
4 changes: 2 additions & 2 deletions neurodamus/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .utils import compat
from .utils.logging import log_stage, log_verbose, log_all
from .utils.memory import DryRunStats, trim_memory, pool_shrink, free_event_queues, print_mem_usage
from .utils.memory import print_allocation_stats
from .utils.memory import print_allocation_stats, export_allocation_stats, distribute_cells
from .utils.timeit import TimerManager, timeit
from .core.coreneuron_configuration import CoreConfig
# Internal Plugins
Expand Down Expand Up @@ -1962,13 +1962,13 @@ def run(self):
"""
if SimConfig.dry_run:
log_stage("============= DRY RUN (SKIP SIMULATION) =============")
from .utils.memory import distribute_cells
self._dry_run_stats.display_total()
self._dry_run_stats.display_node_suggestions()
ranks = int(SimConfig.prosp_hosts)
self._dry_run_stats.collect_all_mpi()
allocation, total_memory_per_rank = distribute_cells(self._dry_run_stats, ranks)
print_allocation_stats(allocation, total_memory_per_rank)
export_allocation_stats(allocation, "allocation.bin")
return
if not SimConfig.simulate_model:
self.sim_init()
Expand Down
77 changes: 47 additions & 30 deletions neurodamus/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,29 +179,33 @@ def distribute_cells(dry_run_stats, num_ranks, batch_size=10) -> (dict, dict):
rank_memory (dict): A dictionary where keys are rank IDs
and values are the total memory load on each rank.
"""
# Check inputs
logging.debug("Distributing cells across %d ranks", num_ranks)
logging.debug("Checking inputs...")
set_metype_gids = set()
for values in dry_run_stats.metype_gids.values():
set_metype_gids.update(values.keys())
assert set_metype_gids == set(dry_run_stats.metype_memory.keys())
average_syns_keys = set(dry_run_stats.average_syns_per_cell.keys())
metype_memory_keys = set(dry_run_stats.metype_memory.keys())
assert average_syns_keys == metype_memory_keys
assert num_ranks > 0, "num_ranks must be a positive integer"

def validate_inputs(dry_run_stats, num_ranks, batch_size):
assert isinstance(dry_run_stats, DryRunStats), "dry_run_stats must be a DryRunStats object"
assert isinstance(num_ranks, int), "num_ranks must be an integer"
assert num_ranks > 0, "num_ranks must be a positive integer"
assert isinstance(batch_size, int), "batch_size must be an integer"
assert batch_size > 0, "batch_size must be a positive integer"
set_metype_gids = set()
for values in dry_run_stats.metype_gids.values():
set_metype_gids.update(values.keys())
assert set_metype_gids == set(dry_run_stats.metype_memory.keys())
average_syns_keys = set(dry_run_stats.average_syns_per_cell.keys())
metype_memory_keys = set(dry_run_stats.metype_memory.keys())
assert average_syns_keys == metype_memory_keys

# Check inputs
validate_inputs(dry_run_stats, num_ranks, batch_size)

# Multiply the average number of synapses per cell by 2.0
# This is done since the biggest memory load for a synapse is 2.0 kB and at this point in the
# code we have lost the information on whether they are excitatory or inhibitory
# so we just take the biggest value to be safe. (the difference between the two is minimal)
logging.debug("Multiplying the average number of synapses per cell by 2.0")
average_syns_mem_per_cell = {k: v * 2.0 for k, v in dry_run_stats.average_syns_per_cell.items()}

# Prepare a list of tuples (cell_id, memory_load)
# We sum the memory load of the cell type and the average number of synapses per cell
logging.debug("Creating generator...")

def generate_cells(metype_gids):
for cell_type, gids in metype_gids.items():
for gid in gids:
Expand All @@ -210,12 +214,22 @@ def generate_cells(metype_gids):
yield gid, memory_usage

# Initialize structures
logging.debug("Initializing structures...")
ranks = [(0, i) for i in range(num_ranks)] # (total_memory, rank_id)
heapq.heapify(ranks)
rank_allocation = {}
rank_memory = {}

def assign_cells_to_rank(rank_id, batch_memory, rank_allocation, total_memory,
rank_memory, pop):
if rank_id not in rank_allocation[pop]:
rank_allocation[pop][rank_id] = []
rank_allocation[pop][rank_id].append(cell_id)
# Update the total memory load of the rank
total_memory += batch_memory
rank_memory[pop][rank_id] = total_memory
# Update total memory and re-add to the heap
heapq.heappush(ranks, (total_memory, rank_id))

# Start distributing cells across ranks
for pop, metype_gids in dry_run_stats.metype_gids.items():
logging.info("Distributing cells of population %s", pop)
Expand All @@ -232,27 +246,17 @@ def generate_cells(metype_gids):
total_memory, rank_id = heapq.heappop(ranks)
logging.debug("Assigning batch to rank %d", rank_id)
# Add the cell to the rank
if rank_id not in rank_allocation[pop]:
rank_allocation[pop][rank_id] = []
rank_allocation[pop][rank_id].append(cell_id)
# Update the total memory load of the rank
total_memory += batch_memory
rank_memory[pop][rank_id] = total_memory
# Update total memory and re-add to the heap
heapq.heappush(ranks, (total_memory, rank_id))
assign_cells_to_rank(rank_id, batch_memory, rank_allocation, total_memory,
rank_memory, pop)
batch = []
batch_memory = 0

# Assign any remaining cells in the last, potentially incomplete batch
if batch:
total_memory, rank_id = heapq.heappop(ranks)
logging.debug("Assigning remaining cells in batch to rank %d", rank_id)
if rank_id not in rank_allocation[pop]:
rank_allocation[pop][rank_id] = []
rank_allocation[pop][rank_id].append(batch)
total_memory += batch_memory
rank_memory[pop][rank_id] = total_memory
heapq.heappush(ranks, (total_memory, rank_id))
assign_cells_to_rank(rank_id, batch_memory, rank_allocation, total_memory,
rank_memory, pop)

return rank_allocation, rank_memory

Expand All @@ -274,12 +278,25 @@ def print_allocation_stats(rank_allocation, rank_memory):
for pop, rank_dict in rank_memory.items():
values = list(rank_dict.values())
logging.info("Population: {}".format(pop))
logging.info("Mean: {}".format(round(statistics.mean(values))))
logging.info("Mean allocation per rank [KB]: {}".format(round(statistics.mean(values))))
try:
stdev = round(statistics.stdev(values))
except statistics.StatisticsError:
stdev = 0
logging.info("Stdev: {}".format(stdev))
logging.info("Stdev of allocation per rank [KB]: {}".format(stdev))


@run_only_rank0
def export_allocation_stats(rank_allocation, filename):
"""
Export allocation dictionary to serialized pickle file.
"""
import pickle
try:
with open(filename, 'wb') as f:
pickle.dump(rank_allocation, f)
except Exception as e:
logging.warning("Unable to export allocation stats: {}".format(e))


class SynapseMemoryUsage:
Expand Down

0 comments on commit 5659b11

Please sign in to comment.