Skip to content

Commit

Permalink
[BBPBGLIB-1027] Fix: Load Balance with multi-populations (#92)
Browse files Browse the repository at this point in the history
## Context

A number of Sonata circuits wouldn't use the load balance object, even
though it had been built. See BBPBGLIB-1027

The issue stems from the fact that we compute load balance for a single
population, but there was no way to identify which one. That information
is required when later instantiating the circuit.

## Scope

This PR has two sides of it. 
1. As an interim solution before we do load balance for all circuits, we
pass the single load_balance object to all.
2. An actual improvement to `LoadBalance`, making it aware of its
population, so that
- It won't mix and attempt creating a sub load balance out of another
population
- We can easily identify if the load-balancer is suitable for a given
circuit

Additionally we _stopped_ enabling `MultiSplit` automatically since it's
a pretty advanced and delicate option. Instead, when there are many more
cores than cells, a warning is raised.


## Testing
```
neurodamus-py/tests/simulations/v5_sonata $ srun -Aproj16 -n2 neurodamus simulation_config.json --lb-mode=WholeCell --verbose
[...]
[STEP] LOADING NODES
[STEP] Circuit default
[VERB]  -> Nodes Format: NodeFormat.SONATA, Loader: load_sonata
[INFO] Reading Nodes (METype) info from '/gpfs/bbp.cscs.ch/project/proj1/circuits/SomatosensoryCxS1-v5.r0/O1-sonata/sonata/networks/nodes/default/nodes.h5'
[INFO]  => Cell distribution from Load Balance is valid
```

## Review
* [x] PR description is complete
* [x] Coding style (imports, function length, New functions, classes or
files) are good
* [ ] Unit/Scientific test added
* [ ] Updated Readme, in-code, developer documentation
  • Loading branch information
ferdonline authored Dec 13, 2023
1 parent 1e1200d commit 4adf400
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 25 deletions.
31 changes: 20 additions & 11 deletions neurodamus/cell_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,10 @@ def load_nodes(self, load_balancer=None, *, _loader=None, loader_opts=None):
loader_f = (lambda *args: _loader(*args, **loader_opts)) if loader_opts else _loader

logging.info("Reading Nodes (METype) info from '%s'", conf.CellLibraryFile)
if load_balancer and load_balancer.population != self._target_spec.population:
log_verbose("Load balance object doesn't apply to '%s'", self._target_spec.population)
load_balancer = None
if not load_balancer or SimConfig.dry_run:
# Use common loading routine, providing the loader
gidvec, me_infos, *cell_counts = self._load_nodes(loader_f)
else:
gidvec, me_infos, *cell_counts = self._load_nodes_balance(loader_f, load_balancer)
Expand Down Expand Up @@ -596,7 +598,8 @@ class LoadBalance:
generating and loading the various files.
LoadBalance instances target the current system (cpu count) and circuit
(nrn_path) BUT check/create load distribution for any given target.
BUT check/create load distribution for any given target.
The circuit is identified by the nodes file AND population.
NOTE: Given the heavy costs of computing load balance, some state files are created
which allow the balance info to be reused. These are
Expand All @@ -607,26 +610,27 @@ class LoadBalance:
For more information refer to the developer documentation.
"""
_base_output_dir = "sim_conf"
_circuit_lb_dir_tpl = "_loadbal_%s"
_circuit_lb_dir_tpl = "_loadbal_%s.%s" # Placeholders are (file_src_hash, population)
_cx_filename_tpl = "cx_%s#.dat" # use # to well delimiter the target name
_cpu_assign_filename_tpl = "cx_%s#.%s.dat" # prefix must be same (imposed by Neuron)

def __init__(self, balance_mode, nodes_path, target_manager, target_cpu_count=None):
def __init__(self, balance_mode, nodes_path, pop, target_manager, target_cpu_count=None):
"""
Creates a new Load Balance object, associated with a given node file
"""
self.lb_mode = balance_mode
self.target_cpu_count = target_cpu_count or MPI.size
self._target_manager = target_manager
self._valid_loadbalance = set()
self._lb_dir, self._cx_targets = self._get_circuit_loadbal_dir(nodes_path)
self.population = pop or ""
self._lb_dir, self._cx_targets = self._get_circuit_loadbal_dir(nodes_path, self.population)
log_verbose("Found existing targets with loadbal: %s", self._cx_targets)

@classmethod
@run_only_rank0
def _get_circuit_loadbal_dir(cls, node_file) -> tuple:
def _get_circuit_loadbal_dir(cls, node_file, pop) -> tuple:
"""Ensure lbal dir exists. dir may be crated on rank 0"""
lb_dir = cls._loadbal_dir(node_file)
lb_dir = cls._loadbal_dir(node_file, pop)
if lb_dir.is_dir():
return lb_dir, cls._get_lbdir_targets(lb_dir)

Expand All @@ -644,10 +648,15 @@ def _get_lbdir_targets(cls, lb_dir: Path) -> list:
)

@run_only_rank0
def valid_load_distribution(self, target_spec) -> bool:
def valid_load_distribution(self, target_spec: TargetSpec) -> bool:
"""Checks whether we have valid load-balance files, attempting to
derive from larger target distributions if possible.
"""
if (target_spec.population or "") != self.population:
logging.info(" => Load balance Population mismatch. Requested: %s, Existing: %s",
target_spec.population, self.population)
return False

target_name = target_spec.simple_name

# Check cache
Expand Down Expand Up @@ -676,7 +685,7 @@ def valid_load_distribution(self, target_spec) -> bool:
return False

# -
def _reuse_cell_complexity(self, target_spec) -> bool:
def _reuse_cell_complexity(self, target_spec: TargetSpec) -> bool:
"""Check if the complexities of all target gids were already calculated
for another target.
"""
Expand Down Expand Up @@ -917,10 +926,10 @@ def load_balance_info(self, target_spec):
return Nd.BalanceInfo(bal_filename, MPI.rank, MPI.size)

@classmethod
def _loadbal_dir(cls, nodefile) -> Path:
def _loadbal_dir(cls, nodefile, population) -> Path:
"""Returns the dir where load balance files are stored for a given nodes file"""
nodefile_hash = hashlib.md5(nodefile.encode()).digest().hex()[:10]
return Path(cls._base_output_dir) / (cls._circuit_lb_dir_tpl % nodefile_hash)
return Path(cls._base_output_dir) / (cls._circuit_lb_dir_tpl % (nodefile_hash, population))

def _cx_filename(self, target_str, basename_str=False) -> Path:
"""Gets the filename of a cell complexity file for a given target"""
Expand Down
6 changes: 3 additions & 3 deletions neurodamus/core/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def parse(cls, lb_mode):

class AutoBalanceModeParams:
"""Parameters for auto-selecting a load-balance mode"""
multisplit_cpu_cell_ratio = 4 # Complexity not worth unless large ratio
multisplit_cpu_cell_ratio = 4 # For warning
cell_count = 1000
duration = 1000
mpi_ranks = 200
Expand All @@ -173,8 +173,8 @@ def auto_select(cls, use_neuron, cell_count, duration, auto_params=AutoBalanceMo
lb_mode = LoadBalanceMode.RoundRobin
reason = "Single rank - not worth using Load Balance"
elif use_neuron and MPI.size >= auto_params.multisplit_cpu_cell_ratio * cell_count:
lb_mode = LoadBalanceMode.MultiSplit
reason = "CPU-Cell ratio"
logging.warn("There's potentially a high number of empty ranks. "
"To activate multi-split set --lb-mode=MultiSplit")
elif (cell_count > auto_params.cell_count
and duration > auto_params.duration
and MPI.size > auto_params.mpi_ranks):
Expand Down
10 changes: 4 additions & 6 deletions neurodamus/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,16 +383,15 @@ def compute_load_balance(self):
circuit.CircuitPath if is_sonata_config
else self._run_conf["nrnPath"] or circuit.CircuitPath
)
load_balancer = LoadBalance(lb_mode, data_src, self._target_manager, prosp_hosts)
pop = target_spec.population
load_balancer = LoadBalance(lb_mode, data_src, pop, self._target_manager, prosp_hosts)

if load_balancer.valid_load_distribution(target_spec):
logging.info("Load Balancing done.")
return load_balancer

logging.info("Could not reuse load balance data. Doing a Full Load-Balance")
cell_dist = self._circuits.new_node_manager(
circuit, self._target_manager, self._run_conf
)
cell_dist = self._circuits.new_node_manager(circuit, self._target_manager, self._run_conf)
with load_balancer.generate_load_balance(target_spec, cell_dist):
# Instantiate a basic circuit to evaluate complexities
cell_dist.finalize()
Expand Down Expand Up @@ -427,8 +426,6 @@ def create_cells(self, load_balance=None):
logging.info("Memory usage after inizialization:")
print_mem_usage()
self._dry_run_stats = DryRunStats()
# We load the memory usage rather early since it will be needed at the moment we load
# the cell ids. This way we can avoid gidvec from having gids of known metype cells.
self._dry_run_stats.try_import_cell_memory_usage()
loader_opts = {"dry_run_stats": self._dry_run_stats}
else:
Expand All @@ -455,6 +452,7 @@ def create_cells(self, load_balance=None):
logging.warning("Skipped node population (restrict_node_populations)")
continue
self._circuits.new_node_manager(circuit, self._target_manager, self._run_conf,
load_balancer=load_balance,
loader_opts=loader_opts)

lfp_weights_file = self._run_conf.get("LFPWeightsPath")
Expand Down
10 changes: 5 additions & 5 deletions tests/integration-e2e/test_loadbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def target_manager_hoc():

def test_loadbal_no_cx(target_manager_hoc, caplog):
from neurodamus.cell_distributor import LoadBalance, TargetSpec
lbal = LoadBalance(1, "/gpfs/fake_path_to_nodes_1", target_manager_hoc, 4)
lbal = LoadBalance(1, "/gpfs/fake_path_to_nodes_1", "pop", target_manager_hoc, 4)
assert not lbal._cx_targets
assert not lbal._valid_loadbalance
with caplog.at_level(logging.INFO):
Expand All @@ -35,10 +35,10 @@ def test_loadbal_subtarget(target_manager_hoc, caplog):
tmp_path = tempfile.TemporaryDirectory("test_loadbal_subtarget")
os.chdir(tmp_path.name)
nodes_file = "/gpfs/fake_node_path"
lbdir, _ = LoadBalance._get_circuit_loadbal_dir(nodes_file)
lbdir, _ = LoadBalance._get_circuit_loadbal_dir(nodes_file, "pop")
shutil.copyfile(SIM_DIR / "1k_v5_balance" / "cx_Small.dat", lbdir / "cx_Small#.dat")

lbal = LoadBalance(1, nodes_file, target_manager_hoc, 4)
lbal = LoadBalance(1, nodes_file, "pop", target_manager_hoc, 4)
assert "Small" in lbal._cx_targets
assert not lbal._valid_loadbalance
with caplog.at_level(logging.INFO):
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_load_balance_integrated(target_manager_hoc, circuit_conf):
cell_manager = CellDistributor(circuit_conf, target_manager_hoc)
cell_manager.load_nodes()

lbal = LoadBalance(1, circuit_conf.CircuitPath, target_manager_hoc, 4)
lbal = LoadBalance(1, circuit_conf.CircuitPath, "", target_manager_hoc, 4)
t1 = TargetSpec("Small")
assert not lbal._cx_valid(t1)

Expand Down Expand Up @@ -108,7 +108,7 @@ def test_multisplit(target_manager_hoc, circuit_conf, capsys):

cell_manager = CellDistributor(circuit_conf, target_manager_hoc)
cell_manager.load_nodes()
lbal = LoadBalance(MULTI_SPLIT, circuit_conf.CircuitPath, target_manager_hoc, 4)
lbal = LoadBalance(MULTI_SPLIT, circuit_conf.CircuitPath, "", target_manager_hoc, 4)
t1 = TargetSpec("Small")
assert not lbal._cx_valid(t1)

Expand Down

0 comments on commit 4adf400

Please sign in to comment.