diff --git a/neurodamus/node.py b/neurodamus/node.py index 965c08a64..5e072347f 100644 --- a/neurodamus/node.py +++ b/neurodamus/node.py @@ -33,7 +33,7 @@ from .utils import compat from .utils.logging import log_stage, log_verbose, log_all from .utils.memory import DryRunStats, trim_memory, pool_shrink, free_event_queues, print_mem_usage -from .utils.memory import print_allocation_stats +from .utils.memory import print_allocation_stats, export_allocation_stats, distribute_cells from .utils.timeit import TimerManager, timeit from .core.coreneuron_configuration import CoreConfig # Internal Plugins @@ -1962,13 +1962,13 @@ def run(self): """ if SimConfig.dry_run: log_stage("============= DRY RUN (SKIP SIMULATION) =============") - from .utils.memory import distribute_cells self._dry_run_stats.display_total() self._dry_run_stats.display_node_suggestions() ranks = int(SimConfig.prosp_hosts) self._dry_run_stats.collect_all_mpi() allocation, total_memory_per_rank = distribute_cells(self._dry_run_stats, ranks) print_allocation_stats(allocation, total_memory_per_rank) + export_allocation_stats(allocation, "allocation.bin") return if not SimConfig.simulate_model: self.sim_init() diff --git a/neurodamus/utils/memory.py b/neurodamus/utils/memory.py index e87472ea7..c0ece6643 100644 --- a/neurodamus/utils/memory.py +++ b/neurodamus/utils/memory.py @@ -179,29 +179,33 @@ def distribute_cells(dry_run_stats, num_ranks, batch_size=10) -> (dict, dict): rank_memory (dict): A dictionary where keys are rank IDs and values are the total memory load on each rank. """ - # Check inputs logging.debug("Distributing cells across %d ranks", num_ranks) - logging.debug("Checking inputs...") - set_metype_gids = set() - for values in dry_run_stats.metype_gids.values(): - set_metype_gids.update(values.keys()) - assert set_metype_gids == set(dry_run_stats.metype_memory.keys()) - average_syns_keys = set(dry_run_stats.average_syns_per_cell.keys()) - metype_memory_keys = set(dry_run_stats.metype_memory.keys()) - assert average_syns_keys == metype_memory_keys - assert num_ranks > 0, "num_ranks must be a positive integer" + + def validate_inputs(dry_run_stats, num_ranks, batch_size): + assert isinstance(dry_run_stats, DryRunStats), "dry_run_stats must be a DryRunStats object" + assert isinstance(num_ranks, int), "num_ranks must be an integer" + assert num_ranks > 0, "num_ranks must be a positive integer" + assert isinstance(batch_size, int), "batch_size must be an integer" + assert batch_size > 0, "batch_size must be a positive integer" + set_metype_gids = set() + for values in dry_run_stats.metype_gids.values(): + set_metype_gids.update(values.keys()) + assert set_metype_gids == set(dry_run_stats.metype_memory.keys()) + average_syns_keys = set(dry_run_stats.average_syns_per_cell.keys()) + metype_memory_keys = set(dry_run_stats.metype_memory.keys()) + assert average_syns_keys == metype_memory_keys + + # Check inputs + validate_inputs(dry_run_stats, num_ranks, batch_size) # Multiply the average number of synapses per cell by 2.0 # This is done since the biggest memory load for a synapse is 2.0 kB and at this point in the # code we have lost the information on whether they are excitatory or inhibitory # so we just take the biggest value to be safe. (the difference between the two is minimal) - logging.debug("Multiplying the average number of synapses per cell by 2.0") average_syns_mem_per_cell = {k: v * 2.0 for k, v in dry_run_stats.average_syns_per_cell.items()} # Prepare a list of tuples (cell_id, memory_load) # We sum the memory load of the cell type and the average number of synapses per cell - logging.debug("Creating generator...") - def generate_cells(metype_gids): for cell_type, gids in metype_gids.items(): for gid in gids: @@ -210,12 +214,22 @@ def generate_cells(metype_gids): yield gid, memory_usage # Initialize structures - logging.debug("Initializing structures...") ranks = [(0, i) for i in range(num_ranks)] # (total_memory, rank_id) heapq.heapify(ranks) rank_allocation = {} rank_memory = {} + def assign_cells_to_rank(rank_id, batch_memory, rank_allocation, total_memory, + rank_memory, pop): + if rank_id not in rank_allocation[pop]: + rank_allocation[pop][rank_id] = [] + rank_allocation[pop][rank_id].append(cell_id) + # Update the total memory load of the rank + total_memory += batch_memory + rank_memory[pop][rank_id] = total_memory + # Update total memory and re-add to the heap + heapq.heappush(ranks, (total_memory, rank_id)) + # Start distributing cells across ranks for pop, metype_gids in dry_run_stats.metype_gids.items(): logging.info("Distributing cells of population %s", pop) @@ -232,14 +246,8 @@ def generate_cells(metype_gids): total_memory, rank_id = heapq.heappop(ranks) logging.debug("Assigning batch to rank %d", rank_id) # Add the cell to the rank - if rank_id not in rank_allocation[pop]: - rank_allocation[pop][rank_id] = [] - rank_allocation[pop][rank_id].append(cell_id) - # Update the total memory load of the rank - total_memory += batch_memory - rank_memory[pop][rank_id] = total_memory - # Update total memory and re-add to the heap - heapq.heappush(ranks, (total_memory, rank_id)) + assign_cells_to_rank(rank_id, batch_memory, rank_allocation, total_memory, + rank_memory, pop) batch = [] batch_memory = 0 @@ -247,12 +255,8 @@ def generate_cells(metype_gids): if batch: total_memory, rank_id = heapq.heappop(ranks) logging.debug("Assigning remaining cells in batch to rank %d", rank_id) - if rank_id not in rank_allocation[pop]: - rank_allocation[pop][rank_id] = [] - rank_allocation[pop][rank_id].append(batch) - total_memory += batch_memory - rank_memory[pop][rank_id] = total_memory - heapq.heappush(ranks, (total_memory, rank_id)) + assign_cells_to_rank(rank_id, batch_memory, rank_allocation, total_memory, + rank_memory, pop) return rank_allocation, rank_memory @@ -274,12 +278,25 @@ def print_allocation_stats(rank_allocation, rank_memory): for pop, rank_dict in rank_memory.items(): values = list(rank_dict.values()) logging.info("Population: {}".format(pop)) - logging.info("Mean: {}".format(round(statistics.mean(values)))) + logging.info("Mean allocation per rank [KB]: {}".format(round(statistics.mean(values)))) try: stdev = round(statistics.stdev(values)) except statistics.StatisticsError: stdev = 0 - logging.info("Stdev: {}".format(stdev)) + logging.info("Stdev of allocation per rank [KB]: {}".format(stdev)) + + +@run_only_rank0 +def export_allocation_stats(rank_allocation, filename): + """ + Export allocation dictionary to serialized pickle file. + """ + import pickle + try: + with open(filename, 'wb') as f: + pickle.dump(rank_allocation, f) + except Exception as e: + logging.warning("Unable to export allocation stats: {}".format(e)) class SynapseMemoryUsage: