Skip to content

Commit

Permalink
Fix record_spikes for load balancing (#167)
Browse files Browse the repository at this point in the history
## Context
After #148, we are writing
raw gids in `Nd.BalanceInfo` instead of final gids with offsets. Such
change is missing in `record_spikes()` leading to no spike written in
the report with NEURON while lb mode is enabled.

## Scope
Fix in the function `cell_distributer.py: record_spikes`.

## Testing
check spikes in `test_loadbal_integration`

## Review
* [x] PR description is complete
* [x] Coding style (imports, function length, New functions, classes or
files) are good
* [x] Unit/Scientific test added
* [ ] Updated Readme, in-code, developer documentation
  • Loading branch information
WeinaJi authored May 13, 2024
1 parent 1d2bf99 commit 5c02932
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
3 changes: 2 additions & 1 deletion neurodamus/cell_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,11 @@ def record_spikes(self, gids=None, append_spike_vecs=None):
spikevec, idvec = append_spike_vecs or (Nd.Vector(), Nd.Vector())
if gids is None:
gids = self._local_nodes.final_gids()
gid_offset = self._local_nodes.offset

for gid in gids:
# only want to collect spikes of cell pieces with the soma (i.e. the real gid)
if not self._binfo or self._binfo.thishost_gid(gid) == gid:
if not self._binfo or self._binfo.thishost_gid(gid - gid_offset) + gid_offset == gid:
self._pc.spike_record(gid, spikevec, idvec)
return spikevec, idvec

Expand Down
24 changes: 17 additions & 7 deletions tests/integration-e2e/test_loadbalance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tests load balance."""
# Since a good deal of load balance tests are e2e we put all of them together in this group
import logging
import numpy as np
import numpy.testing as npt
import os
import pytest
import shutil
Expand Down Expand Up @@ -192,17 +194,15 @@ def _read_complexity_file(base_dir, pattern, cx_pattern):
def test_loadbal_integration():
"""Ensure given the right files are in the lbal dir, the correct situation is detected
"""
from neurodamus import Node
from neurodamus.core.configuration import GlobalConfig, SimConfig
from neurodamus import Neurodamus
from neurodamus.core.configuration import GlobalConfig
from neurodamus.replay import SpikeManager
GlobalConfig.verbosity = 2

# Add connection_overrides for the virtual population so the offsets are calculated before LB
tmp_file = _create_tmpconfig_lbal(SIM_DIR / "usecase3" / "simulation_sonata.json")
nd = Node(tmp_file.name, {"lb_mode": "WholeCell"})
nd.load_targets()
SimConfig.check_connections_configure(nd._target_manager)
lb = nd.compute_load_balance()
nd.create_cells(lb)
nd = Neurodamus(tmp_file.name, lb_mode="WholeCell")
nd.run()

# Check the complexity file
base_dir = "sim_conf"
Expand All @@ -215,6 +215,16 @@ def test_loadbal_integration():
# Gid should be without offset (2 instead of 1002)
assert int(lines[3].split()[0]) == 2, "gid 2 not found."

# check the spikes
spike_dat = Path(nd._run_conf.get("OutputRoot"))/nd._run_conf.get("SpikesFile")
timestamps_A, gids_A = SpikeManager._read_spikes_sonata(spike_dat, "NodeA")
assert len(timestamps_A) == 21
ref_times = np.array([0.2, 0.3, 0.3, 2.5, 3.4, 4.2, 5.5, 7.0, 7.4, 8.6, 13.8, 19.6, 25.7, 32.,
36.4, 38.5, 40.8, 42.6, 45.2, 48.3, 49.9])
ref_gids = np.array([1, 2, 3, 1, 2, 3, 1, 1, 2, 3, 3, 3, 3, 3, 1, 3, 2, 1, 3, 1, 2])
npt.assert_allclose(timestamps_A, ref_times)
npt.assert_allclose(gids_A, ref_gids)


class MockedTargetManager:
"""
Expand Down

0 comments on commit 5c02932

Please sign in to comment.