diff --git a/neurodamus/cell_distributor.py b/neurodamus/cell_distributor.py index 7152cc42..65a93fb7 100644 --- a/neurodamus/cell_distributor.py +++ b/neurodamus/cell_distributor.py @@ -76,8 +76,9 @@ def get_cellref(self, gid): Returns: Cell object """ if self._binfo: - # are we in load balance mode? must replace gid with spgid - gid = self._binfo.thishost_gid(gid) + # are we in load balance mode? raw gids are in the binfo + gid_offset = self._local_nodes.offset + gid = self._binfo.thishost_gid(gid - gid_offset) + gid_offset return self._pc.gid2obj(gid) # Methods for compat with hoc @@ -242,7 +243,7 @@ def _load_nodes_balance(self, loader_f, load_balancer): self._binfo = load_balancer.load_balance_info(target_spec) # self._binfo has gidlist, but gids can appear multiple times all_gids = numpy.unique( - self._binfo.gids.as_numpy().astype("uint32") - self._local_nodes.offset + self._binfo.gids.as_numpy().astype("uint32") ) total_cells = len(all_gids) gidvec, me_infos, full_size = loader_f(self._circuit_conf, all_gids) @@ -358,19 +359,21 @@ def _init_cell_network(self): for final_gid, cell in self._gid2cell.items(): cell.re_init_rng(self._ionchannel_seed) nc = cell.connect2target(None) # Netcon doesnt require being stored - + raw_gid = final_gid - self._local_nodes.offset if self._binfo: - gid_i = int(self._binfo.gids.indwhere("==", final_gid)) + gid_i = int(self._binfo.gids.indwhere("==", raw_gid)) cb = self._binfo.bilist.object(self._binfo.cbindex.x[gid_i]) # multisplit cells call cb.multisplit() instead if cb.subtrees.count() > 0: cb.multisplit(nc, self._binfo.msgid, pc, pc.id()) cell.gid = final_gid + cell.raw_gid = raw_gid continue pc.set_gid2node(final_gid, pc.id()) pc.cell(final_gid, nc) cell.gid = final_gid # update the cell.gid last (RNGs had to use the base gid) + cell.raw_gid = raw_gid pc.multisplit() @@ -487,8 +490,9 @@ def get_cellref(self, gid): """ manager = self._find_manager(gid) if manager._binfo: - # are we in load balance mode? must replace gid with spgid - gid = manager._binfo.thishost_gid(gid) + # are we in load balance mode? raw gids are in the binfo + gid_offset = manager.local_nodes.offset + gid = manager._binfo.thishost_gid(gid - gid_offset) + gid_offset return self._pc.gid2obj(gid) def getSpGid(self, gid): @@ -502,7 +506,8 @@ def getSpGid(self, gid): """ manager = self._find_manager(gid) if manager._binfo: - return manager._binfo.thishost_gid(gid) + gid_offset = manager.local_nodes.offset + return manager._binfo.thishost_gid(gid - gid_offset) + gid_offset return gid def getPopulationInfo(self, gid): @@ -670,7 +675,7 @@ def _reuse_cell_complexity(self, target_spec: TargetSpec) -> bool: return False logging.info("Attempt reusing cx files from other targets...") - target_gids = self._get_target_gids(target_spec) + target_gids = self._get_target_raw_gids(target_spec) cx_other = {} for previous_target in self._cx_targets: @@ -712,7 +717,7 @@ def _cx_valid(self, target_spec) -> bool: return False if target_spec: # target provided, otherwise everything - target_gids = self._get_target_gids(target_spec) + target_gids = self._get_target_raw_gids(target_spec) if not self._cx_contains_gids(cx_filename, target_gids): logging.warning(" => %s invalid: changed target definition!", cx_filename) return False @@ -768,7 +773,7 @@ def _compute_save_complexities(self, target_str, mcomplex, cell_distributor): for cell in cell_distributor.cells: mcomplex.cell_complexity(cell.CellRef) - mcomplex.multisplit(cell.gid, lcx, tmp) + mcomplex.multisplit(cell.raw_gid, lcx, tmp) ms_list.append(tmp.c()) # To output build independently the contents of the file then append @@ -889,8 +894,8 @@ def _write_msdat_dict(fp, cx_dict, gids=None): fp.write(line) # raw lines, include \n # - - def _get_target_gids(self, target_spec) -> numpy.ndarray: - return self._target_manager.get_target(target_spec).get_gids() + def _get_target_raw_gids(self, target_spec) -> numpy.ndarray: + return self._target_manager.get_target(target_spec).get_raw_gids() def load_balance_info(self, target_spec): """ Loads a load-balance info for a given target. diff --git a/neurodamus/metype.py b/neurodamus/metype.py index 46f2b764..7325bf43 100644 --- a/neurodamus/metype.py +++ b/neurodamus/metype.py @@ -14,11 +14,12 @@ class BaseCell: """ Class representing an basic cell, e.g. an artificial cell """ - __slots__ = ("_cellref", "_ccell") + __slots__ = ("_cellref", "_ccell", "raw_gid") def __init__(self, gid, cell_info, circuit_info): self._cellref = None self._ccell = None + self.raw_gid = None @property def CellRef(self): diff --git a/tests/integration-e2e/test_loadbalance.py b/tests/integration-e2e/test_loadbalance.py index 7dbed753..795588a4 100644 --- a/tests/integration-e2e/test_loadbalance.py +++ b/tests/integration-e2e/test_loadbalance.py @@ -140,18 +140,81 @@ def test_multisplit(target_manager, circuit_conf, capsys): assert "Target VerySmall is a subset of the target All_Small" in captured.out +def _create_tmpconfig_lbal(config_file): + import json + import shutil + from tempfile import NamedTemporaryFile + + with open(config_file, "r") as f: + sim_config_data = json.load(f) + sim_config_data["network"] = "circuit_config_virtualpop.json" + sim_config_data["connection_overrides"] = [ + { + "name": "virtual_proj", + "source": "virtual_target", + "target": "l4pc", + "weight": 0.0 + }, + { + "name": "disconnect", + "source": "l4pc", + "target": "virtual_target", + "delay": 0.025, + "weight": 0.0 + } + ] + tmp_file = NamedTemporaryFile(suffix=".json", dir=os.path.dirname(config_file), delete=True) + shutil.copy2(config_file, tmp_file.name) + + with open(tmp_file.name, "w") as f: + json.dump(sim_config_data, f, indent=2) + return tmp_file + + +def _read_complexity_file(base_dir, pattern, cx_pattern): + import glob + # Construct the full pattern path + full_pattern = os.path.join(base_dir, pattern, cx_pattern) + + # Use glob to find files that match the pattern + matching_files = glob.glob(full_pattern) + + # Read each matching file + for file_path in matching_files: + try: + with open(file_path, 'r') as file: + content = file.read() + return content + except FileNotFoundError: + print(f"File not found: {file_path}") + + def test_loadbal_integration(): """Ensure given the right files are in the lbal dir, the correct situation is detected """ from neurodamus import Node - from neurodamus.core.configuration import GlobalConfig + from neurodamus.core.configuration import GlobalConfig, SimConfig GlobalConfig.verbosity = 2 - config_file = str(SIM_DIR / "usecase3" / "simulation_sonata.json") - nd = Node(config_file, {"lb_mode": "WholeCell"}) + + # Add connection_overrides for the virtual population so the offsets are calculated before LB + tmp_file = _create_tmpconfig_lbal(SIM_DIR / "usecase3" / "simulation_sonata.json") + nd = Node(tmp_file.name, {"lb_mode": "WholeCell"}) nd.load_targets() + SimConfig.check_connections_configure(nd._target_manager) lb = nd.compute_load_balance() nd.create_cells(lb) + # Check the complexity file + base_dir = "sim_conf" + pattern = "_loadbal_*.NodeA" # Matches any hash and population + cx_pattern = "cx_NodeA*#.dat" # Matches any cx file with the pattern + assert Path(base_dir).is_dir(), "Directory 'sim_conf' not found." + cx_file = _read_complexity_file(base_dir, pattern, cx_pattern) + lines = cx_file.splitlines() + assert int(lines[1]) == 3, "Number of gids different than 3." + # Gid should be without offset (2 instead of 1002) + assert int(lines[3].split()[0]) == 2, "gid 2 not found." + class MockedTargetManager: """ diff --git a/tests/simulations/usecase3/circuit_config_virtualpop.json b/tests/simulations/usecase3/circuit_config_virtualpop.json new file mode 100644 index 00000000..3fc665cd --- /dev/null +++ b/tests/simulations/usecase3/circuit_config_virtualpop.json @@ -0,0 +1,70 @@ +{ + "version": 2, + "networks": { + "nodes": [ + { + "nodes_file": "nodes_A.h5", + "populations": { + "NodeA": { + "type": "biophysical", + "morphologies_dir": "CircuitA/morphologies/swc", + "biophysical_neuron_models_dir": "CircuitA/hoc", + "alternate_morphologies": { + "neurolucida-asc": "CircuitA/morphologies/asc" + } + } + } + }, + { + "nodes_file": "nodes_B.h5", + "populations": { + "NodeB": { + "type": "biophysical", + "morphologies_dir": "CircuitB/morphologies/swc", + "biophysical_neuron_models_dir": "CircuitB/hoc", + "alternate_morphologies": { + "neurolucida-asc": "CircuitB/morphologies/asc" + } + } + } + }, + { + "nodes_file": "virtual_neurons.h5", + "populations": { + "A_virtual_neurons": { + "type": "virtual" + } + } + } + ], + "edges": [ + { + "edges_file": "local_edges_A.h5", + "populations": { + "NodeA__NodeA__chemical": { + "type": "chemical" + } + } + }, + { + "edges_file": "local_edges_B.h5", + "populations": { + "NodeB__NodeB__chemical": { + "type": "chemical" + } + } + }, + { + "edges_file": "edges_AB.h5", + "populations": { + "NodeA__NodeB__chemical": { + "type": "chemical" + }, + "NodeB__NodeA__chemical": { + "type": "chemical" + } + } + } + ] + } +} diff --git a/tests/simulations/usecase3/nodesets.json b/tests/simulations/usecase3/nodesets.json index 02442095..f35e9bcd 100644 --- a/tests/simulations/usecase3/nodesets.json +++ b/tests/simulations/usecase3/nodesets.json @@ -2,5 +2,8 @@ "nodesPopA": {"population": "NodeA"}, "nodesPopB": {"population": "NodeB"}, "l4pc": {"mtype": "L4_PC"}, - "Mosaic": {"population": ["NodeA", "NodeB"]} + "Mosaic": {"population": ["NodeA", "NodeB"]}, + "virtual_target": { + "population": "A_virtual_neurons" + } } diff --git a/tests/simulations/usecase3/virtual_neurons.h5 b/tests/simulations/usecase3/virtual_neurons.h5 new file mode 100644 index 00000000..7bb97733 Binary files /dev/null and b/tests/simulations/usecase3/virtual_neurons.h5 differ