Skip to content

Commit

Permalink
[BBPBGLIB-1145] Write gids without offset in WholeCell LB complexity …
Browse files Browse the repository at this point in the history
…file (#148)

## Context
Addresses issue highlighted in BBPBGLIB-1145, where the offsets were
being calculated before the load balancing, causing issue sometimes when
having virtual populations alphabetically ordered before the first real
population.

## Review
* [x] PR description is complete
* [x] Coding style (imports, function length, New functions, classes or
files) are good
* [x] Unit/Scientific test added
* [ ] Updated Readme, in-code, developer documentation
  • Loading branch information
jorblancoa authored and atemerev committed May 24, 2024
1 parent 6c37569 commit 75019b9
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 18 deletions.
31 changes: 18 additions & 13 deletions neurodamus/cell_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def get_cellref(self, gid):
Returns: Cell object
"""
if self._binfo:
# are we in load balance mode? must replace gid with spgid
gid = self._binfo.thishost_gid(gid)
# are we in load balance mode? raw gids are in the binfo
gid_offset = self._local_nodes.offset
gid = self._binfo.thishost_gid(gid - gid_offset) + gid_offset
return self._pc.gid2obj(gid)

# Methods for compat with hoc
Expand Down Expand Up @@ -242,7 +243,7 @@ def _load_nodes_balance(self, loader_f, load_balancer):
self._binfo = load_balancer.load_balance_info(target_spec)
# self._binfo has gidlist, but gids can appear multiple times
all_gids = numpy.unique(
self._binfo.gids.as_numpy().astype("uint32") - self._local_nodes.offset
self._binfo.gids.as_numpy().astype("uint32")
)
total_cells = len(all_gids)
gidvec, me_infos, full_size = loader_f(self._circuit_conf, all_gids)
Expand Down Expand Up @@ -358,19 +359,21 @@ def _init_cell_network(self):
for final_gid, cell in self._gid2cell.items():
cell.re_init_rng(self._ionchannel_seed)
nc = cell.connect2target(None) # Netcon doesnt require being stored

raw_gid = final_gid - self._local_nodes.offset
if self._binfo:
gid_i = int(self._binfo.gids.indwhere("==", final_gid))
gid_i = int(self._binfo.gids.indwhere("==", raw_gid))
cb = self._binfo.bilist.object(self._binfo.cbindex.x[gid_i])
# multisplit cells call cb.multisplit() instead
if cb.subtrees.count() > 0:
cb.multisplit(nc, self._binfo.msgid, pc, pc.id())
cell.gid = final_gid
cell.raw_gid = raw_gid
continue

pc.set_gid2node(final_gid, pc.id())
pc.cell(final_gid, nc)
cell.gid = final_gid # update the cell.gid last (RNGs had to use the base gid)
cell.raw_gid = raw_gid

pc.multisplit()

Expand Down Expand Up @@ -487,8 +490,9 @@ def get_cellref(self, gid):
"""
manager = self._find_manager(gid)
if manager._binfo:
# are we in load balance mode? must replace gid with spgid
gid = manager._binfo.thishost_gid(gid)
# are we in load balance mode? raw gids are in the binfo
gid_offset = manager.local_nodes.offset
gid = manager._binfo.thishost_gid(gid - gid_offset) + gid_offset
return self._pc.gid2obj(gid)

def getSpGid(self, gid):
Expand All @@ -502,7 +506,8 @@ def getSpGid(self, gid):
"""
manager = self._find_manager(gid)
if manager._binfo:
return manager._binfo.thishost_gid(gid)
gid_offset = manager.local_nodes.offset
return manager._binfo.thishost_gid(gid - gid_offset) + gid_offset
return gid

def getPopulationInfo(self, gid):
Expand Down Expand Up @@ -670,7 +675,7 @@ def _reuse_cell_complexity(self, target_spec: TargetSpec) -> bool:
return False

logging.info("Attempt reusing cx files from other targets...")
target_gids = self._get_target_gids(target_spec)
target_gids = self._get_target_raw_gids(target_spec)
cx_other = {}

for previous_target in self._cx_targets:
Expand Down Expand Up @@ -712,7 +717,7 @@ def _cx_valid(self, target_spec) -> bool:
return False

if target_spec: # target provided, otherwise everything
target_gids = self._get_target_gids(target_spec)
target_gids = self._get_target_raw_gids(target_spec)
if not self._cx_contains_gids(cx_filename, target_gids):
logging.warning(" => %s invalid: changed target definition!", cx_filename)
return False
Expand Down Expand Up @@ -768,7 +773,7 @@ def _compute_save_complexities(self, target_str, mcomplex, cell_distributor):

for cell in cell_distributor.cells:
mcomplex.cell_complexity(cell.CellRef)
mcomplex.multisplit(cell.gid, lcx, tmp)
mcomplex.multisplit(cell.raw_gid, lcx, tmp)
ms_list.append(tmp.c())

# To output build independently the contents of the file then append
Expand Down Expand Up @@ -889,8 +894,8 @@ def _write_msdat_dict(fp, cx_dict, gids=None):
fp.write(line) # raw lines, include \n

# -
def _get_target_gids(self, target_spec) -> numpy.ndarray:
return self._target_manager.get_target(target_spec).get_gids()
def _get_target_raw_gids(self, target_spec) -> numpy.ndarray:
return self._target_manager.get_target(target_spec).get_raw_gids()

def load_balance_info(self, target_spec):
""" Loads a load-balance info for a given target.
Expand Down
3 changes: 2 additions & 1 deletion neurodamus/metype.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ class BaseCell:
"""
Class representing an basic cell, e.g. an artificial cell
"""
__slots__ = ("_cellref", "_ccell")
__slots__ = ("_cellref", "_ccell", "raw_gid")

def __init__(self, gid, cell_info, circuit_info):
self._cellref = None
self._ccell = None
self.raw_gid = None

@property
def CellRef(self):
Expand Down
69 changes: 66 additions & 3 deletions tests/integration-e2e/test_loadbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,81 @@ def test_multisplit(target_manager, circuit_conf, capsys):
assert "Target VerySmall is a subset of the target All_Small" in captured.out


def _create_tmpconfig_lbal(config_file):
import json
import shutil
from tempfile import NamedTemporaryFile

with open(config_file, "r") as f:
sim_config_data = json.load(f)
sim_config_data["network"] = "circuit_config_virtualpop.json"
sim_config_data["connection_overrides"] = [
{
"name": "virtual_proj",
"source": "virtual_target",
"target": "l4pc",
"weight": 0.0
},
{
"name": "disconnect",
"source": "l4pc",
"target": "virtual_target",
"delay": 0.025,
"weight": 0.0
}
]
tmp_file = NamedTemporaryFile(suffix=".json", dir=os.path.dirname(config_file), delete=True)
shutil.copy2(config_file, tmp_file.name)

with open(tmp_file.name, "w") as f:
json.dump(sim_config_data, f, indent=2)
return tmp_file


def _read_complexity_file(base_dir, pattern, cx_pattern):
import glob
# Construct the full pattern path
full_pattern = os.path.join(base_dir, pattern, cx_pattern)

# Use glob to find files that match the pattern
matching_files = glob.glob(full_pattern)

# Read each matching file
for file_path in matching_files:
try:
with open(file_path, 'r') as file:
content = file.read()
return content
except FileNotFoundError:
print(f"File not found: {file_path}")


def test_loadbal_integration():
"""Ensure given the right files are in the lbal dir, the correct situation is detected
"""
from neurodamus import Node
from neurodamus.core.configuration import GlobalConfig
from neurodamus.core.configuration import GlobalConfig, SimConfig
GlobalConfig.verbosity = 2
config_file = str(SIM_DIR / "usecase3" / "simulation_sonata.json")
nd = Node(config_file, {"lb_mode": "WholeCell"})

# Add connection_overrides for the virtual population so the offsets are calculated before LB
tmp_file = _create_tmpconfig_lbal(SIM_DIR / "usecase3" / "simulation_sonata.json")
nd = Node(tmp_file.name, {"lb_mode": "WholeCell"})
nd.load_targets()
SimConfig.check_connections_configure(nd._target_manager)
lb = nd.compute_load_balance()
nd.create_cells(lb)

# Check the complexity file
base_dir = "sim_conf"
pattern = "_loadbal_*.NodeA" # Matches any hash and population
cx_pattern = "cx_NodeA*#.dat" # Matches any cx file with the pattern
assert Path(base_dir).is_dir(), "Directory 'sim_conf' not found."
cx_file = _read_complexity_file(base_dir, pattern, cx_pattern)
lines = cx_file.splitlines()
assert int(lines[1]) == 3, "Number of gids different than 3."
# Gid should be without offset (2 instead of 1002)
assert int(lines[3].split()[0]) == 2, "gid 2 not found."


class MockedTargetManager:
"""
Expand Down
70 changes: 70 additions & 0 deletions tests/simulations/usecase3/circuit_config_virtualpop.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{
"version": 2,
"networks": {
"nodes": [
{
"nodes_file": "nodes_A.h5",
"populations": {
"NodeA": {
"type": "biophysical",
"morphologies_dir": "CircuitA/morphologies/swc",
"biophysical_neuron_models_dir": "CircuitA/hoc",
"alternate_morphologies": {
"neurolucida-asc": "CircuitA/morphologies/asc"
}
}
}
},
{
"nodes_file": "nodes_B.h5",
"populations": {
"NodeB": {
"type": "biophysical",
"morphologies_dir": "CircuitB/morphologies/swc",
"biophysical_neuron_models_dir": "CircuitB/hoc",
"alternate_morphologies": {
"neurolucida-asc": "CircuitB/morphologies/asc"
}
}
}
},
{
"nodes_file": "virtual_neurons.h5",
"populations": {
"A_virtual_neurons": {
"type": "virtual"
}
}
}
],
"edges": [
{
"edges_file": "local_edges_A.h5",
"populations": {
"NodeA__NodeA__chemical": {
"type": "chemical"
}
}
},
{
"edges_file": "local_edges_B.h5",
"populations": {
"NodeB__NodeB__chemical": {
"type": "chemical"
}
}
},
{
"edges_file": "edges_AB.h5",
"populations": {
"NodeA__NodeB__chemical": {
"type": "chemical"
},
"NodeB__NodeA__chemical": {
"type": "chemical"
}
}
}
]
}
}
5 changes: 4 additions & 1 deletion tests/simulations/usecase3/nodesets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@
"nodesPopA": {"population": "NodeA"},
"nodesPopB": {"population": "NodeB"},
"l4pc": {"mtype": "L4_PC"},
"Mosaic": {"population": ["NodeA", "NodeB"]}
"Mosaic": {"population": ["NodeA", "NodeB"]},
"virtual_target": {
"population": "A_virtual_neurons"
}
}
Binary file added tests/simulations/usecase3/virtual_neurons.h5
Binary file not shown.

0 comments on commit 75019b9

Please sign in to comment.