Skip to content

Commit

Permalink
[testing] Re-organize data tests around the one dataset per test (#1026)
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Aug 23, 2024
1 parent 11a45df commit bd1493c
Show file tree
Hide file tree
Showing 10 changed files with 970 additions and 801 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
* Added `get_json_schema_from_method_signature` which constructs Pydantic models automatically from the signature of any function with typical annotation types used throughout NeuroConv. [PR #1016](https://github.com/catalystneuro/neuroconv/pull/1016)
* Replaced all interface annotations with Pydantic types. [PR #1017](https://github.com/catalystneuro/neuroconv/pull/1017)
* Changed typehint collections (e.g. `List`) to standard collections (e.g. `list`). [PR #1021](https://github.com/catalystneuro/neuroconv/pull/1021)
* Testing now is only one dataset per test [PR #1026](https://github.com/catalystneuro/neuroconv/pull/1026)



Expand Down
297 changes: 127 additions & 170 deletions src/neuroconv/tools/testing/data_interface_mixins.py

Large diffs are not rendered by default.

64 changes: 29 additions & 35 deletions tests/test_behavior/test_audio_interface.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import shutil
import re
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from tempfile import mkdtemp
from warnings import warn

import jsonschema
import numpy as np
import pytest
from dateutil.tz import gettz
from hdmf.testing import TestCase
from numpy.testing import assert_array_equal
from pydantic import FilePath
from pynwb import NWBHDF5IO
Expand Down Expand Up @@ -38,38 +36,39 @@ def create_audio_files(
return audio_file_names


class TestAudioInterface(AudioInterfaceTestMixin, TestCase):
@classmethod
def setUpClass(cls):
class TestAudioInterface(AudioInterfaceTestMixin):

data_interface_cls = AudioInterface

@pytest.fixture(scope="class", autouse=True)
def setup_test(self, request, tmp_path_factory):

cls = request.cls

cls.session_start_time = datetime.now(tz=gettz(name="US/Pacific"))
cls.num_frames = int(1e7)
cls.num_audio_files = 3
cls.sampling_rate = 500
cls.aligned_segment_starting_times = [0.0, 20.0, 40.0]

cls.test_dir = Path(mkdtemp())
class_tmp_dir = tmp_path_factory.mktemp("class_tmp_dir")
cls.test_dir = Path(class_tmp_dir)
cls.file_paths = create_audio_files(
test_dir=cls.test_dir,
num_audio_files=cls.num_audio_files,
sampling_rate=cls.sampling_rate,
num_frames=cls.num_frames,
)
cls.data_interface_cls = AudioInterface
cls.interface_kwargs = dict(file_paths=[cls.file_paths[0]])

def setUp(self):
@pytest.fixture(scope="function", autouse=True)
def setup_converter(self):

self.nwbfile_path = str(self.test_dir / "audio_test.nwb")
self.create_audio_converter()
self.metadata = self.nwb_converter.get_metadata()
self.metadata["NWBFile"].update(session_start_time=self.session_start_time)

@classmethod
def tearDownClass(cls):
try:
shutil.rmtree(cls.test_dir)
except PermissionError: # Windows CI bug
warn(f"Unable to fully clean the temporary directory: {cls.test_dir}\n\nPlease remove it manually.")

def create_audio_converter(self):
class AudioTestNWBConverter(NWBConverter):
data_interface_classes = dict(Audio=AudioInterface)
Expand All @@ -83,18 +82,18 @@ class AudioTestNWBConverter(NWBConverter):

def test_unsupported_format(self):
exc_msg = "The currently supported file format for audio is WAV file. Some of the provided files does not match this format: ['.test']."
with self.assertRaisesWith(ValueError, exc_msg=exc_msg):
with pytest.raises(ValueError, match=re.escape(exc_msg)):
AudioInterface(file_paths=["test.test"])

def test_get_metadata(self):
audio_interface = AudioInterface(file_paths=self.file_paths)
metadata = audio_interface.get_metadata()
audio_metadata = metadata["Behavior"]["Audio"]

self.assertEqual(len(audio_metadata), self.num_audio_files)
assert len(audio_metadata) == self.num_audio_files

def test_incorrect_write_as(self):
with self.assertRaises(jsonschema.exceptions.ValidationError):
with pytest.raises(jsonschema.exceptions.ValidationError):
self.nwb_converter.run_conversion(
nwbfile_path=self.nwbfile_path,
metadata=self.metadata,
Expand Down Expand Up @@ -125,7 +124,7 @@ def test_incomplete_metadata(self):
expected_error_message = (
"The Audio metadata is incomplete (1 entry)! Expected 3 (one for each entry of 'file_paths')."
)
with self.assertRaisesWith(exc_type=AssertionError, exc_msg=expected_error_message):
with pytest.raises(AssertionError, match=re.escape(expected_error_message)):
self.nwb_converter.run_conversion(nwbfile_path=self.nwbfile_path, metadata=metadata, overwrite=True)

def test_metadata_update(self):
Expand All @@ -137,7 +136,7 @@ def test_metadata_update(self):
nwbfile = io.read()
container = nwbfile.stimulus
audio_name = metadata["Behavior"]["Audio"][0]["name"]
self.assertEqual("New description for Acoustic waveform series.", container[audio_name].description)
assert container[audio_name].description == "New description for Acoustic waveform series."

def test_not_all_metadata_are_unique(self):
metadata = deepcopy(self.metadata)
Expand All @@ -149,21 +148,18 @@ def test_not_all_metadata_are_unique(self):
],
)
expected_error_message = "Some of the names for Audio metadata are not unique."
with self.assertRaisesWith(exc_type=AssertionError, exc_msg=expected_error_message):
with pytest.raises(AssertionError, match=re.escape(expected_error_message)):
self.interface.run_conversion(nwbfile_path=self.nwbfile_path, metadata=metadata, overwrite=True)

def test_segment_starting_times_are_floats(self):
with self.assertRaisesWith(
exc_type=AssertionError, exc_msg="Argument 'aligned_segment_starting_times' must be a list of floats."
):
with pytest.raises(AssertionError, match="Argument 'aligned_segment_starting_times' must be a list of floats."):
self.interface.set_aligned_segment_starting_times(aligned_segment_starting_times=[0, 1, 2])

def test_segment_starting_times_length_mismatch(self):
with self.assertRaisesWith(
exc_type=AssertionError,
exc_msg="The number of entries in 'aligned_segment_starting_times' (4) must be equal to the number of audio file paths (3).",
):
with pytest.raises(AssertionError) as exc_info:
self.interface.set_aligned_segment_starting_times(aligned_segment_starting_times=[0.0, 1.0, 2.0, 4.0])
exc_msg = "The number of entries in 'aligned_segment_starting_times' (4) must be equal to the number of audio file paths (3)."
assert str(exc_info.value) == exc_msg

def test_set_aligned_segment_starting_times(self):
fresh_interface = AudioInterface(file_paths=self.file_paths[:2])
Expand Down Expand Up @@ -210,12 +206,10 @@ def test_run_conversion(self):
nwbfile = io.read()
container = nwbfile.stimulus
metadata = self.nwb_converter.get_metadata()
self.assertEqual(3, len(container))
assert len(container) == 3
for audio_ind, audio_metadata in enumerate(metadata["Behavior"]["Audio"]):
audio_interface_name = audio_metadata["name"]
assert audio_interface_name in container
self.assertEqual(
self.aligned_segment_starting_times[audio_ind], container[audio_interface_name].starting_time
)
self.assertEqual(self.sampling_rate, container[audio_interface_name].rate)
assert self.aligned_segment_starting_times[audio_ind] == container[audio_interface_name].starting_time
assert self.sampling_rate == container[audio_interface_name].rate
assert_array_equal(audio_test_data[audio_ind], container[audio_interface_name].data)
8 changes: 2 additions & 6 deletions tests/test_ecephys/test_mock_recording_interface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import unittest

from neuroconv.tools.testing.data_interface_mixins import (
RecordingExtractorInterfaceTestMixin,
)
from neuroconv.tools.testing.mock_interfaces import MockRecordingInterface


class TestMockRecordingInterface(unittest.TestCase, RecordingExtractorInterfaceTestMixin):
class TestMockRecordingInterface(RecordingExtractorInterfaceTestMixin):
data_interface_cls = MockRecordingInterface
interface_kwargs = [
dict(durations=[0.100]),
]
interface_kwargs = dict(durations=[0.100])
Loading

0 comments on commit bd1493c

Please sign in to comment.