Skip to content

Commit

Permalink
add validation
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Mar 21, 2024
1 parent ade9e96 commit 68a1dc1
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 3 deletions.
88 changes: 87 additions & 1 deletion src/pynwb/ndx_binned_spikes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import numpy as np

from pynwb import load_namespaces, get_class
from pynwb.core import NWBDataInterface
from hdmf.utils import docval, popargs_to_dict, get_docval, popargs

try:
from importlib.resources import files
Expand All @@ -18,7 +22,89 @@
# Load the namespace
load_namespaces(str(__spec_path))

BinnedAlignedSpikes = get_class("BinnedAlignedSpikes", "ndx-binned-spikes")
# BinnedAlignedSpikes = get_class("BinnedAlignedSpikes", "ndx-binned-spikes")

from pynwb import register_class, docval


@register_class(neurodata_type="BinnedAlignedSpikes", namespace="ndx-binned-spikes")
class BinnedAlignedSpikes(NWBDataInterface):
__nwbfields__ = (
"name",
"bin_width_in_milliseconds",
"milliseconds_from_event_to_first_bin",
"data",
"event_timestamps",
"units",
)

DEFAULT_NAME = "BinnedAlignedSpikes"

@docval(
{
"name": "name",
"type": str,
"doc": "The name of this container",
"default": DEFAULT_NAME,
},
{
"name": "bin_width_in_milliseconds",
"type": float,
"doc": "The length in milliseconds of the bins",
},
{
"name": "milliseconds_from_event_to_first_bin",
"type": float,
"doc": (
"The time in milliseconds from the event (e.g. a stimuli or the beginning of a trial),"
"to the first bin. Note that this is a negative number if the first bin is before the event."
),
"default": 0.0,
},
{
"name": "data",
"type": "array_data",
"shape": [(None, None, None), (None, None)],
"doc": "The source of the data",
},
{
"name": "event_timestamps",
"type": "array_data",
"doc": "The timestamps at which the event occurred.",
"shape": (None,),
},
{
"name": "units",
"type": ("DynamicTableRegion"),
"doc": "A reference to the Units table region that contains the units of the data.",
"default": None,
},
)
def __init__(self, **kwargs):

keys_to_set = ("bin_width_in_milliseconds", "milliseconds_from_event_to_first_bin", "units")
args_to_set = popargs_to_dict(keys_to_set, kwargs)

keys_to_process = ("data", "event_timestamps") # these are properties and cannot be set with setattr
args_to_process = popargs_to_dict(keys_to_process, kwargs)
super().__init__(**kwargs)

# Set the values
for key, val in args_to_set.items():
setattr(self, key, val)

# Post-process / post_init
data = args_to_process["data"]

data = data if data.ndim == 3 else data[np.newaxis, ...]

event_timestamps = args_to_process["event_timestamps"]

if data.shape[1] != event_timestamps.shape[0]:
raise ValueError("The number of event timestamps must match the number of event repetitions in the data.")

self.fields["data"] = data
self.fields["event_timestamps"] = event_timestamps


# Remove these functions from the package
Expand Down
23 changes: 21 additions & 2 deletions src/pynwb/tests/test_binned_aligned_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def setUp(self):
self.number_of_event_repetitions = 4
self.bin_width_in_milliseconds = 20.0
self.milliseconds_from_event_to_first_bin = -100.0
rng = np.random.default_rng(seed=0)
self.rng = np.random.default_rng(seed=0)

self.data = rng.integers(
self.data = self.rng.integers(
low=0,
high=100,
size=(
Expand Down Expand Up @@ -99,6 +99,25 @@ def test_constructor_units_region(self):
expected_names = [unit_name_a, unit_name_c]
self.assertListEqual(unit_table_names, expected_names)

def test_accepting_input_with_no_number_of_units_dimension(self):

data = self.rng.integers(
low=0,
high=100,
size=(
self.number_of_event_repetitions,
self.number_of_bins,
),
)
binned_aligned_spikes = BinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=data,
event_timestamps=self.event_timestamps,
)

self.assertEqual(binned_aligned_spikes.data.shape, (1, self.number_of_event_repetitions, self.number_of_bins))


class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase):
"""Simple roundtrip test for BinnedAlignedSpikes."""
Expand Down

0 comments on commit 68a1dc1

Please sign in to comment.