Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change 'electrode_group' col of Units to be ElectrodeGroup type #820

Merged
merged 8 commits into from
Feb 25, 2019
2 changes: 1 addition & 1 deletion src/pynwb/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(self, **kwargs):
'default': None, 'shape': (None, 2)},
{'name': 'electrodes', 'type': 'array_data', 'doc': 'the electrodes that each unit came from',
'default': None},
{'name': 'electrode_group', 'type': 'array_data', 'default': None,
{'name': 'electrode_group', 'type': 'ElectrodeGroup', 'default': None,
'doc': 'the electrode group that each unit came from'},
{'name': 'waveform_mean', 'type': 'array_data', 'doc': 'the spike waveform mean for each unit',
'default': None},
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/ui_write/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def addContainer(self, nwbfile):
location='CA1', filtering='none',
group=electrode_group)

nwbfile.add_unit(id=1, electrodes=[1])
nwbfile.add_unit(id=2, electrodes=[1])
nwbfile.add_unit(id=1, electrodes=[1], electrode_group=electrode_group)
nwbfile.add_unit(id=2, electrodes=[1], electrode_group=electrode_group)
nwbfile.units.to_dataframe()
self.container = nwbfile.units

Expand Down
9 changes: 9 additions & 0 deletions tests/unit/pynwb_tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
DecompositionSeries
from pynwb.file import TimeSeries, DynamicTable
from pynwb.core import VectorData
from pynwb.device import Device
from pynwb.ecephys import ElectrodeGroup


class AnnotationSeriesConstructor(unittest.TestCase):
Expand Down Expand Up @@ -149,6 +151,13 @@ def test_times_and_intervals(self):
self.assertTrue(np.all(ut['obs_intervals'][0] == np.array([[0, 2]])))
self.assertTrue(np.all(ut['obs_intervals'][1] == np.array([[2, 3], [4, 5]])))

def test_electrode_group(self):
ut = Units()
device = Device('test_device')
electrode_group = ElectrodeGroup('test_electrode_group', 'description', 'location', device)
ut.add_unit(electrode_group=electrode_group)
self.assertEqual(ut['electrode_group'][0], electrode_group)


if __name__ == '__main__':
unittest.main()